libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

dec_external_image.cc (18775B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/dec_external_image.h"
      7 
      8 #include <jxl/types.h>
      9 #include <string.h>
     10 
     11 #include <algorithm>
     12 #include <utility>
     13 #include <vector>
     14 
     15 #include "lib/jxl/base/status.h"
     16 
     17 #undef HWY_TARGET_INCLUDE
     18 #define HWY_TARGET_INCLUDE "lib/jxl/dec_external_image.cc"
     19 #include <hwy/foreach_target.h>
     20 #include <hwy/highway.h>
     21 
     22 #include "lib/jxl/alpha.h"
     23 #include "lib/jxl/base/byte_order.h"
     24 #include "lib/jxl/base/common.h"
     25 #include "lib/jxl/base/compiler_specific.h"
     26 #include "lib/jxl/base/printf_macros.h"
     27 #include "lib/jxl/sanitizers.h"
     28 
     29 HWY_BEFORE_NAMESPACE();
     30 namespace jxl {
     31 namespace HWY_NAMESPACE {
     32 
     33 // These templates are not found via ADL.
     34 using hwy::HWY_NAMESPACE::Clamp;
     35 using hwy::HWY_NAMESPACE::Mul;
     36 using hwy::HWY_NAMESPACE::NearestInt;
     37 
     38 // TODO(jon): check if this can be replaced by a FloatToU16 function
     39 void FloatToU32(const float* in, uint32_t* out, size_t num, float mul,
     40                 size_t bits_per_sample) {
     41   const HWY_FULL(float) d;
     42   const hwy::HWY_NAMESPACE::Rebind<uint32_t, decltype(d)> du;
     43 
     44   // Unpoison accessing partially-uninitialized vectors with memory sanitizer.
     45   // This is because we run NearestInt() on the vector, which triggers MSAN even
     46   // it is safe to do so since the values are not mixed between lanes.
     47   const size_t num_round_up = RoundUpTo(num, Lanes(d));
     48   msan::UnpoisonMemory(in + num, sizeof(in[0]) * (num_round_up - num));
     49 
     50   const auto one = Set(d, 1.0f);
     51   const auto scale = Set(d, mul);
     52   for (size_t x = 0; x < num; x += Lanes(d)) {
     53     auto v = Load(d, in + x);
     54     // Clamp turns NaN to 'min'.
     55     v = Clamp(v, Zero(d), one);
     56     auto i = NearestInt(Mul(v, scale));
     57     Store(BitCast(du, i), du, out + x);
     58   }
     59 
     60   // Poison back the output.
     61   msan::PoisonMemory(out + num, sizeof(out[0]) * (num_round_up - num));
     62 }
     63 
     64 void FloatToF16(const float* in, hwy::float16_t* out, size_t num) {
     65   const HWY_FULL(float) d;
     66   const hwy::HWY_NAMESPACE::Rebind<hwy::float16_t, decltype(d)> du;
     67 
     68   // Unpoison accessing partially-uninitialized vectors with memory sanitizer.
     69   // This is because we run DemoteTo() on the vector which triggers msan.
     70   const size_t num_round_up = RoundUpTo(num, Lanes(d));
     71   msan::UnpoisonMemory(in + num, sizeof(in[0]) * (num_round_up - num));
     72 
     73   for (size_t x = 0; x < num; x += Lanes(d)) {
     74     auto v = Load(d, in + x);
     75     auto v16 = DemoteTo(du, v);
     76     Store(v16, du, out + x);
     77   }
     78 
     79   // Poison back the output.
     80   msan::PoisonMemory(out + num, sizeof(out[0]) * (num_round_up - num));
     81 }
     82 
     83 // NOLINTNEXTLINE(google-readability-namespace-comments)
     84 }  // namespace HWY_NAMESPACE
     85 }  // namespace jxl
     86 HWY_AFTER_NAMESPACE();
     87 
     88 #if HWY_ONCE
     89 
     90 namespace jxl {
     91 namespace {
     92 
     93 // Stores a float in big endian
     94 void StoreBEFloat(float value, uint8_t* p) {
     95   uint32_t u;
     96   memcpy(&u, &value, 4);
     97   StoreBE32(u, p);
     98 }
     99 
    100 // Stores a float in little endian
    101 void StoreLEFloat(float value, uint8_t* p) {
    102   uint32_t u;
    103   memcpy(&u, &value, 4);
    104   StoreLE32(u, p);
    105 }
    106 
    107 // The orientation may not be identity.
    108 // TODO(lode): SIMDify where possible
    109 template <typename T>
    110 Status UndoOrientation(jxl::Orientation undo_orientation, const Plane<T>& image,
    111                        Plane<T>& out, jxl::ThreadPool* pool) {
    112   const size_t xsize = image.xsize();
    113   const size_t ysize = image.ysize();
    114 
    115   if (undo_orientation == Orientation::kFlipHorizontal) {
    116     JXL_ASSIGN_OR_RETURN(out, Plane<T>::Create(xsize, ysize));
    117     JXL_RETURN_IF_ERROR(RunOnPool(
    118         pool, 0, static_cast<uint32_t>(ysize), ThreadPool::NoInit,
    119         [&](const uint32_t task, size_t /*thread*/) {
    120           const int64_t y = task;
    121           const T* JXL_RESTRICT row_in = image.Row(y);
    122           T* JXL_RESTRICT row_out = out.Row(y);
    123           for (size_t x = 0; x < xsize; ++x) {
    124             row_out[xsize - x - 1] = row_in[x];
    125           }
    126         },
    127         "UndoOrientation"));
    128   } else if (undo_orientation == Orientation::kRotate180) {
    129     JXL_ASSIGN_OR_RETURN(out, Plane<T>::Create(xsize, ysize));
    130     JXL_RETURN_IF_ERROR(RunOnPool(
    131         pool, 0, static_cast<uint32_t>(ysize), ThreadPool::NoInit,
    132         [&](const uint32_t task, size_t /*thread*/) {
    133           const int64_t y = task;
    134           const T* JXL_RESTRICT row_in = image.Row(y);
    135           T* JXL_RESTRICT row_out = out.Row(ysize - y - 1);
    136           for (size_t x = 0; x < xsize; ++x) {
    137             row_out[xsize - x - 1] = row_in[x];
    138           }
    139         },
    140         "UndoOrientation"));
    141   } else if (undo_orientation == Orientation::kFlipVertical) {
    142     JXL_ASSIGN_OR_RETURN(out, Plane<T>::Create(xsize, ysize));
    143     JXL_RETURN_IF_ERROR(RunOnPool(
    144         pool, 0, static_cast<uint32_t>(ysize), ThreadPool::NoInit,
    145         [&](const uint32_t task, size_t /*thread*/) {
    146           const int64_t y = task;
    147           const T* JXL_RESTRICT row_in = image.Row(y);
    148           T* JXL_RESTRICT row_out = out.Row(ysize - y - 1);
    149           for (size_t x = 0; x < xsize; ++x) {
    150             row_out[x] = row_in[x];
    151           }
    152         },
    153         "UndoOrientation"));
    154   } else if (undo_orientation == Orientation::kTranspose) {
    155     JXL_ASSIGN_OR_RETURN(out, Plane<T>::Create(ysize, xsize));
    156     JXL_RETURN_IF_ERROR(RunOnPool(
    157         pool, 0, static_cast<uint32_t>(ysize), ThreadPool::NoInit,
    158         [&](const uint32_t task, size_t /*thread*/) {
    159           const int64_t y = task;
    160           const T* JXL_RESTRICT row_in = image.Row(y);
    161           for (size_t x = 0; x < xsize; ++x) {
    162             out.Row(x)[y] = row_in[x];
    163           }
    164         },
    165         "UndoOrientation"));
    166   } else if (undo_orientation == Orientation::kRotate90) {
    167     JXL_ASSIGN_OR_RETURN(out, Plane<T>::Create(ysize, xsize));
    168     JXL_RETURN_IF_ERROR(RunOnPool(
    169         pool, 0, static_cast<uint32_t>(ysize), ThreadPool::NoInit,
    170         [&](const uint32_t task, size_t /*thread*/) {
    171           const int64_t y = task;
    172           const T* JXL_RESTRICT row_in = image.Row(y);
    173           for (size_t x = 0; x < xsize; ++x) {
    174             out.Row(x)[ysize - y - 1] = row_in[x];
    175           }
    176         },
    177         "UndoOrientation"));
    178   } else if (undo_orientation == Orientation::kAntiTranspose) {
    179     JXL_ASSIGN_OR_RETURN(out, Plane<T>::Create(ysize, xsize));
    180     JXL_RETURN_IF_ERROR(RunOnPool(
    181         pool, 0, static_cast<uint32_t>(ysize), ThreadPool::NoInit,
    182         [&](const uint32_t task, size_t /*thread*/) {
    183           const int64_t y = task;
    184           const T* JXL_RESTRICT row_in = image.Row(y);
    185           for (size_t x = 0; x < xsize; ++x) {
    186             out.Row(xsize - x - 1)[ysize - y - 1] = row_in[x];
    187           }
    188         },
    189         "UndoOrientation"));
    190   } else if (undo_orientation == Orientation::kRotate270) {
    191     JXL_ASSIGN_OR_RETURN(out, Plane<T>::Create(ysize, xsize));
    192     JXL_RETURN_IF_ERROR(RunOnPool(
    193         pool, 0, static_cast<uint32_t>(ysize), ThreadPool::NoInit,
    194         [&](const uint32_t task, size_t /*thread*/) {
    195           const int64_t y = task;
    196           const T* JXL_RESTRICT row_in = image.Row(y);
    197           for (size_t x = 0; x < xsize; ++x) {
    198             out.Row(xsize - x - 1)[y] = row_in[x];
    199           }
    200         },
    201         "UndoOrientation"));
    202   }
    203   return true;
    204 }
    205 }  // namespace
    206 
    207 HWY_EXPORT(FloatToU32);
    208 HWY_EXPORT(FloatToF16);
    209 
    210 namespace {
    211 
    212 using StoreFuncType = void(uint32_t value, uint8_t* dest);
    213 template <StoreFuncType StoreFunc>
    214 void StoreUintRow(uint32_t* JXL_RESTRICT* rows_u32, size_t num_channels,
    215                   size_t xsize, size_t bytes_per_sample,
    216                   uint8_t* JXL_RESTRICT out) {
    217   for (size_t x = 0; x < xsize; ++x) {
    218     for (size_t c = 0; c < num_channels; c++) {
    219       StoreFunc(rows_u32[c][x],
    220                 out + (num_channels * x + c) * bytes_per_sample);
    221     }
    222   }
    223 }
    224 
    225 template <void(StoreFunc)(float, uint8_t*)>
    226 void StoreFloatRow(const float* JXL_RESTRICT* rows_in, size_t num_channels,
    227                    size_t xsize, uint8_t* JXL_RESTRICT out) {
    228   for (size_t x = 0; x < xsize; ++x) {
    229     for (size_t c = 0; c < num_channels; c++) {
    230       StoreFunc(rows_in[c][x], out + (num_channels * x + c) * sizeof(float));
    231     }
    232   }
    233 }
    234 
    235 void JXL_INLINE Store8(uint32_t value, uint8_t* dest) { *dest = value & 0xff; }
    236 
    237 }  // namespace
    238 
    239 Status ConvertChannelsToExternal(const ImageF* in_channels[],
    240                                  size_t num_channels, size_t bits_per_sample,
    241                                  bool float_out, JxlEndianness endianness,
    242                                  size_t stride, jxl::ThreadPool* pool,
    243                                  void* out_image, size_t out_size,
    244                                  const PixelCallback& out_callback,
    245                                  jxl::Orientation undo_orientation) {
    246   JXL_DASSERT(num_channels != 0 && num_channels <= kConvertMaxChannels);
    247   JXL_DASSERT(in_channels[0] != nullptr);
    248   JXL_CHECK(float_out ? bits_per_sample == 16 || bits_per_sample == 32
    249                       : bits_per_sample > 0 && bits_per_sample <= 16);
    250   const bool has_out_image = (out_image != nullptr);
    251   if (has_out_image == out_callback.IsPresent()) {
    252     return JXL_FAILURE(
    253         "Must provide either an out_image or an out_callback, but not both.");
    254   }
    255   std::vector<const ImageF*> channels;
    256   channels.assign(in_channels, in_channels + num_channels);
    257 
    258   const size_t bytes_per_channel = DivCeil(bits_per_sample, jxl::kBitsPerByte);
    259   const size_t bytes_per_pixel = num_channels * bytes_per_channel;
    260 
    261   std::vector<std::vector<uint8_t>> row_out_callback;
    262   const auto FreeCallbackOpaque = [&out_callback](void* p) {
    263     out_callback.destroy(p);
    264   };
    265   std::unique_ptr<void, decltype(FreeCallbackOpaque)> out_run_opaque(
    266       nullptr, FreeCallbackOpaque);
    267   auto InitOutCallback = [&](size_t num_threads) -> Status {
    268     if (out_callback.IsPresent()) {
    269       out_run_opaque.reset(out_callback.Init(num_threads, stride));
    270       JXL_RETURN_IF_ERROR(out_run_opaque != nullptr);
    271       row_out_callback.resize(num_threads);
    272       for (size_t i = 0; i < num_threads; ++i) {
    273         row_out_callback[i].resize(stride);
    274       }
    275     }
    276     return true;
    277   };
    278 
    279   // Channels used to store the transformed original channels if needed.
    280   ImageF temp_channels[kConvertMaxChannels];
    281   if (undo_orientation != Orientation::kIdentity) {
    282     for (size_t c = 0; c < num_channels; ++c) {
    283       if (channels[c]) {
    284         JXL_RETURN_IF_ERROR(UndoOrientation(undo_orientation, *channels[c],
    285                                             temp_channels[c], pool));
    286         channels[c] = &(temp_channels[c]);
    287       }
    288     }
    289   }
    290 
    291   // First channel may not be nullptr.
    292   size_t xsize = channels[0]->xsize();
    293   size_t ysize = channels[0]->ysize();
    294   if (stride < bytes_per_pixel * xsize) {
    295     return JXL_FAILURE("stride is smaller than scanline width in bytes: %" PRIuS
    296                        " vs %" PRIuS,
    297                        stride, bytes_per_pixel * xsize);
    298   }
    299   if (!out_callback.IsPresent() &&
    300       out_size < (ysize - 1) * stride + bytes_per_pixel * xsize) {
    301     return JXL_FAILURE("out_size is too small to store image");
    302   }
    303 
    304   const bool little_endian =
    305       endianness == JXL_LITTLE_ENDIAN ||
    306       (endianness == JXL_NATIVE_ENDIAN && IsLittleEndian());
    307 
    308   // Handle the case where a channel is nullptr by creating a single row with
    309   // ones to use instead.
    310   ImageF ones;
    311   for (size_t c = 0; c < num_channels; ++c) {
    312     if (!channels[c]) {
    313       JXL_ASSIGN_OR_RETURN(ones, ImageF::Create(xsize, 1));
    314       FillImage(1.0f, &ones);
    315       break;
    316     }
    317   }
    318 
    319   if (float_out) {
    320     if (bits_per_sample == 16) {
    321       bool swap_endianness = little_endian != IsLittleEndian();
    322       Plane<hwy::float16_t> f16_cache;
    323       JXL_RETURN_IF_ERROR(RunOnPool(
    324           pool, 0, static_cast<uint32_t>(ysize),
    325           [&](size_t num_threads) {
    326             StatusOr<Plane<hwy::float16_t>> f16_cache_or =
    327                 Plane<hwy::float16_t>::Create(xsize,
    328                                               num_channels * num_threads);
    329             if (!f16_cache_or.ok()) return false;
    330             f16_cache = std::move(f16_cache_or).value();
    331             return !!InitOutCallback(num_threads);
    332           },
    333           [&](const uint32_t task, const size_t thread) {
    334             const int64_t y = task;
    335             const float* JXL_RESTRICT row_in[kConvertMaxChannels];
    336             for (size_t c = 0; c < num_channels; c++) {
    337               row_in[c] = channels[c] ? channels[c]->Row(y) : ones.Row(0);
    338             }
    339             hwy::float16_t* JXL_RESTRICT row_f16[kConvertMaxChannels];
    340             for (size_t c = 0; c < num_channels; c++) {
    341               row_f16[c] = f16_cache.Row(c + thread * num_channels);
    342               HWY_DYNAMIC_DISPATCH(FloatToF16)
    343               (row_in[c], row_f16[c], xsize);
    344             }
    345             uint8_t* row_out =
    346                 out_callback.IsPresent()
    347                     ? row_out_callback[thread].data()
    348                     : &(reinterpret_cast<uint8_t*>(out_image))[stride * y];
    349             // interleave the one scanline
    350             hwy::float16_t* row_f16_out =
    351                 reinterpret_cast<hwy::float16_t*>(row_out);
    352             for (size_t x = 0; x < xsize; x++) {
    353               for (size_t c = 0; c < num_channels; c++) {
    354                 row_f16_out[x * num_channels + c] = row_f16[c][x];
    355               }
    356             }
    357             if (swap_endianness) {
    358               size_t size = xsize * num_channels * 2;
    359               for (size_t i = 0; i < size; i += 2) {
    360                 std::swap(row_out[i + 0], row_out[i + 1]);
    361               }
    362             }
    363             if (out_callback.IsPresent()) {
    364               out_callback.run(out_run_opaque.get(), thread, 0, y, xsize,
    365                                row_out);
    366             }
    367           },
    368           "ConvertF16"));
    369     } else if (bits_per_sample == 32) {
    370       JXL_RETURN_IF_ERROR(RunOnPool(
    371           pool, 0, static_cast<uint32_t>(ysize),
    372           [&](size_t num_threads) { return InitOutCallback(num_threads); },
    373           [&](const uint32_t task, const size_t thread) {
    374             const int64_t y = task;
    375             uint8_t* row_out =
    376                 out_callback.IsPresent()
    377                     ? row_out_callback[thread].data()
    378                     : &(reinterpret_cast<uint8_t*>(out_image))[stride * y];
    379             const float* JXL_RESTRICT row_in[kConvertMaxChannels];
    380             for (size_t c = 0; c < num_channels; c++) {
    381               row_in[c] = channels[c] ? channels[c]->Row(y) : ones.Row(0);
    382             }
    383             if (little_endian) {
    384               StoreFloatRow<StoreLEFloat>(row_in, num_channels, xsize, row_out);
    385             } else {
    386               StoreFloatRow<StoreBEFloat>(row_in, num_channels, xsize, row_out);
    387             }
    388             if (out_callback.IsPresent()) {
    389               out_callback.run(out_run_opaque.get(), thread, 0, y, xsize,
    390                                row_out);
    391             }
    392           },
    393           "ConvertFloat"));
    394     } else {
    395       return JXL_FAILURE("float other than 16-bit and 32-bit not supported");
    396     }
    397   } else {
    398     // Multiplier to convert from floating point 0-1 range to the integer
    399     // range.
    400     float mul = (1ull << bits_per_sample) - 1;
    401     Plane<uint32_t> u32_cache;
    402     JXL_RETURN_IF_ERROR(RunOnPool(
    403         pool, 0, static_cast<uint32_t>(ysize),
    404         [&](size_t num_threads) {
    405           StatusOr<Plane<uint32_t>> u32_cache_or =
    406               Plane<uint32_t>::Create(xsize, num_channels * num_threads);
    407           if (!u32_cache_or.ok()) return false;
    408           u32_cache = std::move(u32_cache_or).value();
    409           return !!InitOutCallback(num_threads);
    410         },
    411         [&](const uint32_t task, const size_t thread) {
    412           const int64_t y = task;
    413           uint8_t* row_out =
    414               out_callback.IsPresent()
    415                   ? row_out_callback[thread].data()
    416                   : &(reinterpret_cast<uint8_t*>(out_image))[stride * y];
    417           const float* JXL_RESTRICT row_in[kConvertMaxChannels];
    418           for (size_t c = 0; c < num_channels; c++) {
    419             row_in[c] = channels[c] ? channels[c]->Row(y) : ones.Row(0);
    420           }
    421           uint32_t* JXL_RESTRICT row_u32[kConvertMaxChannels];
    422           for (size_t c = 0; c < num_channels; c++) {
    423             row_u32[c] = u32_cache.Row(c + thread * num_channels);
    424             // row_u32[] is a per-thread temporary row storage, this isn't
    425             // intended to be initialized on a previous run.
    426             msan::PoisonMemory(row_u32[c], xsize * sizeof(row_u32[c][0]));
    427             HWY_DYNAMIC_DISPATCH(FloatToU32)
    428             (row_in[c], row_u32[c], xsize, mul, bits_per_sample);
    429           }
    430           if (bits_per_sample <= 8) {
    431             StoreUintRow<Store8>(row_u32, num_channels, xsize, 1, row_out);
    432           } else {
    433             if (little_endian) {
    434               StoreUintRow<StoreLE16>(row_u32, num_channels, xsize, 2, row_out);
    435             } else {
    436               StoreUintRow<StoreBE16>(row_u32, num_channels, xsize, 2, row_out);
    437             }
    438           }
    439           if (out_callback.IsPresent()) {
    440             out_callback.run(out_run_opaque.get(), thread, 0, y, xsize,
    441                              row_out);
    442           }
    443         },
    444         "ConvertUint"));
    445   }
    446   return true;
    447 }
    448 
    449 Status ConvertToExternal(const jxl::ImageBundle& ib, size_t bits_per_sample,
    450                          bool float_out, size_t num_channels,
    451                          JxlEndianness endianness, size_t stride,
    452                          jxl::ThreadPool* pool, void* out_image,
    453                          size_t out_size, const PixelCallback& out_callback,
    454                          jxl::Orientation undo_orientation,
    455                          bool unpremul_alpha) {
    456   bool want_alpha = num_channels == 2 || num_channels == 4;
    457   size_t color_channels = num_channels <= 2 ? 1 : 3;
    458 
    459   const Image3F* color = &ib.color();
    460   // Undo premultiplied alpha.
    461   Image3F unpremul;
    462   if (ib.AlphaIsPremultiplied() && ib.HasAlpha() && unpremul_alpha) {
    463     JXL_ASSIGN_OR_RETURN(unpremul,
    464                          Image3F::Create(color->xsize(), color->ysize()));
    465     CopyImageTo(*color, &unpremul);
    466     for (size_t y = 0; y < unpremul.ysize(); y++) {
    467       UnpremultiplyAlpha(unpremul.PlaneRow(0, y), unpremul.PlaneRow(1, y),
    468                          unpremul.PlaneRow(2, y), ib.alpha().Row(y),
    469                          unpremul.xsize());
    470     }
    471     color = &unpremul;
    472   }
    473 
    474   const ImageF* channels[kConvertMaxChannels];
    475   size_t c = 0;
    476   for (; c < color_channels; c++) {
    477     channels[c] = &color->Plane(c);
    478   }
    479   if (want_alpha) {
    480     channels[c++] = ib.HasAlpha() ? &ib.alpha() : nullptr;
    481   }
    482   JXL_ASSERT(num_channels == c);
    483 
    484   return ConvertChannelsToExternal(
    485       channels, num_channels, bits_per_sample, float_out, endianness, stride,
    486       pool, out_image, out_size, out_callback, undo_orientation);
    487 }
    488 
    489 }  // namespace jxl
    490 #endif  // HWY_ONCE