libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

rct.cc (5032B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/modular/transform/rct.h"
      7 #undef HWY_TARGET_INCLUDE
      8 #define HWY_TARGET_INCLUDE "lib/jxl/modular/transform/rct.cc"
      9 #include <hwy/foreach_target.h>
     10 #include <hwy/highway.h>
     11 HWY_BEFORE_NAMESPACE();
     12 namespace jxl {
     13 namespace HWY_NAMESPACE {
     14 
     15 // These templates are not found via ADL.
     16 using hwy::HWY_NAMESPACE::Add;
     17 using hwy::HWY_NAMESPACE::ShiftRight;
     18 using hwy::HWY_NAMESPACE::Sub;
     19 
     20 template <int transform_type>
     21 void InvRCTRow(const pixel_type* in0, const pixel_type* in1,
     22                const pixel_type* in2, pixel_type* out0, pixel_type* out1,
     23                pixel_type* out2, size_t w) {
     24   static_assert(transform_type >= 0 && transform_type < 7,
     25                 "Invalid transform type");
     26   int second = transform_type >> 1;
     27   int third = transform_type & 1;
     28 
     29   size_t x = 0;
     30   const HWY_FULL(pixel_type) d;
     31   const size_t N = Lanes(d);
     32   for (; x + N - 1 < w; x += N) {
     33     if (transform_type == 6) {
     34       auto Y = Load(d, in0 + x);
     35       auto Co = Load(d, in1 + x);
     36       auto Cg = Load(d, in2 + x);
     37       Y = Sub(Y, ShiftRight<1>(Cg));
     38       auto G = Add(Cg, Y);
     39       Y = Sub(Y, ShiftRight<1>(Co));
     40       auto R = Add(Y, Co);
     41       Store(R, d, out0 + x);
     42       Store(G, d, out1 + x);
     43       Store(Y, d, out2 + x);
     44     } else {
     45       auto First = Load(d, in0 + x);
     46       auto Second = Load(d, in1 + x);
     47       auto Third = Load(d, in2 + x);
     48       if (third) Third = Add(Third, First);
     49       if (second == 1) {
     50         Second = Add(Second, First);
     51       } else if (second == 2) {
     52         Second = Add(Second, ShiftRight<1>(Add(First, Third)));
     53       }
     54       Store(First, d, out0 + x);
     55       Store(Second, d, out1 + x);
     56       Store(Third, d, out2 + x);
     57     }
     58   }
     59   for (; x < w; x++) {
     60     if (transform_type == 6) {
     61       pixel_type Y = in0[x];
     62       pixel_type Co = in1[x];
     63       pixel_type Cg = in2[x];
     64       pixel_type tmp = PixelAdd(Y, -(Cg >> 1));
     65       pixel_type G = PixelAdd(Cg, tmp);
     66       pixel_type B = PixelAdd(tmp, -(Co >> 1));
     67       pixel_type R = PixelAdd(B, Co);
     68       out0[x] = R;
     69       out1[x] = G;
     70       out2[x] = B;
     71     } else {
     72       pixel_type First = in0[x];
     73       pixel_type Second = in1[x];
     74       pixel_type Third = in2[x];
     75       if (third) Third = PixelAdd(Third, First);
     76       if (second == 1) {
     77         Second = PixelAdd(Second, First);
     78       } else if (second == 2) {
     79         Second = PixelAdd(Second, (PixelAdd(First, Third) >> 1));
     80       }
     81       out0[x] = First;
     82       out1[x] = Second;
     83       out2[x] = Third;
     84     }
     85   }
     86 }
     87 
     88 Status InvRCT(Image& input, size_t begin_c, size_t rct_type, ThreadPool* pool) {
     89   JXL_RETURN_IF_ERROR(CheckEqualChannels(input, begin_c, begin_c + 2));
     90   size_t m = begin_c;
     91   Channel& c0 = input.channel[m + 0];
     92   size_t w = c0.w;
     93   size_t h = c0.h;
     94   if (rct_type == 0) {  // noop
     95     return true;
     96   }
     97   // Permutation: 0=RGB, 1=GBR, 2=BRG, 3=RBG, 4=GRB, 5=BGR
     98   int permutation = rct_type / 7;
     99   JXL_CHECK(permutation < 6);
    100   // 0-5 values have the low bit corresponding to Third and the high bits
    101   // corresponding to Second. 6 corresponds to YCoCg.
    102   //
    103   // Second: 0=nop, 1=SubtractFirst, 2=SubtractAvgFirstThird
    104   //
    105   // Third: 0=nop, 1=SubtractFirst
    106   int custom = rct_type % 7;
    107   // Special case: permute-only. Swap channels around.
    108   if (custom == 0) {
    109     Channel ch0 = std::move(input.channel[m]);
    110     Channel ch1 = std::move(input.channel[m + 1]);
    111     Channel ch2 = std::move(input.channel[m + 2]);
    112     input.channel[m + (permutation % 3)] = std::move(ch0);
    113     input.channel[m + ((permutation + 1 + permutation / 3) % 3)] =
    114         std::move(ch1);
    115     input.channel[m + ((permutation + 2 - permutation / 3) % 3)] =
    116         std::move(ch2);
    117     return true;
    118   }
    119   constexpr decltype(&InvRCTRow<0>) inv_rct_row[] = {
    120       InvRCTRow<0>, InvRCTRow<1>, InvRCTRow<2>, InvRCTRow<3>,
    121       InvRCTRow<4>, InvRCTRow<5>, InvRCTRow<6>};
    122   JXL_RETURN_IF_ERROR(RunOnPool(
    123       pool, 0, h, ThreadPool::NoInit,
    124       [&](const uint32_t task, size_t /* thread */) {
    125         const size_t y = task;
    126         const pixel_type* in0 = input.channel[m].Row(y);
    127         const pixel_type* in1 = input.channel[m + 1].Row(y);
    128         const pixel_type* in2 = input.channel[m + 2].Row(y);
    129         pixel_type* out0 = input.channel[m + (permutation % 3)].Row(y);
    130         pixel_type* out1 =
    131             input.channel[m + ((permutation + 1 + permutation / 3) % 3)].Row(y);
    132         pixel_type* out2 =
    133             input.channel[m + ((permutation + 2 - permutation / 3) % 3)].Row(y);
    134         inv_rct_row[custom](in0, in1, in2, out0, out1, out2, w);
    135       },
    136       "InvRCT"));
    137   return true;
    138 }
    139 
    140 }  // namespace HWY_NAMESPACE
    141 }  // namespace jxl
    142 HWY_AFTER_NAMESPACE();
    143 
    144 #if HWY_ONCE
    145 namespace jxl {
    146 
    147 HWY_EXPORT(InvRCT);
    148 Status InvRCT(Image& input, size_t begin_c, size_t rct_type, ThreadPool* pool) {
    149   return HWY_DYNAMIC_DISPATCH(InvRCT)(input, begin_c, rct_type, pool);
    150 }
    151 
    152 }  // namespace jxl
    153 #endif