libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

local_tone_map.cc (13573B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include <jxl/cms.h>
      7 #include <stdio.h>
      8 #include <stdlib.h>
      9 
     10 #include "lib/jxl/base/status.h"
     11 #include "tools/file_io.h"
     12 
     13 #undef HWY_TARGET_INCLUDE
     14 #define HWY_TARGET_INCLUDE "tools/hdr/local_tone_map.cc"
     15 #include <hwy/foreach_target.h>
     16 #include <hwy/highway.h>
     17 
     18 #include "lib/extras/codec.h"
     19 #include "lib/extras/packed_image_convert.h"
     20 #include "lib/extras/tone_mapping.h"
     21 #include "lib/jxl/base/fast_math-inl.h"
     22 #include "lib/jxl/convolve.h"
     23 #include "lib/jxl/image_bundle.h"
     24 #include "tools/cmdline.h"
     25 #include "tools/thread_pool_internal.h"
     26 
     27 HWY_BEFORE_NAMESPACE();
     28 namespace jxl {
     29 namespace HWY_NAMESPACE {
     30 namespace {
     31 
     32 using ::hwy::HWY_NAMESPACE::Add;
     33 using ::hwy::HWY_NAMESPACE::Div;
     34 using ::hwy::HWY_NAMESPACE::Lt;
     35 using ::hwy::HWY_NAMESPACE::Max;
     36 using ::hwy::HWY_NAMESPACE::Min;
     37 using ::hwy::HWY_NAMESPACE::Mul;
     38 using ::hwy::HWY_NAMESPACE::MulAdd;
     39 using ::hwy::HWY_NAMESPACE::Sub;
     40 
     41 constexpr size_t kDownsampling = 128;
     42 
     43 // Color components must be in linear Rec. 2020.
     44 template <typename V>
     45 V ComputeLuminance(const float intensity_target, const V r, const V g,
     46                    const V b) {
     47   hwy::HWY_NAMESPACE::DFromV<V> df;
     48   const auto luminance =
     49       Mul(Set(df, intensity_target),
     50           MulAdd(Set(df, 0.2627f), r,
     51                  MulAdd(Set(df, 0.6780f), g, Mul(Set(df, 0.0593f), b))));
     52   return Max(Set(df, 1e-12f), luminance);
     53 }
     54 
     55 ImageF DownsampledLuminances(const Image3F& image,
     56                              const float intensity_target) {
     57   HWY_CAPPED(float, kDownsampling) d;
     58   JXL_ASSIGN_OR_DIE(ImageF result,
     59                     ImageF::Create(DivCeil(image.xsize(), kDownsampling),
     60                                    DivCeil(image.ysize(), kDownsampling)));
     61   FillImage(kDefaultIntensityTarget, &result);
     62   for (size_t y = 0; y < image.ysize(); ++y) {
     63     const float* const JXL_RESTRICT rows[3] = {image.ConstPlaneRow(0, y),
     64                                                image.ConstPlaneRow(1, y),
     65                                                image.ConstPlaneRow(2, y)};
     66     float* const JXL_RESTRICT result_row = result.Row(y / kDownsampling);
     67 
     68     for (size_t x = 0; x < image.xsize(); x += kDownsampling) {
     69       auto max = Set(d, result_row[x / kDownsampling]);
     70       for (size_t kx = 0; kx < kDownsampling && x + kx < image.xsize();
     71            kx += Lanes(d)) {
     72         max =
     73             Max(max, ComputeLuminance(
     74                          intensity_target, Load(d, rows[0] + x + kx),
     75                          Load(d, rows[1] + x + kx), Load(d, rows[2] + x + kx)));
     76       }
     77       result_row[x / kDownsampling] = GetLane(MaxOfLanes(d, max));
     78     }
     79   }
     80   HWY_FULL(float) df;
     81   for (size_t y = 0; y < result.ysize(); ++y) {
     82     float* const JXL_RESTRICT row = result.Row(y);
     83     for (size_t x = 0; x < result.xsize(); x += Lanes(df)) {
     84       Store(FastLog2f(df, Load(df, row + x)), df, row + x);
     85     }
     86   }
     87   return result;
     88 }
     89 
     90 ImageF Upsample(const ImageF& image, ThreadPool* pool) {
     91   JXL_ASSIGN_OR_DIE(ImageF upsampled_horizontally,
     92                     ImageF::Create(2 * image.xsize(), image.ysize()));
     93   const auto BoundX = [&image](ssize_t x) {
     94     return Clamp1<ssize_t>(x, 0, image.xsize() - 1);
     95   };
     96   JXL_CHECK(RunOnPool(
     97       pool, 0, image.ysize(), &ThreadPool::NoInit,
     98       [&](const int32_t y, const int32_t /*thread_id*/) {
     99         const float* const JXL_RESTRICT in_row = image.ConstRow(y);
    100         float* const JXL_RESTRICT out_row = upsampled_horizontally.Row(y);
    101 
    102         for (ssize_t x = 0; x < static_cast<ssize_t>(image.xsize()); ++x) {
    103           out_row[2 * x] = in_row[x];
    104           out_row[2 * x + 1] =
    105               0.5625f * (in_row[x] + in_row[BoundX(x + 1)]) -
    106               0.0625f * (in_row[BoundX(x - 1)] + in_row[BoundX(x + 2)]);
    107         }
    108       },
    109       "UpsampleHorizontally"));
    110 
    111   HWY_FULL(float) df;
    112   JXL_ASSIGN_OR_DIE(ImageF upsampled,
    113                     ImageF::Create(2 * image.xsize(), 2 * image.ysize()));
    114   const auto BoundY = [&image](ssize_t y) {
    115     return Clamp1<ssize_t>(y, 0, image.ysize() - 1);
    116   };
    117   JXL_CHECK(RunOnPool(
    118       pool, 0, image.ysize(), &ThreadPool::NoInit,
    119       [&](const int32_t y, const int32_t /*thread_id*/) {
    120         const float* const JXL_RESTRICT in_rows[4] = {
    121             upsampled_horizontally.ConstRow(BoundY(y - 1)),
    122             upsampled_horizontally.ConstRow(y),
    123             upsampled_horizontally.ConstRow(BoundY(y + 1)),
    124             upsampled_horizontally.ConstRow(BoundY(y + 2)),
    125         };
    126         float* const JXL_RESTRICT out_rows[2] = {
    127             upsampled.Row(2 * y),
    128             upsampled.Row(2 * y + 1),
    129         };
    130 
    131         for (ssize_t x = 0;
    132              x < static_cast<ssize_t>(upsampled_horizontally.xsize());
    133              x += Lanes(df)) {
    134           Store(Load(df, in_rows[1] + x), df, out_rows[0] + x);
    135           Store(MulAdd(Set(df, 0.5625f),
    136                        Add(Load(df, in_rows[1] + x), Load(df, in_rows[2] + x)),
    137                        Mul(Set(df, -0.0625f), Add(Load(df, in_rows[0] + x),
    138                                                   Load(df, in_rows[3] + x)))),
    139                 df, out_rows[1] + x);
    140         }
    141       },
    142       "UpsampleVertically"));
    143   return upsampled;
    144 }
    145 
    146 float ComputeOffset(const ImageF& original_luminances,
    147                     const ImageF& upsampled_blurred_luminances) {
    148   HWY_CAPPED(float, kDownsampling) df;
    149   float max_difference = 0.f;
    150   for (size_t y = 0; y < original_luminances.ysize(); ++y) {
    151     const float* const JXL_RESTRICT original_row =
    152         original_luminances.ConstRow(y);
    153     for (size_t x = 0; x < original_luminances.xsize(); ++x) {
    154       auto block_min = Set(df, std::numeric_limits<float>::infinity());
    155       for (size_t ky = 0; ky < kDownsampling; ++ky) {
    156         const float* const JXL_RESTRICT blurred_row =
    157             upsampled_blurred_luminances.ConstRow(kDownsampling * y + ky);
    158         for (size_t kx = 0; kx < kDownsampling; kx += Lanes(df)) {
    159           block_min =
    160               Min(block_min, Load(df, blurred_row + kDownsampling * x + kx));
    161         }
    162       }
    163 
    164       const float difference =
    165           original_row[x] - GetLane(MinOfLanes(df, block_min));
    166       if (difference > max_difference) max_difference = difference;
    167     }
    168   }
    169   return max_difference;
    170 }
    171 
    172 Status ApplyLocalToneMapping(const ImageF& blurred_luminances,
    173                              const float intensity_target,
    174                              const float max_difference, Image3F* color,
    175                              ThreadPool* pool) {
    176   HWY_FULL(float) df;
    177 
    178   const auto log_default_intensity_target =
    179       Set(df, FastLog2f(kDefaultIntensityTarget));
    180   const auto log_10000 = Set(df, FastLog2f(10000.f));
    181   JXL_RETURN_IF_ERROR(RunOnPool(
    182       pool, 0, color->ysize(), &ThreadPool::NoInit,
    183       [&](const int32_t y, const int32_t /*thread_id*/) {
    184         float* const JXL_RESTRICT rows[3] = {color->PlaneRow(0, y),
    185                                              color->PlaneRow(1, y),
    186                                              color->PlaneRow(2, y)};
    187         const float* const JXL_RESTRICT blurred_lum_row =
    188             blurred_luminances.ConstRow(y);
    189 
    190         for (size_t x = 0; x < color->xsize(); x += Lanes(df)) {
    191           const auto log_local_max =
    192               Add(Load(df, blurred_lum_row + x), Set(df, max_difference));
    193           const auto luminance =
    194               ComputeLuminance(intensity_target, Load(df, rows[0] + x),
    195                                Load(df, rows[1] + x), Load(df, rows[2] + x));
    196           const auto log_luminance = FastLog2f(df, luminance);
    197           const auto log_knee =
    198               Mul(log_default_intensity_target,
    199                   MulAdd(Set(df, -0.85f),
    200                          Div(Sub(log_local_max, log_default_intensity_target),
    201                              Sub(log_10000, log_default_intensity_target)),
    202                          Set(df, 1.f)));
    203           const auto second_segment_position =
    204               Div(Sub(log_luminance, log_knee), Sub(log_local_max, log_knee));
    205           const auto log_new_luminance = IfThenElse(
    206               Lt(log_luminance, log_knee), log_luminance,
    207               MulAdd(
    208                   second_segment_position,
    209                   MulAdd(Sub(log_default_intensity_target, log_knee),
    210                          second_segment_position, Sub(log_knee, log_luminance)),
    211                   log_luminance));
    212           const auto new_luminance = FastPow2f(df, log_new_luminance);
    213           const auto ratio =
    214               Div(Mul(Set(df, intensity_target), new_luminance),
    215                   Mul(luminance, Set(df, kDefaultIntensityTarget)));
    216           for (int c = 0; c < 3; ++c) {
    217             Store(Mul(ratio, Load(df, rows[c] + x)), df, rows[c] + x);
    218           }
    219         }
    220       },
    221       "ApplyLocalToneMapping"));
    222 
    223   return true;
    224 }
    225 
    226 }  // namespace
    227 }  // namespace HWY_NAMESPACE
    228 }  // namespace jxl
    229 HWY_AFTER_NAMESPACE();
    230 
    231 #if HWY_ONCE
    232 
    233 namespace jxl {
    234 namespace {
    235 
    236 HWY_EXPORT(DownsampledLuminances);
    237 HWY_EXPORT(Upsample);
    238 HWY_EXPORT(ComputeOffset);
    239 HWY_EXPORT(ApplyLocalToneMapping);
    240 
    241 void Blur(ImageF* image) {
    242   static constexpr WeightsSeparable5 kBlurFilter = {
    243       {HWY_REP4(.375f), HWY_REP4(.25f), HWY_REP4(.0625f)},
    244       {HWY_REP4(.375f), HWY_REP4(.25f), HWY_REP4(.0625f)}};
    245   JXL_ASSIGN_OR_DIE(ImageF blurred_once,
    246                     ImageF::Create(image->xsize(), image->ysize()));
    247   Separable5(*image, Rect(*image), kBlurFilter, nullptr, &blurred_once);
    248   Separable5(blurred_once, Rect(blurred_once), kBlurFilter, nullptr, image);
    249 }
    250 
    251 void ProcessFrame(CodecInOut* image, float preserve_saturation,
    252                   ThreadPool* pool) {
    253   ColorEncoding linear_rec2020;
    254   JXL_CHECK(linear_rec2020.SetWhitePointType(WhitePoint::kD65));
    255   JXL_CHECK(linear_rec2020.SetPrimariesType(Primaries::k2100));
    256   linear_rec2020.Tf().SetTransferFunction(TransferFunction::kLinear);
    257   JXL_CHECK(linear_rec2020.CreateICC());
    258   JXL_CHECK(
    259       image->Main().TransformTo(linear_rec2020, *JxlGetDefaultCms(), pool));
    260 
    261   const float intensity_target = image->metadata.m.IntensityTarget();
    262 
    263   Image3F color = std::move(*image->Main().color());
    264   ImageF subsampled_image =
    265       HWY_DYNAMIC_DISPATCH(DownsampledLuminances)(color, intensity_target);
    266   JXL_ASSIGN_OR_DIE(
    267       ImageF original_luminances,
    268       ImageF::Create(subsampled_image.xsize(), subsampled_image.ysize()));
    269   CopyImageTo(subsampled_image, &original_luminances);
    270 
    271   Blur(&subsampled_image);
    272   const auto& Upsample = HWY_DYNAMIC_DISPATCH(Upsample);
    273   ImageF blurred_luminances = std::move(subsampled_image);
    274   for (int downsampling = HWY_NAMESPACE::kDownsampling; downsampling > 1;
    275        downsampling >>= 1) {
    276     blurred_luminances =
    277         Upsample(blurred_luminances, downsampling > 4 ? nullptr : pool);
    278   }
    279 
    280   const float max_difference = HWY_DYNAMIC_DISPATCH(ComputeOffset)(
    281       original_luminances, blurred_luminances);
    282 
    283   JXL_CHECK(HWY_DYNAMIC_DISPATCH(ApplyLocalToneMapping)(
    284       blurred_luminances, intensity_target, max_difference, &color, pool));
    285 
    286   image->SetFromImage(std::move(color), linear_rec2020);
    287   image->metadata.m.color_encoding = linear_rec2020;
    288   image->metadata.m.SetIntensityTarget(kDefaultIntensityTarget);
    289 
    290   JXL_CHECK(GamutMap(image, preserve_saturation, pool));
    291 
    292   ColorEncoding rec2020_srgb = linear_rec2020;
    293   rec2020_srgb.Tf().SetTransferFunction(TransferFunction::kSRGB);
    294   JXL_CHECK(rec2020_srgb.CreateICC());
    295   JXL_CHECK(image->Main().TransformTo(rec2020_srgb, *JxlGetDefaultCms(), pool));
    296   image->metadata.m.color_encoding = rec2020_srgb;
    297 }
    298 
    299 }  // namespace
    300 }  // namespace jxl
    301 
    302 int main(int argc, const char** argv) {
    303   jpegxl::tools::ThreadPoolInternal pool(8);
    304 
    305   jpegxl::tools::CommandLineParser parser;
    306   float preserve_saturation = .4f;
    307   parser.AddOptionValue(
    308       's', "preserve_saturation", "0..1",
    309       "to what extent to try and preserve saturation over luminance",
    310       &preserve_saturation, &jpegxl::tools::ParseFloat, 0);
    311   const char* input_filename = nullptr;
    312   auto input_filename_option = parser.AddPositionalOption(
    313       "input", true, "input image", &input_filename, 0);
    314   const char* output_filename = nullptr;
    315   auto output_filename_option = parser.AddPositionalOption(
    316       "output", true, "output image", &output_filename, 0);
    317 
    318   if (!parser.Parse(argc, argv)) {
    319     fprintf(stderr, "See -h for help.\n");
    320     return EXIT_FAILURE;
    321   }
    322 
    323   if (parser.HelpFlagPassed()) {
    324     parser.PrintHelp();
    325     return EXIT_SUCCESS;
    326   }
    327 
    328   if (!parser.GetOption(input_filename_option)->matched()) {
    329     fprintf(stderr, "Missing input filename.\nSee -h for help.\n");
    330     return EXIT_FAILURE;
    331   }
    332   if (!parser.GetOption(output_filename_option)->matched()) {
    333     fprintf(stderr, "Missing output filename.\nSee -h for help.\n");
    334     return EXIT_FAILURE;
    335   }
    336 
    337   jxl::CodecInOut image;
    338   jxl::extras::ColorHints color_hints;
    339   color_hints.Add("color_space", "RGB_D65_202_Rel_PeQ");
    340   std::vector<uint8_t> encoded;
    341   JXL_CHECK(jpegxl::tools::ReadFile(input_filename, &encoded));
    342   JXL_CHECK(jxl::SetFromBytes(jxl::Bytes(encoded), color_hints, &image, &pool));
    343 
    344   jxl::ProcessFrame(&image, preserve_saturation, &pool);
    345 
    346   JxlPixelFormat format = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0};
    347   jxl::extras::PackedPixelFile ppf =
    348       jxl::extras::ConvertImage3FToPackedPixelFile(
    349           *image.Main().color(), image.metadata.m.color_encoding, format,
    350           &pool);
    351   JXL_CHECK(jxl::Encode(ppf, output_filename, &encoded, &pool));
    352   JXL_CHECK(jpegxl::tools::WriteFile(output_filename, encoded));
    353 }
    354 
    355 #endif