local_tone_map.cc (13573B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include <jxl/cms.h> 7 #include <stdio.h> 8 #include <stdlib.h> 9 10 #include "lib/jxl/base/status.h" 11 #include "tools/file_io.h" 12 13 #undef HWY_TARGET_INCLUDE 14 #define HWY_TARGET_INCLUDE "tools/hdr/local_tone_map.cc" 15 #include <hwy/foreach_target.h> 16 #include <hwy/highway.h> 17 18 #include "lib/extras/codec.h" 19 #include "lib/extras/packed_image_convert.h" 20 #include "lib/extras/tone_mapping.h" 21 #include "lib/jxl/base/fast_math-inl.h" 22 #include "lib/jxl/convolve.h" 23 #include "lib/jxl/image_bundle.h" 24 #include "tools/cmdline.h" 25 #include "tools/thread_pool_internal.h" 26 27 HWY_BEFORE_NAMESPACE(); 28 namespace jxl { 29 namespace HWY_NAMESPACE { 30 namespace { 31 32 using ::hwy::HWY_NAMESPACE::Add; 33 using ::hwy::HWY_NAMESPACE::Div; 34 using ::hwy::HWY_NAMESPACE::Lt; 35 using ::hwy::HWY_NAMESPACE::Max; 36 using ::hwy::HWY_NAMESPACE::Min; 37 using ::hwy::HWY_NAMESPACE::Mul; 38 using ::hwy::HWY_NAMESPACE::MulAdd; 39 using ::hwy::HWY_NAMESPACE::Sub; 40 41 constexpr size_t kDownsampling = 128; 42 43 // Color components must be in linear Rec. 2020. 44 template <typename V> 45 V ComputeLuminance(const float intensity_target, const V r, const V g, 46 const V b) { 47 hwy::HWY_NAMESPACE::DFromV<V> df; 48 const auto luminance = 49 Mul(Set(df, intensity_target), 50 MulAdd(Set(df, 0.2627f), r, 51 MulAdd(Set(df, 0.6780f), g, Mul(Set(df, 0.0593f), b)))); 52 return Max(Set(df, 1e-12f), luminance); 53 } 54 55 ImageF DownsampledLuminances(const Image3F& image, 56 const float intensity_target) { 57 HWY_CAPPED(float, kDownsampling) d; 58 JXL_ASSIGN_OR_DIE(ImageF result, 59 ImageF::Create(DivCeil(image.xsize(), kDownsampling), 60 DivCeil(image.ysize(), kDownsampling))); 61 FillImage(kDefaultIntensityTarget, &result); 62 for (size_t y = 0; y < image.ysize(); ++y) { 63 const float* const JXL_RESTRICT rows[3] = {image.ConstPlaneRow(0, y), 64 image.ConstPlaneRow(1, y), 65 image.ConstPlaneRow(2, y)}; 66 float* const JXL_RESTRICT result_row = result.Row(y / kDownsampling); 67 68 for (size_t x = 0; x < image.xsize(); x += kDownsampling) { 69 auto max = Set(d, result_row[x / kDownsampling]); 70 for (size_t kx = 0; kx < kDownsampling && x + kx < image.xsize(); 71 kx += Lanes(d)) { 72 max = 73 Max(max, ComputeLuminance( 74 intensity_target, Load(d, rows[0] + x + kx), 75 Load(d, rows[1] + x + kx), Load(d, rows[2] + x + kx))); 76 } 77 result_row[x / kDownsampling] = GetLane(MaxOfLanes(d, max)); 78 } 79 } 80 HWY_FULL(float) df; 81 for (size_t y = 0; y < result.ysize(); ++y) { 82 float* const JXL_RESTRICT row = result.Row(y); 83 for (size_t x = 0; x < result.xsize(); x += Lanes(df)) { 84 Store(FastLog2f(df, Load(df, row + x)), df, row + x); 85 } 86 } 87 return result; 88 } 89 90 ImageF Upsample(const ImageF& image, ThreadPool* pool) { 91 JXL_ASSIGN_OR_DIE(ImageF upsampled_horizontally, 92 ImageF::Create(2 * image.xsize(), image.ysize())); 93 const auto BoundX = [&image](ssize_t x) { 94 return Clamp1<ssize_t>(x, 0, image.xsize() - 1); 95 }; 96 JXL_CHECK(RunOnPool( 97 pool, 0, image.ysize(), &ThreadPool::NoInit, 98 [&](const int32_t y, const int32_t /*thread_id*/) { 99 const float* const JXL_RESTRICT in_row = image.ConstRow(y); 100 float* const JXL_RESTRICT out_row = upsampled_horizontally.Row(y); 101 102 for (ssize_t x = 0; x < static_cast<ssize_t>(image.xsize()); ++x) { 103 out_row[2 * x] = in_row[x]; 104 out_row[2 * x + 1] = 105 0.5625f * (in_row[x] + in_row[BoundX(x + 1)]) - 106 0.0625f * (in_row[BoundX(x - 1)] + in_row[BoundX(x + 2)]); 107 } 108 }, 109 "UpsampleHorizontally")); 110 111 HWY_FULL(float) df; 112 JXL_ASSIGN_OR_DIE(ImageF upsampled, 113 ImageF::Create(2 * image.xsize(), 2 * image.ysize())); 114 const auto BoundY = [&image](ssize_t y) { 115 return Clamp1<ssize_t>(y, 0, image.ysize() - 1); 116 }; 117 JXL_CHECK(RunOnPool( 118 pool, 0, image.ysize(), &ThreadPool::NoInit, 119 [&](const int32_t y, const int32_t /*thread_id*/) { 120 const float* const JXL_RESTRICT in_rows[4] = { 121 upsampled_horizontally.ConstRow(BoundY(y - 1)), 122 upsampled_horizontally.ConstRow(y), 123 upsampled_horizontally.ConstRow(BoundY(y + 1)), 124 upsampled_horizontally.ConstRow(BoundY(y + 2)), 125 }; 126 float* const JXL_RESTRICT out_rows[2] = { 127 upsampled.Row(2 * y), 128 upsampled.Row(2 * y + 1), 129 }; 130 131 for (ssize_t x = 0; 132 x < static_cast<ssize_t>(upsampled_horizontally.xsize()); 133 x += Lanes(df)) { 134 Store(Load(df, in_rows[1] + x), df, out_rows[0] + x); 135 Store(MulAdd(Set(df, 0.5625f), 136 Add(Load(df, in_rows[1] + x), Load(df, in_rows[2] + x)), 137 Mul(Set(df, -0.0625f), Add(Load(df, in_rows[0] + x), 138 Load(df, in_rows[3] + x)))), 139 df, out_rows[1] + x); 140 } 141 }, 142 "UpsampleVertically")); 143 return upsampled; 144 } 145 146 float ComputeOffset(const ImageF& original_luminances, 147 const ImageF& upsampled_blurred_luminances) { 148 HWY_CAPPED(float, kDownsampling) df; 149 float max_difference = 0.f; 150 for (size_t y = 0; y < original_luminances.ysize(); ++y) { 151 const float* const JXL_RESTRICT original_row = 152 original_luminances.ConstRow(y); 153 for (size_t x = 0; x < original_luminances.xsize(); ++x) { 154 auto block_min = Set(df, std::numeric_limits<float>::infinity()); 155 for (size_t ky = 0; ky < kDownsampling; ++ky) { 156 const float* const JXL_RESTRICT blurred_row = 157 upsampled_blurred_luminances.ConstRow(kDownsampling * y + ky); 158 for (size_t kx = 0; kx < kDownsampling; kx += Lanes(df)) { 159 block_min = 160 Min(block_min, Load(df, blurred_row + kDownsampling * x + kx)); 161 } 162 } 163 164 const float difference = 165 original_row[x] - GetLane(MinOfLanes(df, block_min)); 166 if (difference > max_difference) max_difference = difference; 167 } 168 } 169 return max_difference; 170 } 171 172 Status ApplyLocalToneMapping(const ImageF& blurred_luminances, 173 const float intensity_target, 174 const float max_difference, Image3F* color, 175 ThreadPool* pool) { 176 HWY_FULL(float) df; 177 178 const auto log_default_intensity_target = 179 Set(df, FastLog2f(kDefaultIntensityTarget)); 180 const auto log_10000 = Set(df, FastLog2f(10000.f)); 181 JXL_RETURN_IF_ERROR(RunOnPool( 182 pool, 0, color->ysize(), &ThreadPool::NoInit, 183 [&](const int32_t y, const int32_t /*thread_id*/) { 184 float* const JXL_RESTRICT rows[3] = {color->PlaneRow(0, y), 185 color->PlaneRow(1, y), 186 color->PlaneRow(2, y)}; 187 const float* const JXL_RESTRICT blurred_lum_row = 188 blurred_luminances.ConstRow(y); 189 190 for (size_t x = 0; x < color->xsize(); x += Lanes(df)) { 191 const auto log_local_max = 192 Add(Load(df, blurred_lum_row + x), Set(df, max_difference)); 193 const auto luminance = 194 ComputeLuminance(intensity_target, Load(df, rows[0] + x), 195 Load(df, rows[1] + x), Load(df, rows[2] + x)); 196 const auto log_luminance = FastLog2f(df, luminance); 197 const auto log_knee = 198 Mul(log_default_intensity_target, 199 MulAdd(Set(df, -0.85f), 200 Div(Sub(log_local_max, log_default_intensity_target), 201 Sub(log_10000, log_default_intensity_target)), 202 Set(df, 1.f))); 203 const auto second_segment_position = 204 Div(Sub(log_luminance, log_knee), Sub(log_local_max, log_knee)); 205 const auto log_new_luminance = IfThenElse( 206 Lt(log_luminance, log_knee), log_luminance, 207 MulAdd( 208 second_segment_position, 209 MulAdd(Sub(log_default_intensity_target, log_knee), 210 second_segment_position, Sub(log_knee, log_luminance)), 211 log_luminance)); 212 const auto new_luminance = FastPow2f(df, log_new_luminance); 213 const auto ratio = 214 Div(Mul(Set(df, intensity_target), new_luminance), 215 Mul(luminance, Set(df, kDefaultIntensityTarget))); 216 for (int c = 0; c < 3; ++c) { 217 Store(Mul(ratio, Load(df, rows[c] + x)), df, rows[c] + x); 218 } 219 } 220 }, 221 "ApplyLocalToneMapping")); 222 223 return true; 224 } 225 226 } // namespace 227 } // namespace HWY_NAMESPACE 228 } // namespace jxl 229 HWY_AFTER_NAMESPACE(); 230 231 #if HWY_ONCE 232 233 namespace jxl { 234 namespace { 235 236 HWY_EXPORT(DownsampledLuminances); 237 HWY_EXPORT(Upsample); 238 HWY_EXPORT(ComputeOffset); 239 HWY_EXPORT(ApplyLocalToneMapping); 240 241 void Blur(ImageF* image) { 242 static constexpr WeightsSeparable5 kBlurFilter = { 243 {HWY_REP4(.375f), HWY_REP4(.25f), HWY_REP4(.0625f)}, 244 {HWY_REP4(.375f), HWY_REP4(.25f), HWY_REP4(.0625f)}}; 245 JXL_ASSIGN_OR_DIE(ImageF blurred_once, 246 ImageF::Create(image->xsize(), image->ysize())); 247 Separable5(*image, Rect(*image), kBlurFilter, nullptr, &blurred_once); 248 Separable5(blurred_once, Rect(blurred_once), kBlurFilter, nullptr, image); 249 } 250 251 void ProcessFrame(CodecInOut* image, float preserve_saturation, 252 ThreadPool* pool) { 253 ColorEncoding linear_rec2020; 254 JXL_CHECK(linear_rec2020.SetWhitePointType(WhitePoint::kD65)); 255 JXL_CHECK(linear_rec2020.SetPrimariesType(Primaries::k2100)); 256 linear_rec2020.Tf().SetTransferFunction(TransferFunction::kLinear); 257 JXL_CHECK(linear_rec2020.CreateICC()); 258 JXL_CHECK( 259 image->Main().TransformTo(linear_rec2020, *JxlGetDefaultCms(), pool)); 260 261 const float intensity_target = image->metadata.m.IntensityTarget(); 262 263 Image3F color = std::move(*image->Main().color()); 264 ImageF subsampled_image = 265 HWY_DYNAMIC_DISPATCH(DownsampledLuminances)(color, intensity_target); 266 JXL_ASSIGN_OR_DIE( 267 ImageF original_luminances, 268 ImageF::Create(subsampled_image.xsize(), subsampled_image.ysize())); 269 CopyImageTo(subsampled_image, &original_luminances); 270 271 Blur(&subsampled_image); 272 const auto& Upsample = HWY_DYNAMIC_DISPATCH(Upsample); 273 ImageF blurred_luminances = std::move(subsampled_image); 274 for (int downsampling = HWY_NAMESPACE::kDownsampling; downsampling > 1; 275 downsampling >>= 1) { 276 blurred_luminances = 277 Upsample(blurred_luminances, downsampling > 4 ? nullptr : pool); 278 } 279 280 const float max_difference = HWY_DYNAMIC_DISPATCH(ComputeOffset)( 281 original_luminances, blurred_luminances); 282 283 JXL_CHECK(HWY_DYNAMIC_DISPATCH(ApplyLocalToneMapping)( 284 blurred_luminances, intensity_target, max_difference, &color, pool)); 285 286 image->SetFromImage(std::move(color), linear_rec2020); 287 image->metadata.m.color_encoding = linear_rec2020; 288 image->metadata.m.SetIntensityTarget(kDefaultIntensityTarget); 289 290 JXL_CHECK(GamutMap(image, preserve_saturation, pool)); 291 292 ColorEncoding rec2020_srgb = linear_rec2020; 293 rec2020_srgb.Tf().SetTransferFunction(TransferFunction::kSRGB); 294 JXL_CHECK(rec2020_srgb.CreateICC()); 295 JXL_CHECK(image->Main().TransformTo(rec2020_srgb, *JxlGetDefaultCms(), pool)); 296 image->metadata.m.color_encoding = rec2020_srgb; 297 } 298 299 } // namespace 300 } // namespace jxl 301 302 int main(int argc, const char** argv) { 303 jpegxl::tools::ThreadPoolInternal pool(8); 304 305 jpegxl::tools::CommandLineParser parser; 306 float preserve_saturation = .4f; 307 parser.AddOptionValue( 308 's', "preserve_saturation", "0..1", 309 "to what extent to try and preserve saturation over luminance", 310 &preserve_saturation, &jpegxl::tools::ParseFloat, 0); 311 const char* input_filename = nullptr; 312 auto input_filename_option = parser.AddPositionalOption( 313 "input", true, "input image", &input_filename, 0); 314 const char* output_filename = nullptr; 315 auto output_filename_option = parser.AddPositionalOption( 316 "output", true, "output image", &output_filename, 0); 317 318 if (!parser.Parse(argc, argv)) { 319 fprintf(stderr, "See -h for help.\n"); 320 return EXIT_FAILURE; 321 } 322 323 if (parser.HelpFlagPassed()) { 324 parser.PrintHelp(); 325 return EXIT_SUCCESS; 326 } 327 328 if (!parser.GetOption(input_filename_option)->matched()) { 329 fprintf(stderr, "Missing input filename.\nSee -h for help.\n"); 330 return EXIT_FAILURE; 331 } 332 if (!parser.GetOption(output_filename_option)->matched()) { 333 fprintf(stderr, "Missing output filename.\nSee -h for help.\n"); 334 return EXIT_FAILURE; 335 } 336 337 jxl::CodecInOut image; 338 jxl::extras::ColorHints color_hints; 339 color_hints.Add("color_space", "RGB_D65_202_Rel_PeQ"); 340 std::vector<uint8_t> encoded; 341 JXL_CHECK(jpegxl::tools::ReadFile(input_filename, &encoded)); 342 JXL_CHECK(jxl::SetFromBytes(jxl::Bytes(encoded), color_hints, &image, &pool)); 343 344 jxl::ProcessFrame(&image, preserve_saturation, &pool); 345 346 JxlPixelFormat format = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; 347 jxl::extras::PackedPixelFile ppf = 348 jxl::extras::ConvertImage3FToPackedPixelFile( 349 *image.Main().color(), image.metadata.m.color_encoding, format, 350 &pool); 351 JXL_CHECK(jxl::Encode(ppf, output_filename, &encoded, &pool)); 352 JXL_CHECK(jpegxl::tools::WriteFile(output_filename, encoded)); 353 } 354 355 #endif