libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

squeeze.cc (18212B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/modular/transform/squeeze.h"
      7 
      8 #include <stdlib.h>
      9 
     10 #include "lib/jxl/base/common.h"
     11 #include "lib/jxl/base/data_parallel.h"
     12 #include "lib/jxl/base/printf_macros.h"
     13 #include "lib/jxl/modular/modular_image.h"
     14 #include "lib/jxl/modular/transform/transform.h"
     15 #undef HWY_TARGET_INCLUDE
     16 #define HWY_TARGET_INCLUDE "lib/jxl/modular/transform/squeeze.cc"
     17 #include <hwy/foreach_target.h>
     18 #include <hwy/highway.h>
     19 
     20 #include "lib/jxl/simd_util-inl.h"
     21 
     22 HWY_BEFORE_NAMESPACE();
     23 namespace jxl {
     24 namespace HWY_NAMESPACE {
     25 
     26 // These templates are not found via ADL.
     27 using hwy::HWY_NAMESPACE::Abs;
     28 using hwy::HWY_NAMESPACE::Add;
     29 using hwy::HWY_NAMESPACE::And;
     30 using hwy::HWY_NAMESPACE::Gt;
     31 using hwy::HWY_NAMESPACE::IfThenElse;
     32 using hwy::HWY_NAMESPACE::IfThenZeroElse;
     33 using hwy::HWY_NAMESPACE::Lt;
     34 using hwy::HWY_NAMESPACE::MulEven;
     35 using hwy::HWY_NAMESPACE::Ne;
     36 using hwy::HWY_NAMESPACE::Neg;
     37 using hwy::HWY_NAMESPACE::OddEven;
     38 using hwy::HWY_NAMESPACE::RebindToUnsigned;
     39 using hwy::HWY_NAMESPACE::ShiftLeft;
     40 using hwy::HWY_NAMESPACE::ShiftRight;
     41 using hwy::HWY_NAMESPACE::Sub;
     42 using hwy::HWY_NAMESPACE::Xor;
     43 
     44 #if HWY_TARGET != HWY_SCALAR
     45 
     46 JXL_INLINE void FastUnsqueeze(const pixel_type *JXL_RESTRICT p_residual,
     47                               const pixel_type *JXL_RESTRICT p_avg,
     48                               const pixel_type *JXL_RESTRICT p_navg,
     49                               const pixel_type *p_pout,
     50                               pixel_type *JXL_RESTRICT p_out,
     51                               pixel_type *p_nout) {
     52   const HWY_CAPPED(pixel_type, 8) d;
     53   const RebindToUnsigned<decltype(d)> du;
     54   const size_t N = Lanes(d);
     55   auto onethird = Set(d, 0x55555556);
     56   for (size_t x = 0; x < 8; x += N) {
     57     auto avg = Load(d, p_avg + x);
     58     auto next_avg = Load(d, p_navg + x);
     59     auto top = Load(d, p_pout + x);
     60     // Equivalent to SmoothTendency(top,avg,next_avg), but without branches
     61     auto Ba = Sub(top, avg);
     62     auto an = Sub(avg, next_avg);
     63     auto nonmono = Xor(Ba, an);
     64     auto absBa = Abs(Ba);
     65     auto absan = Abs(an);
     66     auto absBn = Abs(Sub(top, next_avg));
     67     // Compute a3 = absBa / 3
     68     auto a3e = BitCast(d, ShiftRight<32>(MulEven(absBa, onethird)));
     69     auto a3oi = MulEven(Reverse(d, absBa), onethird);
     70     auto a3o = BitCast(
     71         d, Reverse(hwy::HWY_NAMESPACE::Repartition<pixel_type_w, decltype(d)>(),
     72                    a3oi));
     73     auto a3 = OddEven(a3o, a3e);
     74     a3 = Add(a3, Add(absBn, Set(d, 2)));
     75     auto absdiff = ShiftRight<2>(a3);
     76     auto skipdiff = Ne(Ba, Zero(d));
     77     skipdiff = And(skipdiff, Ne(an, Zero(d)));
     78     skipdiff = And(skipdiff, Lt(nonmono, Zero(d)));
     79     auto absBa2 = Add(ShiftLeft<1>(absBa), And(absdiff, Set(d, 1)));
     80     absdiff = IfThenElse(Gt(absdiff, absBa2),
     81                          Add(ShiftLeft<1>(absBa), Set(d, 1)), absdiff);
     82     auto absan2 = ShiftLeft<1>(absan);
     83     absdiff = IfThenElse(Gt(Add(absdiff, And(absdiff, Set(d, 1))), absan2),
     84                          absan2, absdiff);
     85     auto diff1 = IfThenElse(Lt(top, next_avg), Neg(absdiff), absdiff);
     86     auto tendency = IfThenZeroElse(skipdiff, diff1);
     87 
     88     auto diff_minus_tendency = Load(d, p_residual + x);
     89     auto diff = Add(diff_minus_tendency, tendency);
     90     auto out =
     91         Add(avg, ShiftRight<1>(
     92                      Add(diff, BitCast(d, ShiftRight<31>(BitCast(du, diff))))));
     93     Store(out, d, p_out + x);
     94     Store(Sub(out, diff), d, p_nout + x);
     95   }
     96 }
     97 
     98 #endif
     99 
    100 Status InvHSqueeze(Image &input, uint32_t c, uint32_t rc, ThreadPool *pool) {
    101   JXL_ASSERT(c < input.channel.size());
    102   JXL_ASSERT(rc < input.channel.size());
    103   Channel &chin = input.channel[c];
    104   const Channel &chin_residual = input.channel[rc];
    105   // These must be valid since we ran MetaApply already.
    106   JXL_ASSERT(chin.w == DivCeil(chin.w + chin_residual.w, 2));
    107   JXL_ASSERT(chin.h == chin_residual.h);
    108 
    109   if (chin_residual.w == 0) {
    110     // Short-circuit: output channel has same dimensions as input.
    111     input.channel[c].hshift--;
    112     return true;
    113   }
    114 
    115   // Note: chin.w >= chin_residual.w and at most 1 different.
    116   JXL_ASSIGN_OR_RETURN(Channel chout,
    117                        Channel::Create(chin.w + chin_residual.w, chin.h,
    118                                        chin.hshift - 1, chin.vshift));
    119   JXL_DEBUG_V(4,
    120               "Undoing horizontal squeeze of channel %i using residuals in "
    121               "channel %i (going from width %" PRIuS " to %" PRIuS ")",
    122               c, rc, chin.w, chout.w);
    123 
    124   if (chin_residual.h == 0) {
    125     // Short-circuit: channel with no pixels.
    126     input.channel[c] = std::move(chout);
    127     return true;
    128   }
    129   auto unsqueeze_row = [&](size_t y, size_t x0) {
    130     const pixel_type *JXL_RESTRICT p_residual = chin_residual.Row(y);
    131     const pixel_type *JXL_RESTRICT p_avg = chin.Row(y);
    132     pixel_type *JXL_RESTRICT p_out = chout.Row(y);
    133     for (size_t x = x0; x < chin_residual.w; x++) {
    134       pixel_type_w diff_minus_tendency = p_residual[x];
    135       pixel_type_w avg = p_avg[x];
    136       pixel_type_w next_avg = (x + 1 < chin.w ? p_avg[x + 1] : avg);
    137       pixel_type_w left = (x ? p_out[(x << 1) - 1] : avg);
    138       pixel_type_w tendency = SmoothTendency(left, avg, next_avg);
    139       pixel_type_w diff = diff_minus_tendency + tendency;
    140       pixel_type_w A = avg + (diff / 2);
    141       p_out[(x << 1)] = A;
    142       pixel_type_w B = A - diff;
    143       p_out[(x << 1) + 1] = B;
    144     }
    145     if (chout.w & 1) p_out[chout.w - 1] = p_avg[chin.w - 1];
    146   };
    147 
    148   // somewhat complicated trickery just to be able to SIMD this.
    149   // Horizontal unsqueeze has horizontal data dependencies, so we do
    150   // 8 rows at a time and treat it as a vertical unsqueeze of a
    151   // transposed 8x8 block (or 9x8 for one input).
    152   static constexpr const size_t kRowsPerThread = 8;
    153   const auto unsqueeze_span = [&](const uint32_t task, size_t /* thread */) {
    154     const size_t y0 = task * kRowsPerThread;
    155     const size_t rows = std::min(kRowsPerThread, chin.h - y0);
    156     size_t x = 0;
    157 
    158 #if HWY_TARGET != HWY_SCALAR
    159     intptr_t onerow_in = chin.plane.PixelsPerRow();
    160     intptr_t onerow_inr = chin_residual.plane.PixelsPerRow();
    161     intptr_t onerow_out = chout.plane.PixelsPerRow();
    162     const pixel_type *JXL_RESTRICT p_residual = chin_residual.Row(y0);
    163     const pixel_type *JXL_RESTRICT p_avg = chin.Row(y0);
    164     pixel_type *JXL_RESTRICT p_out = chout.Row(y0);
    165     HWY_ALIGN pixel_type b_p_avg[9 * kRowsPerThread];
    166     HWY_ALIGN pixel_type b_p_residual[8 * kRowsPerThread];
    167     HWY_ALIGN pixel_type b_p_out_even[8 * kRowsPerThread];
    168     HWY_ALIGN pixel_type b_p_out_odd[8 * kRowsPerThread];
    169     HWY_ALIGN pixel_type b_p_out_evenT[8 * kRowsPerThread];
    170     HWY_ALIGN pixel_type b_p_out_oddT[8 * kRowsPerThread];
    171     const HWY_CAPPED(pixel_type, 8) d;
    172     const size_t N = Lanes(d);
    173     if (chin_residual.w > 16 && rows == kRowsPerThread) {
    174       for (; x < chin_residual.w - 9; x += 8) {
    175         Transpose8x8Block(p_residual + x, b_p_residual, onerow_inr);
    176         Transpose8x8Block(p_avg + x, b_p_avg, onerow_in);
    177         for (size_t y = 0; y < kRowsPerThread; y++) {
    178           b_p_avg[8 * 8 + y] = p_avg[x + 8 + onerow_in * y];
    179         }
    180         for (size_t i = 0; i < 8; i++) {
    181           FastUnsqueeze(
    182               b_p_residual + 8 * i, b_p_avg + 8 * i, b_p_avg + 8 * (i + 1),
    183               (x + i ? b_p_out_odd + 8 * ((x + i - 1) & 7) : b_p_avg + 8 * i),
    184               b_p_out_even + 8 * i, b_p_out_odd + 8 * i);
    185         }
    186 
    187         Transpose8x8Block(b_p_out_even, b_p_out_evenT, 8);
    188         Transpose8x8Block(b_p_out_odd, b_p_out_oddT, 8);
    189         for (size_t y = 0; y < kRowsPerThread; y++) {
    190           for (size_t i = 0; i < kRowsPerThread; i += N) {
    191             auto even = Load(d, b_p_out_evenT + 8 * y + i);
    192             auto odd = Load(d, b_p_out_oddT + 8 * y + i);
    193             StoreInterleaved(d, even, odd,
    194                              p_out + ((x + i) << 1) + onerow_out * y);
    195           }
    196         }
    197       }
    198     }
    199 #endif
    200     for (size_t y = 0; y < rows; y++) {
    201       unsqueeze_row(y0 + y, x);
    202     }
    203   };
    204   JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, DivCeil(chin.h, kRowsPerThread),
    205                                 ThreadPool::NoInit, unsqueeze_span,
    206                                 "InvHorizontalSqueeze"));
    207   input.channel[c] = std::move(chout);
    208   return true;
    209 }
    210 
    211 Status InvVSqueeze(Image &input, uint32_t c, uint32_t rc, ThreadPool *pool) {
    212   JXL_ASSERT(c < input.channel.size());
    213   JXL_ASSERT(rc < input.channel.size());
    214   const Channel &chin = input.channel[c];
    215   const Channel &chin_residual = input.channel[rc];
    216   // These must be valid since we ran MetaApply already.
    217   JXL_ASSERT(chin.h == DivCeil(chin.h + chin_residual.h, 2));
    218   JXL_ASSERT(chin.w == chin_residual.w);
    219 
    220   if (chin_residual.h == 0) {
    221     // Short-circuit: output channel has same dimensions as input.
    222     input.channel[c].vshift--;
    223     return true;
    224   }
    225 
    226   // Note: chin.h >= chin_residual.h and at most 1 different.
    227   JXL_ASSIGN_OR_RETURN(Channel chout,
    228                        Channel::Create(chin.w, chin.h + chin_residual.h,
    229                                        chin.hshift, chin.vshift - 1));
    230   JXL_DEBUG_V(
    231       4,
    232       "Undoing vertical squeeze of channel %i using residuals in channel "
    233       "%i (going from height %" PRIuS " to %" PRIuS ")",
    234       c, rc, chin.h, chout.h);
    235 
    236   if (chin_residual.w == 0) {
    237     // Short-circuit: channel with no pixels.
    238     input.channel[c] = std::move(chout);
    239     return true;
    240   }
    241 
    242   static constexpr const int kColsPerThread = 64;
    243   const auto unsqueeze_slice = [&](const uint32_t task, size_t /* thread */) {
    244     const size_t x0 = task * kColsPerThread;
    245     const size_t x1 =
    246         std::min(static_cast<size_t>(task + 1) * kColsPerThread, chin.w);
    247     const size_t w = x1 - x0;
    248     // We only iterate up to std::min(chin_residual.h, chin.h) which is
    249     // always chin_residual.h.
    250     for (size_t y = 0; y < chin_residual.h; y++) {
    251       const pixel_type *JXL_RESTRICT p_residual = chin_residual.Row(y) + x0;
    252       const pixel_type *JXL_RESTRICT p_avg = chin.Row(y) + x0;
    253       const pixel_type *JXL_RESTRICT p_navg =
    254           chin.Row(y + 1 < chin.h ? y + 1 : y) + x0;
    255       pixel_type *JXL_RESTRICT p_out = chout.Row(y << 1) + x0;
    256       pixel_type *JXL_RESTRICT p_nout = chout.Row((y << 1) + 1) + x0;
    257       const pixel_type *p_pout = y > 0 ? chout.Row((y << 1) - 1) + x0 : p_avg;
    258       size_t x = 0;
    259 #if HWY_TARGET != HWY_SCALAR
    260       for (; x + 7 < w; x += 8) {
    261         FastUnsqueeze(p_residual + x, p_avg + x, p_navg + x, p_pout + x,
    262                       p_out + x, p_nout + x);
    263       }
    264 #endif
    265       for (; x < w; x++) {
    266         pixel_type_w avg = p_avg[x];
    267         pixel_type_w next_avg = p_navg[x];
    268         pixel_type_w top = p_pout[x];
    269         pixel_type_w tendency = SmoothTendency(top, avg, next_avg);
    270         pixel_type_w diff_minus_tendency = p_residual[x];
    271         pixel_type_w diff = diff_minus_tendency + tendency;
    272         pixel_type_w out = avg + (diff / 2);
    273         p_out[x] = out;
    274         // If the chin_residual.h == chin.h, the output has an even number
    275         // of rows so the next line is fine. Otherwise, this loop won't
    276         // write to the last output row which is handled separately.
    277         p_nout[x] = out - diff;
    278       }
    279     }
    280   };
    281   JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, DivCeil(chin.w, kColsPerThread),
    282                                 ThreadPool::NoInit, unsqueeze_slice,
    283                                 "InvVertSqueeze"));
    284 
    285   if (chout.h & 1) {
    286     size_t y = chin.h - 1;
    287     const pixel_type *p_avg = chin.Row(y);
    288     pixel_type *p_out = chout.Row(y << 1);
    289     for (size_t x = 0; x < chin.w; x++) {
    290       p_out[x] = p_avg[x];
    291     }
    292   }
    293   input.channel[c] = std::move(chout);
    294   return true;
    295 }
    296 
    297 Status InvSqueeze(Image &input, const std::vector<SqueezeParams> &parameters,
    298                   ThreadPool *pool) {
    299   for (int i = parameters.size() - 1; i >= 0; i--) {
    300     JXL_RETURN_IF_ERROR(
    301         CheckMetaSqueezeParams(parameters[i], input.channel.size()));
    302     bool horizontal = parameters[i].horizontal;
    303     bool in_place = parameters[i].in_place;
    304     uint32_t beginc = parameters[i].begin_c;
    305     uint32_t endc = parameters[i].begin_c + parameters[i].num_c - 1;
    306     uint32_t offset;
    307     if (in_place) {
    308       offset = endc + 1;
    309     } else {
    310       offset = input.channel.size() + beginc - endc - 1;
    311     }
    312     if (beginc < input.nb_meta_channels) {
    313       // This is checked in MetaSqueeze.
    314       JXL_ASSERT(input.nb_meta_channels > parameters[i].num_c);
    315       input.nb_meta_channels -= parameters[i].num_c;
    316     }
    317 
    318     for (uint32_t c = beginc; c <= endc; c++) {
    319       uint32_t rc = offset + c - beginc;
    320       // MetaApply should imply that `rc` is within range, otherwise there's a
    321       // programming bug.
    322       JXL_ASSERT(rc < input.channel.size());
    323       if ((input.channel[c].w < input.channel[rc].w) ||
    324           (input.channel[c].h < input.channel[rc].h)) {
    325         return JXL_FAILURE("Corrupted squeeze transform");
    326       }
    327       if (horizontal) {
    328         JXL_RETURN_IF_ERROR(InvHSqueeze(input, c, rc, pool));
    329       } else {
    330         JXL_RETURN_IF_ERROR(InvVSqueeze(input, c, rc, pool));
    331       }
    332     }
    333     input.channel.erase(input.channel.begin() + offset,
    334                         input.channel.begin() + offset + (endc - beginc + 1));
    335   }
    336   return true;
    337 }
    338 
    339 }  // namespace HWY_NAMESPACE
    340 }  // namespace jxl
    341 HWY_AFTER_NAMESPACE();
    342 
    343 #if HWY_ONCE
    344 
    345 namespace jxl {
    346 
    347 HWY_EXPORT(InvSqueeze);
    348 Status InvSqueeze(Image &input, const std::vector<SqueezeParams> &parameters,
    349                   ThreadPool *pool) {
    350   return HWY_DYNAMIC_DISPATCH(InvSqueeze)(input, parameters, pool);
    351 }
    352 
    353 void DefaultSqueezeParameters(std::vector<SqueezeParams> *parameters,
    354                               const Image &image) {
    355   int nb_channels = image.channel.size() - image.nb_meta_channels;
    356 
    357   parameters->clear();
    358   size_t w = image.channel[image.nb_meta_channels].w;
    359   size_t h = image.channel[image.nb_meta_channels].h;
    360   JXL_DEBUG_V(
    361       7, "Default squeeze parameters for %" PRIuS "x%" PRIuS " image: ", w, h);
    362 
    363   // do horizontal first on wide images; vertical first on tall images
    364   bool wide = (w > h);
    365 
    366   if (nb_channels > 2 && image.channel[image.nb_meta_channels + 1].w == w &&
    367       image.channel[image.nb_meta_channels + 1].h == h) {
    368     // assume channels 1 and 2 are chroma, and can be squeezed first for 4:2:0
    369     // previews
    370     JXL_DEBUG_V(7, "(4:2:0 chroma), %" PRIuS "x%" PRIuS " image", w, h);
    371     SqueezeParams params;
    372     // horizontal chroma squeeze
    373     params.horizontal = true;
    374     params.in_place = false;
    375     params.begin_c = image.nb_meta_channels + 1;
    376     params.num_c = 2;
    377     parameters->push_back(params);
    378     params.horizontal = false;
    379     // vertical chroma squeeze
    380     parameters->push_back(params);
    381   }
    382   SqueezeParams params;
    383   params.begin_c = image.nb_meta_channels;
    384   params.num_c = nb_channels;
    385   params.in_place = true;
    386 
    387   if (!wide) {
    388     if (h > JXL_MAX_FIRST_PREVIEW_SIZE) {
    389       params.horizontal = false;
    390       parameters->push_back(params);
    391       h = (h + 1) / 2;
    392       JXL_DEBUG_V(7, "Vertical (%" PRIuS "x%" PRIuS "), ", w, h);
    393     }
    394   }
    395   while (w > JXL_MAX_FIRST_PREVIEW_SIZE || h > JXL_MAX_FIRST_PREVIEW_SIZE) {
    396     if (w > JXL_MAX_FIRST_PREVIEW_SIZE) {
    397       params.horizontal = true;
    398       parameters->push_back(params);
    399       w = (w + 1) / 2;
    400       JXL_DEBUG_V(7, "Horizontal (%" PRIuS "x%" PRIuS "), ", w, h);
    401     }
    402     if (h > JXL_MAX_FIRST_PREVIEW_SIZE) {
    403       params.horizontal = false;
    404       parameters->push_back(params);
    405       h = (h + 1) / 2;
    406       JXL_DEBUG_V(7, "Vertical (%" PRIuS "x%" PRIuS "), ", w, h);
    407     }
    408   }
    409   JXL_DEBUG_V(7, "that's it");
    410 }
    411 
    412 Status CheckMetaSqueezeParams(const SqueezeParams &parameter,
    413                               int num_channels) {
    414   int c1 = parameter.begin_c;
    415   int c2 = parameter.begin_c + parameter.num_c - 1;
    416   if (c1 < 0 || c1 >= num_channels || c2 < 0 || c2 >= num_channels || c2 < c1) {
    417     return JXL_FAILURE("Invalid channel range");
    418   }
    419   return true;
    420 }
    421 
    422 Status MetaSqueeze(Image &image, std::vector<SqueezeParams> *parameters) {
    423   if (parameters->empty()) {
    424     DefaultSqueezeParameters(parameters, image);
    425   }
    426 
    427   for (size_t i = 0; i < parameters->size(); i++) {
    428     JXL_RETURN_IF_ERROR(
    429         CheckMetaSqueezeParams((*parameters)[i], image.channel.size()));
    430     bool horizontal = (*parameters)[i].horizontal;
    431     bool in_place = (*parameters)[i].in_place;
    432     uint32_t beginc = (*parameters)[i].begin_c;
    433     uint32_t endc = (*parameters)[i].begin_c + (*parameters)[i].num_c - 1;
    434 
    435     uint32_t offset;
    436     if (beginc < image.nb_meta_channels) {
    437       if (endc >= image.nb_meta_channels) {
    438         return JXL_FAILURE("Invalid squeeze: mix of meta and nonmeta channels");
    439       }
    440       if (!in_place) {
    441         return JXL_FAILURE(
    442             "Invalid squeeze: meta channels require in-place residuals");
    443       }
    444       image.nb_meta_channels += (*parameters)[i].num_c;
    445     }
    446     if (in_place) {
    447       offset = endc + 1;
    448     } else {
    449       offset = image.channel.size();
    450     }
    451     for (uint32_t c = beginc; c <= endc; c++) {
    452       if (image.channel[c].hshift > 30 || image.channel[c].vshift > 30) {
    453         return JXL_FAILURE("Too many squeezes: shift > 30");
    454       }
    455       size_t w = image.channel[c].w;
    456       size_t h = image.channel[c].h;
    457       if (w == 0 || h == 0) return JXL_FAILURE("Squeezing empty channel");
    458       if (horizontal) {
    459         image.channel[c].w = (w + 1) / 2;
    460         if (image.channel[c].hshift >= 0) image.channel[c].hshift++;
    461         w = w - (w + 1) / 2;
    462       } else {
    463         image.channel[c].h = (h + 1) / 2;
    464         if (image.channel[c].vshift >= 0) image.channel[c].vshift++;
    465         h = h - (h + 1) / 2;
    466       }
    467       JXL_RETURN_IF_ERROR(image.channel[c].shrink());
    468       JXL_ASSIGN_OR_RETURN(Channel placeholder, Channel::Create(w, h));
    469       placeholder.hshift = image.channel[c].hshift;
    470       placeholder.vshift = image.channel[c].vshift;
    471 
    472       image.channel.insert(image.channel.begin() + offset + (c - beginc),
    473                            std::move(placeholder));
    474       JXL_DEBUG_V(8, "MetaSqueeze applied, current image: %s",
    475                   image.DebugString().c_str());
    476     }
    477   }
    478   return true;
    479 }
    480 
    481 }  // namespace jxl
    482 
    483 #endif