libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

enc_fast_lossless.cc (155362B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #ifndef FJXL_SELF_INCLUDE
      7 
      8 #include "lib/jxl/enc_fast_lossless.h"
      9 
     10 #include <assert.h>
     11 #include <stdint.h>
     12 #include <string.h>
     13 
     14 #include <algorithm>
     15 #include <array>
     16 #include <limits>
     17 #include <memory>
     18 #include <vector>
     19 
     20 #if !FJXL_STANDALONE
     21 #include "lib/jxl/encode_internal.h"
     22 #endif
     23 
     24 // Enable NEON and AVX2/AVX512 if not asked to do otherwise and the compilers
     25 // support it.
     26 #if defined(__aarch64__) || defined(_M_ARM64)
     27 #include <arm_neon.h>
     28 
     29 #ifndef FJXL_ENABLE_NEON
     30 #define FJXL_ENABLE_NEON 1
     31 #endif
     32 
     33 #elif (defined(__x86_64__) || defined(_M_X64)) && !defined(_MSC_VER)
     34 #include <immintrin.h>
     35 
     36 // manually add _mm512_cvtsi512_si32 definition if missing
     37 // (e.g. with Xcode on macOS Mojave)
     38 // copied from gcc 11.1.0 include/avx512fintrin.h line 14367-14373
     39 #if defined(__clang__) &&                                           \
     40     ((!defined(__apple_build_version__) && __clang_major__ < 10) || \
     41      (defined(__apple_build_version__) && __apple_build_version__ < 12000032))
     42 inline int __attribute__((__gnu_inline__, __always_inline__, __artificial__))
     43 _mm512_cvtsi512_si32(__m512i __A) {
     44   __v16si __B = (__v16si)__A;
     45   return __B[0];
     46 }
     47 #endif
     48 
     49 // TODO(veluca): MSVC support for dynamic dispatch.
     50 #if defined(__clang__) || defined(__GNUC__)
     51 
     52 #ifndef FJXL_ENABLE_AVX2
     53 #define FJXL_ENABLE_AVX2 1
     54 #endif
     55 
     56 #ifndef FJXL_ENABLE_AVX512
     57 // On clang-7 or earlier, and gcc-10 or earlier, AVX512 seems broken.
     58 #if (defined(__clang__) &&                                             \
     59          (!defined(__apple_build_version__) && __clang_major__ > 7) || \
     60      (defined(__apple_build_version__) &&                              \
     61       __apple_build_version__ > 10010046)) ||                          \
     62     (defined(__GNUC__) && __GNUC__ > 10)
     63 #define FJXL_ENABLE_AVX512 1
     64 #endif
     65 #endif
     66 
     67 #endif
     68 
     69 #endif
     70 
     71 #ifndef FJXL_ENABLE_NEON
     72 #define FJXL_ENABLE_NEON 0
     73 #endif
     74 
     75 #ifndef FJXL_ENABLE_AVX2
     76 #define FJXL_ENABLE_AVX2 0
     77 #endif
     78 
     79 #ifndef FJXL_ENABLE_AVX512
     80 #define FJXL_ENABLE_AVX512 0
     81 #endif
     82 
     83 namespace {
     84 #if defined(_MSC_VER) && !defined(__clang__)
     85 #define FJXL_INLINE __forceinline
     86 FJXL_INLINE uint32_t FloorLog2(uint32_t v) {
     87   unsigned long index;
     88   _BitScanReverse(&index, v);
     89   return index;
     90 }
     91 FJXL_INLINE uint32_t CtzNonZero(uint64_t v) {
     92   unsigned long index;
     93   _BitScanForward(&index, v);
     94   return index;
     95 }
     96 #else
     97 #define FJXL_INLINE inline __attribute__((always_inline))
     98 FJXL_INLINE uint32_t FloorLog2(uint32_t v) {
     99   return v ? 31 - __builtin_clz(v) : 0;
    100 }
    101 FJXL_INLINE uint32_t CtzNonZero(uint64_t v) { return __builtin_ctzll(v); }
    102 #endif
    103 
    104 // Compiles to a memcpy on little-endian systems.
    105 FJXL_INLINE void StoreLE64(uint8_t* tgt, uint64_t data) {
    106 #if (!defined(__BYTE_ORDER__) || (__BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__))
    107   for (int i = 0; i < 8; i++) {
    108     tgt[i] = (data >> (i * 8)) & 0xFF;
    109   }
    110 #else
    111   memcpy(tgt, &data, 8);
    112 #endif
    113 }
    114 
    115 FJXL_INLINE size_t AddBits(uint32_t count, uint64_t bits, uint8_t* data_buf,
    116                            size_t& bits_in_buffer, uint64_t& bit_buffer) {
    117   bit_buffer |= bits << bits_in_buffer;
    118   bits_in_buffer += count;
    119   StoreLE64(data_buf, bit_buffer);
    120   size_t bytes_in_buffer = bits_in_buffer / 8;
    121   bits_in_buffer -= bytes_in_buffer * 8;
    122   bit_buffer >>= bytes_in_buffer * 8;
    123   return bytes_in_buffer;
    124 }
    125 
    126 struct BitWriter {
    127   void Allocate(size_t maximum_bit_size) {
    128     assert(data == nullptr);
    129     // Leave some padding.
    130     data.reset(static_cast<uint8_t*>(malloc(maximum_bit_size / 8 + 64)));
    131   }
    132 
    133   void Write(uint32_t count, uint64_t bits) {
    134     bytes_written += AddBits(count, bits, data.get() + bytes_written,
    135                              bits_in_buffer, buffer);
    136   }
    137 
    138   void ZeroPadToByte() {
    139     if (bits_in_buffer != 0) {
    140       Write(8 - bits_in_buffer, 0);
    141     }
    142   }
    143 
    144   FJXL_INLINE void WriteMultiple(const uint64_t* nbits, const uint64_t* bits,
    145                                  size_t n) {
    146     // Necessary because Write() is only guaranteed to work with <=56 bits.
    147     // Trying to SIMD-fy this code results in lower speed (and definitely less
    148     // clarity).
    149     {
    150       for (size_t i = 0; i < n; i++) {
    151         this->buffer |= bits[i] << this->bits_in_buffer;
    152         memcpy(this->data.get() + this->bytes_written, &this->buffer, 8);
    153         uint64_t shift = 64 - this->bits_in_buffer;
    154         this->bits_in_buffer += nbits[i];
    155         // This `if` seems to be faster than using ternaries.
    156         if (this->bits_in_buffer >= 64) {
    157           uint64_t next_buffer = bits[i] >> shift;
    158           this->buffer = next_buffer;
    159           this->bits_in_buffer -= 64;
    160           this->bytes_written += 8;
    161         }
    162       }
    163       memcpy(this->data.get() + this->bytes_written, &this->buffer, 8);
    164       size_t bytes_in_buffer = this->bits_in_buffer / 8;
    165       this->bits_in_buffer -= bytes_in_buffer * 8;
    166       this->buffer >>= bytes_in_buffer * 8;
    167       this->bytes_written += bytes_in_buffer;
    168     }
    169   }
    170 
    171   std::unique_ptr<uint8_t[], void (*)(void*)> data = {nullptr, free};
    172   size_t bytes_written = 0;
    173   size_t bits_in_buffer = 0;
    174   uint64_t buffer = 0;
    175 };
    176 
    177 size_t SectionSize(const std::array<BitWriter, 4>& group_data) {
    178   size_t sz = 0;
    179   for (size_t j = 0; j < 4; j++) {
    180     const auto& writer = group_data[j];
    181     sz += writer.bytes_written * 8 + writer.bits_in_buffer;
    182   }
    183   sz = (sz + 7) / 8;
    184   return sz;
    185 }
    186 
    187 constexpr size_t kMaxFrameHeaderSize = 5;
    188 
    189 constexpr size_t kGroupSizeOffset[4] = {
    190     static_cast<size_t>(0),
    191     static_cast<size_t>(1024),
    192     static_cast<size_t>(17408),
    193     static_cast<size_t>(4211712),
    194 };
    195 constexpr size_t kTOCBits[4] = {12, 16, 24, 32};
    196 
    197 size_t TOCBucket(size_t group_size) {
    198   size_t bucket = 0;
    199   while (bucket < 3 && group_size >= kGroupSizeOffset[bucket + 1]) ++bucket;
    200   return bucket;
    201 }
    202 
    203 size_t TOCSize(const std::vector<size_t>& group_sizes) {
    204   size_t toc_bits = 0;
    205   for (size_t i = 0; i < group_sizes.size(); i++) {
    206     toc_bits += kTOCBits[TOCBucket(group_sizes[i])];
    207   }
    208   return (toc_bits + 7) / 8;
    209 }
    210 
    211 size_t FrameHeaderSize(bool have_alpha, bool is_last) {
    212   size_t nbits = 28 + (have_alpha ? 4 : 0) + (is_last ? 0 : 2);
    213   return (nbits + 7) / 8;
    214 }
    215 
    216 void ComputeAcGroupDataOffset(size_t dc_global_size, size_t num_dc_groups,
    217                               size_t num_ac_groups, size_t& min_dc_global_size,
    218                               size_t& ac_group_offset) {
    219   // Max AC group size is 768 kB, so max AC group TOC bits is 24.
    220   size_t ac_toc_max_bits = num_ac_groups * 24;
    221   size_t ac_toc_min_bits = num_ac_groups * 12;
    222   size_t max_padding = 1 + (ac_toc_max_bits - ac_toc_min_bits + 7) / 8;
    223   min_dc_global_size = dc_global_size;
    224   size_t dc_global_bucket = TOCBucket(min_dc_global_size);
    225   while (TOCBucket(min_dc_global_size + max_padding) > dc_global_bucket) {
    226     dc_global_bucket = TOCBucket(min_dc_global_size + max_padding);
    227     min_dc_global_size = kGroupSizeOffset[dc_global_bucket];
    228   }
    229   assert(TOCBucket(min_dc_global_size) == dc_global_bucket);
    230   assert(TOCBucket(min_dc_global_size + max_padding) == dc_global_bucket);
    231   size_t max_toc_bits =
    232       kTOCBits[dc_global_bucket] + 12 * (1 + num_dc_groups) + ac_toc_max_bits;
    233   size_t max_toc_size = (max_toc_bits + 7) / 8;
    234   ac_group_offset = kMaxFrameHeaderSize + max_toc_size + min_dc_global_size;
    235 }
    236 
    237 size_t ComputeDcGlobalPadding(const std::vector<size_t>& group_sizes,
    238                               size_t ac_group_data_offset,
    239                               size_t min_dc_global_size, bool have_alpha,
    240                               bool is_last) {
    241   std::vector<size_t> new_group_sizes = group_sizes;
    242   new_group_sizes[0] = min_dc_global_size;
    243   size_t toc_size = TOCSize(new_group_sizes);
    244   size_t actual_offset =
    245       FrameHeaderSize(have_alpha, is_last) + toc_size + group_sizes[0];
    246   return ac_group_data_offset - actual_offset;
    247 }
    248 
    249 constexpr size_t kNumRawSymbols = 19;
    250 constexpr size_t kNumLZ77 = 33;
    251 constexpr size_t kLZ77CacheSize = 32;
    252 
    253 constexpr size_t kLZ77Offset = 224;
    254 constexpr size_t kLZ77MinLength = 7;
    255 
    256 void EncodeHybridUintLZ77(uint32_t value, uint32_t* token, uint32_t* nbits,
    257                           uint32_t* bits) {
    258   // 400 config
    259   uint32_t n = FloorLog2(value);
    260   *token = value < 16 ? value : 16 + n - 4;
    261   *nbits = value < 16 ? 0 : n;
    262   *bits = value < 16 ? 0 : value - (1 << *nbits);
    263 }
    264 
    265 struct PrefixCode {
    266   uint8_t raw_nbits[kNumRawSymbols] = {};
    267   uint8_t raw_bits[kNumRawSymbols] = {};
    268 
    269   uint8_t lz77_nbits[kNumLZ77] = {};
    270   uint16_t lz77_bits[kNumLZ77] = {};
    271 
    272   uint64_t lz77_cache_bits[kLZ77CacheSize] = {};
    273   uint8_t lz77_cache_nbits[kLZ77CacheSize] = {};
    274 
    275   size_t numraw;
    276 
    277   static uint16_t BitReverse(size_t nbits, uint16_t bits) {
    278     constexpr uint16_t kNibbleLookup[16] = {
    279         0b0000, 0b1000, 0b0100, 0b1100, 0b0010, 0b1010, 0b0110, 0b1110,
    280         0b0001, 0b1001, 0b0101, 0b1101, 0b0011, 0b1011, 0b0111, 0b1111,
    281     };
    282     uint16_t rev16 = (kNibbleLookup[bits & 0xF] << 12) |
    283                      (kNibbleLookup[(bits >> 4) & 0xF] << 8) |
    284                      (kNibbleLookup[(bits >> 8) & 0xF] << 4) |
    285                      (kNibbleLookup[bits >> 12]);
    286     return rev16 >> (16 - nbits);
    287   }
    288 
    289   // Create the prefix codes given the code lengths.
    290   // Supports the code lengths being split into two halves.
    291   static void ComputeCanonicalCode(const uint8_t* first_chunk_nbits,
    292                                    uint8_t* first_chunk_bits,
    293                                    size_t first_chunk_size,
    294                                    const uint8_t* second_chunk_nbits,
    295                                    uint16_t* second_chunk_bits,
    296                                    size_t second_chunk_size) {
    297     constexpr size_t kMaxCodeLength = 15;
    298     uint8_t code_length_counts[kMaxCodeLength + 1] = {};
    299     for (size_t i = 0; i < first_chunk_size; i++) {
    300       code_length_counts[first_chunk_nbits[i]]++;
    301       assert(first_chunk_nbits[i] <= kMaxCodeLength);
    302       assert(first_chunk_nbits[i] <= 8);
    303       assert(first_chunk_nbits[i] > 0);
    304     }
    305     for (size_t i = 0; i < second_chunk_size; i++) {
    306       code_length_counts[second_chunk_nbits[i]]++;
    307       assert(second_chunk_nbits[i] <= kMaxCodeLength);
    308     }
    309 
    310     uint16_t next_code[kMaxCodeLength + 1] = {};
    311 
    312     uint16_t code = 0;
    313     for (size_t i = 1; i < kMaxCodeLength + 1; i++) {
    314       code = (code + code_length_counts[i - 1]) << 1;
    315       next_code[i] = code;
    316     }
    317 
    318     for (size_t i = 0; i < first_chunk_size; i++) {
    319       first_chunk_bits[i] =
    320           BitReverse(first_chunk_nbits[i], next_code[first_chunk_nbits[i]]++);
    321     }
    322     for (size_t i = 0; i < second_chunk_size; i++) {
    323       second_chunk_bits[i] =
    324           BitReverse(second_chunk_nbits[i], next_code[second_chunk_nbits[i]]++);
    325     }
    326   }
    327 
    328   template <typename T>
    329   static void ComputeCodeLengthsNonZeroImpl(const uint64_t* freqs, size_t n,
    330                                             size_t precision, T infty,
    331                                             uint8_t* min_limit,
    332                                             uint8_t* max_limit,
    333                                             uint8_t* nbits) {
    334     assert(precision < 15);
    335     assert(n <= kMaxNumSymbols);
    336     std::vector<T> dynp(((1U << precision) + 1) * (n + 1), infty);
    337     auto d = [&](size_t sym, size_t off) -> T& {
    338       return dynp[sym * ((1 << precision) + 1) + off];
    339     };
    340     d(0, 0) = 0;
    341     for (size_t sym = 0; sym < n; sym++) {
    342       for (T bits = min_limit[sym]; bits <= max_limit[sym]; bits++) {
    343         size_t off_delta = 1U << (precision - bits);
    344         for (size_t off = 0; off + off_delta <= (1U << precision); off++) {
    345           d(sym + 1, off + off_delta) =
    346               std::min(d(sym, off) + static_cast<T>(freqs[sym]) * bits,
    347                        d(sym + 1, off + off_delta));
    348         }
    349       }
    350     }
    351 
    352     size_t sym = n;
    353     size_t off = 1U << precision;
    354 
    355     assert(d(sym, off) != infty);
    356 
    357     while (sym-- > 0) {
    358       assert(off > 0);
    359       for (size_t bits = min_limit[sym]; bits <= max_limit[sym]; bits++) {
    360         size_t off_delta = 1U << (precision - bits);
    361         if (off_delta <= off &&
    362             d(sym + 1, off) == d(sym, off - off_delta) + freqs[sym] * bits) {
    363           off -= off_delta;
    364           nbits[sym] = bits;
    365           break;
    366         }
    367       }
    368     }
    369   }
    370 
    371   // Computes nbits[i] for i <= n, subject to min_limit[i] <= nbits[i] <=
    372   // max_limit[i] and sum 2**-nbits[i] == 1, so to minimize sum(nbits[i] *
    373   // freqs[i]).
    374   static void ComputeCodeLengthsNonZero(const uint64_t* freqs, size_t n,
    375                                         uint8_t* min_limit, uint8_t* max_limit,
    376                                         uint8_t* nbits) {
    377     size_t precision = 0;
    378     size_t shortest_length = 255;
    379     uint64_t freqsum = 0;
    380     for (size_t i = 0; i < n; i++) {
    381       assert(freqs[i] != 0);
    382       freqsum += freqs[i];
    383       if (min_limit[i] < 1) min_limit[i] = 1;
    384       assert(min_limit[i] <= max_limit[i]);
    385       precision = std::max<size_t>(max_limit[i], precision);
    386       shortest_length = std::min<size_t>(min_limit[i], shortest_length);
    387     }
    388     // If all the minimum limits are greater than 1, shift precision so that we
    389     // behave as if the shortest was 1.
    390     precision -= shortest_length - 1;
    391     uint64_t infty = freqsum * precision;
    392     if (infty < std::numeric_limits<uint32_t>::max() / 2) {
    393       ComputeCodeLengthsNonZeroImpl(freqs, n, precision,
    394                                     static_cast<uint32_t>(infty), min_limit,
    395                                     max_limit, nbits);
    396     } else {
    397       ComputeCodeLengthsNonZeroImpl(freqs, n, precision, infty, min_limit,
    398                                     max_limit, nbits);
    399     }
    400   }
    401 
    402   static constexpr size_t kMaxNumSymbols =
    403       kNumRawSymbols + 1 < kNumLZ77 ? kNumLZ77 : kNumRawSymbols + 1;
    404   static void ComputeCodeLengths(const uint64_t* freqs, size_t n,
    405                                  const uint8_t* min_limit_in,
    406                                  const uint8_t* max_limit_in, uint8_t* nbits) {
    407     assert(n <= kMaxNumSymbols);
    408     uint64_t compact_freqs[kMaxNumSymbols];
    409     uint8_t min_limit[kMaxNumSymbols];
    410     uint8_t max_limit[kMaxNumSymbols];
    411     size_t ni = 0;
    412     for (size_t i = 0; i < n; i++) {
    413       if (freqs[i]) {
    414         compact_freqs[ni] = freqs[i];
    415         min_limit[ni] = min_limit_in[i];
    416         max_limit[ni] = max_limit_in[i];
    417         ni++;
    418       }
    419     }
    420     uint8_t num_bits[kMaxNumSymbols] = {};
    421     ComputeCodeLengthsNonZero(compact_freqs, ni, min_limit, max_limit,
    422                               num_bits);
    423     ni = 0;
    424     for (size_t i = 0; i < n; i++) {
    425       nbits[i] = 0;
    426       if (freqs[i]) {
    427         nbits[i] = num_bits[ni++];
    428       }
    429     }
    430   }
    431 
    432   // Invalid code, used to construct arrays.
    433   PrefixCode() = default;
    434 
    435   template <typename BitDepth>
    436   PrefixCode(BitDepth /* bitdepth */, uint64_t* raw_counts,
    437              uint64_t* lz77_counts) {
    438     // "merge" together all the lz77 counts in a single symbol for the level 1
    439     // table (containing just the raw symbols, up to length 7).
    440     uint64_t level1_counts[kNumRawSymbols + 1];
    441     memcpy(level1_counts, raw_counts, kNumRawSymbols * sizeof(uint64_t));
    442     numraw = kNumRawSymbols;
    443     while (numraw > 0 && level1_counts[numraw - 1] == 0) numraw--;
    444 
    445     level1_counts[numraw] = 0;
    446     for (size_t i = 0; i < kNumLZ77; i++) {
    447       level1_counts[numraw] += lz77_counts[i];
    448     }
    449     uint8_t level1_nbits[kNumRawSymbols + 1] = {};
    450     ComputeCodeLengths(level1_counts, numraw + 1, BitDepth::kMinRawLength,
    451                        BitDepth::kMaxRawLength, level1_nbits);
    452 
    453     uint8_t level2_nbits[kNumLZ77] = {};
    454     uint8_t min_lengths[kNumLZ77] = {};
    455     uint8_t l = 15 - level1_nbits[numraw];
    456     uint8_t max_lengths[kNumLZ77];
    457     for (size_t i = 0; i < kNumLZ77; i++) {
    458       max_lengths[i] = l;
    459     }
    460     size_t num_lz77 = kNumLZ77;
    461     while (num_lz77 > 0 && lz77_counts[num_lz77 - 1] == 0) num_lz77--;
    462     ComputeCodeLengths(lz77_counts, num_lz77, min_lengths, max_lengths,
    463                        level2_nbits);
    464     for (size_t i = 0; i < numraw; i++) {
    465       raw_nbits[i] = level1_nbits[i];
    466     }
    467     for (size_t i = 0; i < num_lz77; i++) {
    468       lz77_nbits[i] =
    469           level2_nbits[i] ? level1_nbits[numraw] + level2_nbits[i] : 0;
    470     }
    471 
    472     ComputeCanonicalCode(raw_nbits, raw_bits, numraw, lz77_nbits, lz77_bits,
    473                          kNumLZ77);
    474 
    475     // Prepare lz77 cache
    476     for (size_t count = 0; count < kLZ77CacheSize; count++) {
    477       unsigned token, nbits, bits;
    478       EncodeHybridUintLZ77(count, &token, &nbits, &bits);
    479       lz77_cache_nbits[count] = lz77_nbits[token] + nbits + raw_nbits[0];
    480       lz77_cache_bits[count] =
    481           (((bits << lz77_nbits[token]) | lz77_bits[token]) << raw_nbits[0]) |
    482           raw_bits[0];
    483     }
    484   }
    485 
    486   // Max bits written: 2 + 72 + 95 + 24 + 165 = 286
    487   void WriteTo(BitWriter* writer) const {
    488     uint64_t code_length_counts[18] = {};
    489     code_length_counts[17] = 3 + 2 * (kNumLZ77 - 1);
    490     for (size_t i = 0; i < kNumRawSymbols; i++) {
    491       code_length_counts[raw_nbits[i]]++;
    492     }
    493     for (size_t i = 0; i < kNumLZ77; i++) {
    494       code_length_counts[lz77_nbits[i]]++;
    495     }
    496     uint8_t code_length_nbits[18] = {};
    497     uint8_t code_length_nbits_min[18] = {};
    498     uint8_t code_length_nbits_max[18] = {
    499         5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
    500     };
    501     ComputeCodeLengths(code_length_counts, 18, code_length_nbits_min,
    502                        code_length_nbits_max, code_length_nbits);
    503     writer->Write(2, 0b00);  // HSKIP = 0, i.e. don't skip code lengths.
    504 
    505     // As per Brotli RFC.
    506     uint8_t code_length_order[18] = {1, 2, 3, 4,  0,  5,  17, 6,  16,
    507                                      7, 8, 9, 10, 11, 12, 13, 14, 15};
    508     uint8_t code_length_length_nbits[] = {2, 4, 3, 2, 2, 4};
    509     uint8_t code_length_length_bits[] = {0, 7, 3, 2, 1, 15};
    510 
    511     // Encode lengths of code lengths.
    512     size_t num_code_lengths = 18;
    513     while (code_length_nbits[code_length_order[num_code_lengths - 1]] == 0) {
    514       num_code_lengths--;
    515     }
    516     // Max bits written in this loop: 18 * 4 = 72
    517     for (size_t i = 0; i < num_code_lengths; i++) {
    518       int symbol = code_length_nbits[code_length_order[i]];
    519       writer->Write(code_length_length_nbits[symbol],
    520                     code_length_length_bits[symbol]);
    521     }
    522 
    523     // Compute the canonical codes for the codes that represent the lengths of
    524     // the actual codes for data.
    525     uint16_t code_length_bits[18] = {};
    526     ComputeCanonicalCode(nullptr, nullptr, 0, code_length_nbits,
    527                          code_length_bits, 18);
    528     // Encode raw bit code lengths.
    529     // Max bits written in this loop: 19 * 5 = 95
    530     for (size_t i = 0; i < kNumRawSymbols; i++) {
    531       writer->Write(code_length_nbits[raw_nbits[i]],
    532                     code_length_bits[raw_nbits[i]]);
    533     }
    534     size_t num_lz77 = kNumLZ77;
    535     while (lz77_nbits[num_lz77 - 1] == 0) {
    536       num_lz77--;
    537     }
    538     // Encode 0s until 224 (start of LZ77 symbols). This is in total 224-19 =
    539     // 205.
    540     static_assert(kLZ77Offset == 224, "");
    541     static_assert(kNumRawSymbols == 19, "");
    542     {
    543       // Max bits in this block: 24
    544       writer->Write(code_length_nbits[17], code_length_bits[17]);
    545       writer->Write(3, 0b010);  // 5
    546       writer->Write(code_length_nbits[17], code_length_bits[17]);
    547       writer->Write(3, 0b000);  // (5-2)*8 + 3 = 27
    548       writer->Write(code_length_nbits[17], code_length_bits[17]);
    549       writer->Write(3, 0b010);  // (27-2)*8 + 5 = 205
    550     }
    551     // Encode LZ77 symbols, with values 224+i.
    552     // Max bits written in this loop: 33 * 5 = 165
    553     for (size_t i = 0; i < num_lz77; i++) {
    554       writer->Write(code_length_nbits[lz77_nbits[i]],
    555                     code_length_bits[lz77_nbits[i]]);
    556     }
    557   }
    558 };
    559 
    560 }  // namespace
    561 
    562 extern "C" {
    563 
    564 struct JxlFastLosslessFrameState {
    565   JxlChunkedFrameInputSource input;
    566   size_t width;
    567   size_t height;
    568   size_t num_groups_x;
    569   size_t num_groups_y;
    570   size_t num_dc_groups_x;
    571   size_t num_dc_groups_y;
    572   size_t nb_chans;
    573   size_t bitdepth;
    574   int big_endian;
    575   int effort;
    576   bool collided;
    577   PrefixCode hcode[4];
    578   std::vector<int16_t> lookup;
    579   BitWriter header;
    580   std::vector<std::array<BitWriter, 4>> group_data;
    581   std::vector<size_t> group_sizes;
    582   size_t ac_group_data_offset = 0;
    583   size_t min_dc_global_size = 0;
    584   size_t current_bit_writer = 0;
    585   size_t bit_writer_byte_pos = 0;
    586   size_t bits_in_buffer = 0;
    587   uint64_t bit_buffer = 0;
    588   bool process_done = false;
    589 };
    590 
    591 size_t JxlFastLosslessOutputSize(const JxlFastLosslessFrameState* frame) {
    592   size_t total_size_groups = 0;
    593   for (size_t i = 0; i < frame->group_data.size(); i++) {
    594     total_size_groups += SectionSize(frame->group_data[i]);
    595   }
    596   return frame->header.bytes_written + total_size_groups;
    597 }
    598 
    599 size_t JxlFastLosslessMaxRequiredOutput(
    600     const JxlFastLosslessFrameState* frame) {
    601   return JxlFastLosslessOutputSize(frame) + 32;
    602 }
    603 
    604 void JxlFastLosslessPrepareHeader(JxlFastLosslessFrameState* frame,
    605                                   int add_image_header, int is_last) {
    606   BitWriter* output = &frame->header;
    607   output->Allocate(1000 + frame->group_sizes.size() * 32);
    608 
    609   bool have_alpha = (frame->nb_chans == 2 || frame->nb_chans == 4);
    610 
    611 #if FJXL_STANDALONE
    612   if (add_image_header) {
    613     // Signature
    614     output->Write(16, 0x0AFF);
    615 
    616     // Size header, hand-crafted.
    617     // Not small
    618     output->Write(1, 0);
    619 
    620     auto wsz = [output](size_t size) {
    621       if (size - 1 < (1 << 9)) {
    622         output->Write(2, 0b00);
    623         output->Write(9, size - 1);
    624       } else if (size - 1 < (1 << 13)) {
    625         output->Write(2, 0b01);
    626         output->Write(13, size - 1);
    627       } else if (size - 1 < (1 << 18)) {
    628         output->Write(2, 0b10);
    629         output->Write(18, size - 1);
    630       } else {
    631         output->Write(2, 0b11);
    632         output->Write(30, size - 1);
    633       }
    634     };
    635 
    636     wsz(frame->height);
    637 
    638     // No special ratio.
    639     output->Write(3, 0);
    640 
    641     wsz(frame->width);
    642 
    643     // Hand-crafted ImageMetadata.
    644     output->Write(1, 0);  // all_default
    645     output->Write(1, 0);  // extra_fields
    646     output->Write(1, 0);  // bit_depth.floating_point_sample
    647     if (frame->bitdepth == 8) {
    648       output->Write(2, 0b00);  // bit_depth.bits_per_sample = 8
    649     } else if (frame->bitdepth == 10) {
    650       output->Write(2, 0b01);  // bit_depth.bits_per_sample = 10
    651     } else if (frame->bitdepth == 12) {
    652       output->Write(2, 0b10);  // bit_depth.bits_per_sample = 12
    653     } else {
    654       output->Write(2, 0b11);  // 1 + u(6)
    655       output->Write(6, frame->bitdepth - 1);
    656     }
    657     if (frame->bitdepth <= 14) {
    658       output->Write(1, 1);  // 16-bit-buffer sufficient
    659     } else {
    660       output->Write(1, 0);  // 16-bit-buffer NOT sufficient
    661     }
    662     if (have_alpha) {
    663       output->Write(2, 0b01);  // One extra channel
    664       output->Write(1, 1);     // ... all_default (ie. 8-bit alpha)
    665     } else {
    666       output->Write(2, 0b00);  // No extra channel
    667     }
    668     output->Write(1, 0);  // Not XYB
    669     if (frame->nb_chans > 2) {
    670       output->Write(1, 1);  // color_encoding.all_default (sRGB)
    671     } else {
    672       output->Write(1, 0);     // color_encoding.all_default false
    673       output->Write(1, 0);     // color_encoding.want_icc false
    674       output->Write(2, 1);     // grayscale
    675       output->Write(2, 1);     // D65
    676       output->Write(1, 0);     // no gamma transfer function
    677       output->Write(2, 0b10);  // tf: 2 + u(4)
    678       output->Write(4, 11);    // tf of sRGB
    679       output->Write(2, 1);     // relative rendering intent
    680     }
    681     output->Write(2, 0b00);  // No extensions.
    682 
    683     output->Write(1, 1);  // all_default transform data
    684 
    685     // No ICC, no preview. Frame should start at byte boundery.
    686     output->ZeroPadToByte();
    687   }
    688 #else
    689   assert(!add_image_header);
    690 #endif
    691   // Handcrafted frame header.
    692   output->Write(1, 0);     // all_default
    693   output->Write(2, 0b00);  // regular frame
    694   output->Write(1, 1);     // modular
    695   output->Write(2, 0b00);  // default flags
    696   output->Write(1, 0);     // not YCbCr
    697   output->Write(2, 0b00);  // no upsampling
    698   if (have_alpha) {
    699     output->Write(2, 0b00);  // no alpha upsampling
    700   }
    701   output->Write(2, 0b01);  // default group size
    702   output->Write(2, 0b00);  // exactly one pass
    703   output->Write(1, 0);     // no custom size or origin
    704   output->Write(2, 0b00);  // kReplace blending mode
    705   if (have_alpha) {
    706     output->Write(2, 0b00);  // kReplace blending mode for alpha channel
    707   }
    708   output->Write(1, is_last);  // is_last
    709   if (!is_last) {
    710     output->Write(2, 0b00);  // can not be saved as reference
    711   }
    712   output->Write(2, 0b00);  // a frame has no name
    713   output->Write(1, 0);     // loop filter is not all_default
    714   output->Write(1, 0);     // no gaborish
    715   output->Write(2, 0);     // 0 EPF iters
    716   output->Write(2, 0b00);  // No LF extensions
    717   output->Write(2, 0b00);  // No FH extensions
    718 
    719   output->Write(1, 0);      // No TOC permutation
    720   output->ZeroPadToByte();  // TOC is byte-aligned.
    721   assert(add_image_header || output->bytes_written <= kMaxFrameHeaderSize);
    722   for (size_t i = 0; i < frame->group_sizes.size(); i++) {
    723     size_t sz = frame->group_sizes[i];
    724     size_t bucket = TOCBucket(sz);
    725     output->Write(2, bucket);
    726     output->Write(kTOCBits[bucket] - 2, sz - kGroupSizeOffset[bucket]);
    727   }
    728   output->ZeroPadToByte();  // Groups are byte-aligned.
    729 }
    730 
    731 #if !FJXL_STANDALONE
    732 void JxlFastLosslessOutputAlignedSection(
    733     const BitWriter& bw, JxlEncoderOutputProcessorWrapper* output_processor) {
    734   assert(bw.bits_in_buffer == 0);
    735   const uint8_t* data = bw.data.get();
    736   size_t remaining_len = bw.bytes_written;
    737   while (remaining_len > 0) {
    738     auto retval = output_processor->GetBuffer(1, remaining_len);
    739     assert(retval.status());
    740     auto buffer = std::move(retval).value();
    741     size_t n = std::min(buffer.size(), remaining_len);
    742     if (n == 0) break;
    743     memcpy(buffer.data(), data, n);
    744     buffer.advance(n);
    745     data += n;
    746     remaining_len -= n;
    747   };
    748 }
    749 
    750 void JxlFastLosslessOutputHeaders(
    751     JxlFastLosslessFrameState* frame_state,
    752     JxlEncoderOutputProcessorWrapper* output_processor) {
    753   JxlFastLosslessOutputAlignedSection(frame_state->header, output_processor);
    754   JxlFastLosslessOutputAlignedSection(frame_state->group_data[0][0],
    755                                       output_processor);
    756 }
    757 #endif
    758 
    759 #if FJXL_ENABLE_AVX512
    760 __attribute__((target("avx512vbmi2"))) static size_t AppendBytesWithBitOffset(
    761     const uint8_t* data, size_t n, size_t bit_buffer_nbits,
    762     unsigned char* output, uint64_t& bit_buffer) {
    763   if (n < 128) {
    764     return 0;
    765   }
    766 
    767   size_t i = 0;
    768   __m512i shift = _mm512_set1_epi64(64 - bit_buffer_nbits);
    769   __m512i carry = _mm512_set1_epi64(bit_buffer << (64 - bit_buffer_nbits));
    770 
    771   for (; i + 64 <= n; i += 64) {
    772     __m512i current = _mm512_loadu_si512(data + i);
    773     __m512i previous_u64 = _mm512_alignr_epi64(current, carry, 7);
    774     carry = current;
    775     __m512i out = _mm512_shrdv_epi64(previous_u64, current, shift);
    776     _mm512_storeu_si512(output + i, out);
    777   }
    778 
    779   bit_buffer = data[i - 1] >> (8 - bit_buffer_nbits);
    780 
    781   return i;
    782 }
    783 #endif
    784 
    785 size_t JxlFastLosslessWriteOutput(JxlFastLosslessFrameState* frame,
    786                                   unsigned char* output, size_t output_size) {
    787   assert(output_size >= 32);
    788   unsigned char* initial_output = output;
    789   size_t (*append_bytes_with_bit_offset)(const uint8_t*, size_t, size_t,
    790                                          unsigned char*, uint64_t&) = nullptr;
    791 
    792 #if FJXL_ENABLE_AVX512
    793   if (__builtin_cpu_supports("avx512vbmi2")) {
    794     append_bytes_with_bit_offset = AppendBytesWithBitOffset;
    795   }
    796 #endif
    797 
    798   while (true) {
    799     size_t& cur = frame->current_bit_writer;
    800     size_t& bw_pos = frame->bit_writer_byte_pos;
    801     if (cur >= 1 + frame->group_data.size() * frame->nb_chans) {
    802       return output - initial_output;
    803     }
    804     if (output_size <= 9) {
    805       return output - initial_output;
    806     }
    807     size_t nbc = frame->nb_chans;
    808     const BitWriter& writer =
    809         cur == 0 ? frame->header
    810                  : frame->group_data[(cur - 1) / nbc][(cur - 1) % nbc];
    811     size_t full_byte_count =
    812         std::min(output_size - 9, writer.bytes_written - bw_pos);
    813     if (frame->bits_in_buffer == 0) {
    814       memcpy(output, writer.data.get() + bw_pos, full_byte_count);
    815     } else {
    816       size_t i = 0;
    817       if (append_bytes_with_bit_offset) {
    818         i += append_bytes_with_bit_offset(
    819             writer.data.get() + bw_pos, full_byte_count, frame->bits_in_buffer,
    820             output, frame->bit_buffer);
    821       }
    822 #if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
    823       // Copy 8 bytes at a time until we reach the border.
    824       for (; i + 8 < full_byte_count; i += 8) {
    825         uint64_t chunk;
    826         memcpy(&chunk, writer.data.get() + bw_pos + i, 8);
    827         uint64_t out = frame->bit_buffer | (chunk << frame->bits_in_buffer);
    828         memcpy(output + i, &out, 8);
    829         frame->bit_buffer = chunk >> (64 - frame->bits_in_buffer);
    830       }
    831 #endif
    832       for (; i < full_byte_count; i++) {
    833         AddBits(8, writer.data.get()[bw_pos + i], output + i,
    834                 frame->bits_in_buffer, frame->bit_buffer);
    835       }
    836     }
    837     output += full_byte_count;
    838     output_size -= full_byte_count;
    839     bw_pos += full_byte_count;
    840     if (bw_pos == writer.bytes_written) {
    841       auto write = [&](size_t num, uint64_t bits) {
    842         size_t n = AddBits(num, bits, output, frame->bits_in_buffer,
    843                            frame->bit_buffer);
    844         output += n;
    845         output_size -= n;
    846       };
    847       if (writer.bits_in_buffer) {
    848         write(writer.bits_in_buffer, writer.buffer);
    849       }
    850       bw_pos = 0;
    851       cur++;
    852       if ((cur - 1) % nbc == 0 && frame->bits_in_buffer != 0) {
    853         write(8 - frame->bits_in_buffer, 0);
    854       }
    855     }
    856   }
    857 }
    858 
    859 void JxlFastLosslessFreeFrameState(JxlFastLosslessFrameState* frame) {
    860   delete frame;
    861 }
    862 
    863 }  // extern "C"
    864 
    865 #endif
    866 
    867 #ifdef FJXL_SELF_INCLUDE
    868 
    869 namespace {
    870 
    871 template <typename T>
    872 struct VecPair {
    873   T low;
    874   T hi;
    875 };
    876 
    877 #ifdef FJXL_GENERIC_SIMD
    878 #undef FJXL_GENERIC_SIMD
    879 #endif
    880 
    881 #ifdef FJXL_AVX512
    882 #define FJXL_GENERIC_SIMD
    883 struct SIMDVec32;
    884 struct Mask32 {
    885   __mmask16 mask;
    886   SIMDVec32 IfThenElse(const SIMDVec32& if_true, const SIMDVec32& if_false);
    887   size_t CountPrefix() const {
    888     return CtzNonZero(~uint64_t{_cvtmask16_u32(mask)});
    889   }
    890 };
    891 
    892 struct SIMDVec32 {
    893   __m512i vec;
    894 
    895   static constexpr size_t kLanes = 16;
    896 
    897   FJXL_INLINE static SIMDVec32 Load(const uint32_t* data) {
    898     return SIMDVec32{_mm512_loadu_si512((__m512i*)data)};
    899   }
    900   FJXL_INLINE void Store(uint32_t* data) {
    901     _mm512_storeu_si512((__m512i*)data, vec);
    902   }
    903   FJXL_INLINE static SIMDVec32 Val(uint32_t v) {
    904     return SIMDVec32{_mm512_set1_epi32(v)};
    905   }
    906   FJXL_INLINE SIMDVec32 ValToToken() const {
    907     return SIMDVec32{
    908         _mm512_sub_epi32(_mm512_set1_epi32(32), _mm512_lzcnt_epi32(vec))};
    909   }
    910   FJXL_INLINE SIMDVec32 SatSubU(const SIMDVec32& to_subtract) const {
    911     return SIMDVec32{_mm512_sub_epi32(_mm512_max_epu32(vec, to_subtract.vec),
    912                                       to_subtract.vec)};
    913   }
    914   FJXL_INLINE SIMDVec32 Sub(const SIMDVec32& to_subtract) const {
    915     return SIMDVec32{_mm512_sub_epi32(vec, to_subtract.vec)};
    916   }
    917   FJXL_INLINE SIMDVec32 Add(const SIMDVec32& oth) const {
    918     return SIMDVec32{_mm512_add_epi32(vec, oth.vec)};
    919   }
    920   FJXL_INLINE SIMDVec32 Xor(const SIMDVec32& oth) const {
    921     return SIMDVec32{_mm512_xor_epi32(vec, oth.vec)};
    922   }
    923   FJXL_INLINE Mask32 Eq(const SIMDVec32& oth) const {
    924     return Mask32{_mm512_cmpeq_epi32_mask(vec, oth.vec)};
    925   }
    926   FJXL_INLINE Mask32 Gt(const SIMDVec32& oth) const {
    927     return Mask32{_mm512_cmpgt_epi32_mask(vec, oth.vec)};
    928   }
    929   FJXL_INLINE SIMDVec32 Pow2() const {
    930     return SIMDVec32{_mm512_sllv_epi32(_mm512_set1_epi32(1), vec)};
    931   }
    932   template <size_t i>
    933   FJXL_INLINE SIMDVec32 SignedShiftRight() const {
    934     return SIMDVec32{_mm512_srai_epi32(vec, i)};
    935   }
    936 };
    937 
    938 struct SIMDVec16;
    939 
    940 struct Mask16 {
    941   __mmask32 mask;
    942   SIMDVec16 IfThenElse(const SIMDVec16& if_true, const SIMDVec16& if_false);
    943   Mask16 And(const Mask16& oth) const {
    944     return Mask16{_kand_mask32(mask, oth.mask)};
    945   }
    946   size_t CountPrefix() const {
    947     return CtzNonZero(~uint64_t{_cvtmask32_u32(mask)});
    948   }
    949 };
    950 
    951 struct SIMDVec16 {
    952   __m512i vec;
    953 
    954   static constexpr size_t kLanes = 32;
    955 
    956   FJXL_INLINE static SIMDVec16 Load(const uint16_t* data) {
    957     return SIMDVec16{_mm512_loadu_si512((__m512i*)data)};
    958   }
    959   FJXL_INLINE void Store(uint16_t* data) {
    960     _mm512_storeu_si512((__m512i*)data, vec);
    961   }
    962   FJXL_INLINE static SIMDVec16 Val(uint16_t v) {
    963     return SIMDVec16{_mm512_set1_epi16(v)};
    964   }
    965   FJXL_INLINE static SIMDVec16 FromTwo32(const SIMDVec32& lo,
    966                                          const SIMDVec32& hi) {
    967     auto tmp = _mm512_packus_epi32(lo.vec, hi.vec);
    968     alignas(64) uint64_t perm[8] = {0, 2, 4, 6, 1, 3, 5, 7};
    969     return SIMDVec16{
    970         _mm512_permutex2var_epi64(tmp, _mm512_load_si512((__m512i*)perm), tmp)};
    971   }
    972 
    973   FJXL_INLINE SIMDVec16 ValToToken() const {
    974     auto c16 = _mm512_set1_epi32(16);
    975     auto c32 = _mm512_set1_epi32(32);
    976     auto low16bit = _mm512_set1_epi32(0x0000FFFF);
    977     auto lzhi =
    978         _mm512_sub_epi32(c16, _mm512_min_epu32(c16, _mm512_lzcnt_epi32(vec)));
    979     auto lzlo = _mm512_sub_epi32(
    980         c32, _mm512_lzcnt_epi32(_mm512_and_si512(low16bit, vec)));
    981     return SIMDVec16{_mm512_or_si512(lzlo, _mm512_slli_epi32(lzhi, 16))};
    982   }
    983 
    984   FJXL_INLINE SIMDVec16 SatSubU(const SIMDVec16& to_subtract) const {
    985     return SIMDVec16{_mm512_subs_epu16(vec, to_subtract.vec)};
    986   }
    987   FJXL_INLINE SIMDVec16 Sub(const SIMDVec16& to_subtract) const {
    988     return SIMDVec16{_mm512_sub_epi16(vec, to_subtract.vec)};
    989   }
    990   FJXL_INLINE SIMDVec16 Add(const SIMDVec16& oth) const {
    991     return SIMDVec16{_mm512_add_epi16(vec, oth.vec)};
    992   }
    993   FJXL_INLINE SIMDVec16 Min(const SIMDVec16& oth) const {
    994     return SIMDVec16{_mm512_min_epu16(vec, oth.vec)};
    995   }
    996   FJXL_INLINE Mask16 Eq(const SIMDVec16& oth) const {
    997     return Mask16{_mm512_cmpeq_epi16_mask(vec, oth.vec)};
    998   }
    999   FJXL_INLINE Mask16 Gt(const SIMDVec16& oth) const {
   1000     return Mask16{_mm512_cmpgt_epi16_mask(vec, oth.vec)};
   1001   }
   1002   FJXL_INLINE SIMDVec16 Pow2() const {
   1003     return SIMDVec16{_mm512_sllv_epi16(_mm512_set1_epi16(1), vec)};
   1004   }
   1005   FJXL_INLINE SIMDVec16 Or(const SIMDVec16& oth) const {
   1006     return SIMDVec16{_mm512_or_si512(vec, oth.vec)};
   1007   }
   1008   FJXL_INLINE SIMDVec16 Xor(const SIMDVec16& oth) const {
   1009     return SIMDVec16{_mm512_xor_si512(vec, oth.vec)};
   1010   }
   1011   FJXL_INLINE SIMDVec16 And(const SIMDVec16& oth) const {
   1012     return SIMDVec16{_mm512_and_si512(vec, oth.vec)};
   1013   }
   1014   FJXL_INLINE SIMDVec16 HAdd(const SIMDVec16& oth) const {
   1015     return SIMDVec16{_mm512_srai_epi16(_mm512_add_epi16(vec, oth.vec), 1)};
   1016   }
   1017   FJXL_INLINE SIMDVec16 PrepareForU8Lookup() const {
   1018     return SIMDVec16{_mm512_or_si512(vec, _mm512_set1_epi16(0xFF00))};
   1019   }
   1020   FJXL_INLINE SIMDVec16 U8Lookup(const uint8_t* table) const {
   1021     return SIMDVec16{_mm512_shuffle_epi8(
   1022         _mm512_broadcast_i32x4(_mm_loadu_si128((__m128i*)table)), vec)};
   1023   }
   1024   FJXL_INLINE VecPair<SIMDVec16> Interleave(const SIMDVec16& low) const {
   1025     auto lo = _mm512_unpacklo_epi16(low.vec, vec);
   1026     auto hi = _mm512_unpackhi_epi16(low.vec, vec);
   1027     alignas(64) uint64_t perm1[8] = {0, 1, 8, 9, 2, 3, 10, 11};
   1028     alignas(64) uint64_t perm2[8] = {4, 5, 12, 13, 6, 7, 14, 15};
   1029     return {SIMDVec16{_mm512_permutex2var_epi64(
   1030                 lo, _mm512_load_si512((__m512i*)perm1), hi)},
   1031             SIMDVec16{_mm512_permutex2var_epi64(
   1032                 lo, _mm512_load_si512((__m512i*)perm2), hi)}};
   1033   }
   1034   FJXL_INLINE VecPair<SIMDVec32> Upcast() const {
   1035     auto lo = _mm512_unpacklo_epi16(vec, _mm512_setzero_si512());
   1036     auto hi = _mm512_unpackhi_epi16(vec, _mm512_setzero_si512());
   1037     alignas(64) uint64_t perm1[8] = {0, 1, 8, 9, 2, 3, 10, 11};
   1038     alignas(64) uint64_t perm2[8] = {4, 5, 12, 13, 6, 7, 14, 15};
   1039     return {SIMDVec32{_mm512_permutex2var_epi64(
   1040                 lo, _mm512_load_si512((__m512i*)perm1), hi)},
   1041             SIMDVec32{_mm512_permutex2var_epi64(
   1042                 lo, _mm512_load_si512((__m512i*)perm2), hi)}};
   1043   }
   1044   template <size_t i>
   1045   FJXL_INLINE SIMDVec16 SignedShiftRight() const {
   1046     return SIMDVec16{_mm512_srai_epi16(vec, i)};
   1047   }
   1048 
   1049   static std::array<SIMDVec16, 1> LoadG8(const unsigned char* data) {
   1050     __m256i bytes = _mm256_loadu_si256((__m256i*)data);
   1051     return {SIMDVec16{_mm512_cvtepu8_epi16(bytes)}};
   1052   }
   1053   static std::array<SIMDVec16, 1> LoadG16(const unsigned char* data) {
   1054     return {Load((const uint16_t*)data)};
   1055   }
   1056 
   1057   static std::array<SIMDVec16, 2> LoadGA8(const unsigned char* data) {
   1058     __m512i bytes = _mm512_loadu_si512((__m512i*)data);
   1059     __m512i gray = _mm512_and_si512(bytes, _mm512_set1_epi16(0xFF));
   1060     __m512i alpha = _mm512_srli_epi16(bytes, 8);
   1061     return {SIMDVec16{gray}, SIMDVec16{alpha}};
   1062   }
   1063   static std::array<SIMDVec16, 2> LoadGA16(const unsigned char* data) {
   1064     __m512i bytes1 = _mm512_loadu_si512((__m512i*)data);
   1065     __m512i bytes2 = _mm512_loadu_si512((__m512i*)(data + 64));
   1066     __m512i g_mask = _mm512_set1_epi32(0xFFFF);
   1067     __m512i permuteidx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
   1068     __m512i g = _mm512_permutexvar_epi64(
   1069         permuteidx, _mm512_packus_epi32(_mm512_and_si512(bytes1, g_mask),
   1070                                         _mm512_and_si512(bytes2, g_mask)));
   1071     __m512i a = _mm512_permutexvar_epi64(
   1072         permuteidx, _mm512_packus_epi32(_mm512_srli_epi32(bytes1, 16),
   1073                                         _mm512_srli_epi32(bytes2, 16)));
   1074     return {SIMDVec16{g}, SIMDVec16{a}};
   1075   }
   1076 
   1077   static std::array<SIMDVec16, 3> LoadRGB8(const unsigned char* data) {
   1078     __m512i bytes0 = _mm512_loadu_si512((__m512i*)data);
   1079     __m512i bytes1 =
   1080         _mm512_zextsi256_si512(_mm256_loadu_si256((__m256i*)(data + 64)));
   1081 
   1082     // 0x7A = element of upper half of second vector = 0 after lookup; still in
   1083     // the upper half once we add 1 or 2.
   1084     uint8_t z = 0x7A;
   1085     __m512i ridx =
   1086         _mm512_set_epi8(z, 93, z, 90, z, 87, z, 84, z, 81, z, 78, z, 75, z, 72,
   1087                         z, 69, z, 66, z, 63, z, 60, z, 57, z, 54, z, 51, z, 48,
   1088                         z, 45, z, 42, z, 39, z, 36, z, 33, z, 30, z, 27, z, 24,
   1089                         z, 21, z, 18, z, 15, z, 12, z, 9, z, 6, z, 3, z, 0);
   1090     __m512i gidx = _mm512_add_epi8(ridx, _mm512_set1_epi8(1));
   1091     __m512i bidx = _mm512_add_epi8(gidx, _mm512_set1_epi8(1));
   1092     __m512i r = _mm512_permutex2var_epi8(bytes0, ridx, bytes1);
   1093     __m512i g = _mm512_permutex2var_epi8(bytes0, gidx, bytes1);
   1094     __m512i b = _mm512_permutex2var_epi8(bytes0, bidx, bytes1);
   1095     return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}};
   1096   }
   1097   static std::array<SIMDVec16, 3> LoadRGB16(const unsigned char* data) {
   1098     __m512i bytes0 = _mm512_loadu_si512((__m512i*)data);
   1099     __m512i bytes1 = _mm512_loadu_si512((__m512i*)(data + 64));
   1100     __m512i bytes2 = _mm512_loadu_si512((__m512i*)(data + 128));
   1101 
   1102     __m512i ridx_lo = _mm512_set_epi16(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 63, 60, 57,
   1103                                        54, 51, 48, 45, 42, 39, 36, 33, 30, 27,
   1104                                        24, 21, 18, 15, 12, 9, 6, 3, 0);
   1105     // -1 is such that when adding 1 or 2, we get the correct index for
   1106     // green/blue.
   1107     __m512i ridx_hi =
   1108         _mm512_set_epi16(29, 26, 23, 20, 17, 14, 11, 8, 5, 2, -1, 0, 0, 0, 0, 0,
   1109                          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
   1110     __m512i gidx_lo = _mm512_add_epi16(ridx_lo, _mm512_set1_epi16(1));
   1111     __m512i gidx_hi = _mm512_add_epi16(ridx_hi, _mm512_set1_epi16(1));
   1112     __m512i bidx_lo = _mm512_add_epi16(gidx_lo, _mm512_set1_epi16(1));
   1113     __m512i bidx_hi = _mm512_add_epi16(gidx_hi, _mm512_set1_epi16(1));
   1114 
   1115     __mmask32 rmask = _cvtu32_mask32(0b11111111110000000000000000000000);
   1116     __mmask32 gbmask = _cvtu32_mask32(0b11111111111000000000000000000000);
   1117 
   1118     __m512i rlo = _mm512_permutex2var_epi16(bytes0, ridx_lo, bytes1);
   1119     __m512i glo = _mm512_permutex2var_epi16(bytes0, gidx_lo, bytes1);
   1120     __m512i blo = _mm512_permutex2var_epi16(bytes0, bidx_lo, bytes1);
   1121     __m512i r = _mm512_mask_permutexvar_epi16(rlo, rmask, ridx_hi, bytes2);
   1122     __m512i g = _mm512_mask_permutexvar_epi16(glo, gbmask, gidx_hi, bytes2);
   1123     __m512i b = _mm512_mask_permutexvar_epi16(blo, gbmask, bidx_hi, bytes2);
   1124     return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}};
   1125   }
   1126 
   1127   static std::array<SIMDVec16, 4> LoadRGBA8(const unsigned char* data) {
   1128     __m512i bytes1 = _mm512_loadu_si512((__m512i*)data);
   1129     __m512i bytes2 = _mm512_loadu_si512((__m512i*)(data + 64));
   1130     __m512i rg_mask = _mm512_set1_epi32(0xFFFF);
   1131     __m512i permuteidx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
   1132     __m512i rg = _mm512_permutexvar_epi64(
   1133         permuteidx, _mm512_packus_epi32(_mm512_and_si512(bytes1, rg_mask),
   1134                                         _mm512_and_si512(bytes2, rg_mask)));
   1135     __m512i ba = _mm512_permutexvar_epi64(
   1136         permuteidx, _mm512_packus_epi32(_mm512_srli_epi32(bytes1, 16),
   1137                                         _mm512_srli_epi32(bytes2, 16)));
   1138     __m512i r = _mm512_and_si512(rg, _mm512_set1_epi16(0xFF));
   1139     __m512i g = _mm512_srli_epi16(rg, 8);
   1140     __m512i b = _mm512_and_si512(ba, _mm512_set1_epi16(0xFF));
   1141     __m512i a = _mm512_srli_epi16(ba, 8);
   1142     return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}, SIMDVec16{a}};
   1143   }
   1144   static std::array<SIMDVec16, 4> LoadRGBA16(const unsigned char* data) {
   1145     __m512i bytes0 = _mm512_loadu_si512((__m512i*)data);
   1146     __m512i bytes1 = _mm512_loadu_si512((__m512i*)(data + 64));
   1147     __m512i bytes2 = _mm512_loadu_si512((__m512i*)(data + 128));
   1148     __m512i bytes3 = _mm512_loadu_si512((__m512i*)(data + 192));
   1149 
   1150     auto pack32 = [](__m512i a, __m512i b) {
   1151       __m512i permuteidx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
   1152       return _mm512_permutexvar_epi64(permuteidx, _mm512_packus_epi32(a, b));
   1153     };
   1154     auto packlow32 = [&pack32](__m512i a, __m512i b) {
   1155       __m512i mask = _mm512_set1_epi32(0xFFFF);
   1156       return pack32(_mm512_and_si512(a, mask), _mm512_and_si512(b, mask));
   1157     };
   1158     auto packhi32 = [&pack32](__m512i a, __m512i b) {
   1159       return pack32(_mm512_srli_epi32(a, 16), _mm512_srli_epi32(b, 16));
   1160     };
   1161 
   1162     __m512i rb0 = packlow32(bytes0, bytes1);
   1163     __m512i rb1 = packlow32(bytes2, bytes3);
   1164     __m512i ga0 = packhi32(bytes0, bytes1);
   1165     __m512i ga1 = packhi32(bytes2, bytes3);
   1166 
   1167     __m512i r = packlow32(rb0, rb1);
   1168     __m512i g = packlow32(ga0, ga1);
   1169     __m512i b = packhi32(rb0, rb1);
   1170     __m512i a = packhi32(ga0, ga1);
   1171     return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}, SIMDVec16{a}};
   1172   }
   1173 
   1174   void SwapEndian() {
   1175     auto indices = _mm512_broadcast_i32x4(
   1176         _mm_setr_epi8(1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14));
   1177     vec = _mm512_shuffle_epi8(vec, indices);
   1178   }
   1179 };
   1180 
   1181 SIMDVec16 Mask16::IfThenElse(const SIMDVec16& if_true,
   1182                              const SIMDVec16& if_false) {
   1183   return SIMDVec16{_mm512_mask_blend_epi16(mask, if_false.vec, if_true.vec)};
   1184 }
   1185 
   1186 SIMDVec32 Mask32::IfThenElse(const SIMDVec32& if_true,
   1187                              const SIMDVec32& if_false) {
   1188   return SIMDVec32{_mm512_mask_blend_epi32(mask, if_false.vec, if_true.vec)};
   1189 }
   1190 
   1191 struct Bits64 {
   1192   static constexpr size_t kLanes = 8;
   1193 
   1194   __m512i nbits;
   1195   __m512i bits;
   1196 
   1197   FJXL_INLINE void Store(uint64_t* nbits_out, uint64_t* bits_out) {
   1198     _mm512_storeu_si512((__m512i*)nbits_out, nbits);
   1199     _mm512_storeu_si512((__m512i*)bits_out, bits);
   1200   }
   1201 };
   1202 
   1203 struct Bits32 {
   1204   __m512i nbits;
   1205   __m512i bits;
   1206 
   1207   static Bits32 FromRaw(SIMDVec32 nbits, SIMDVec32 bits) {
   1208     return Bits32{nbits.vec, bits.vec};
   1209   }
   1210 
   1211   Bits64 Merge() const {
   1212     auto nbits_hi32 = _mm512_srli_epi64(nbits, 32);
   1213     auto nbits_lo32 = _mm512_and_si512(nbits, _mm512_set1_epi64(0xFFFFFFFF));
   1214     auto bits_hi32 = _mm512_srli_epi64(bits, 32);
   1215     auto bits_lo32 = _mm512_and_si512(bits, _mm512_set1_epi64(0xFFFFFFFF));
   1216 
   1217     auto nbits64 = _mm512_add_epi64(nbits_hi32, nbits_lo32);
   1218     auto bits64 =
   1219         _mm512_or_si512(_mm512_sllv_epi64(bits_hi32, nbits_lo32), bits_lo32);
   1220     return Bits64{nbits64, bits64};
   1221   }
   1222 
   1223   void Interleave(const Bits32& low) {
   1224     bits = _mm512_or_si512(_mm512_sllv_epi32(bits, low.nbits), low.bits);
   1225     nbits = _mm512_add_epi32(nbits, low.nbits);
   1226   }
   1227 
   1228   void ClipTo(size_t n) {
   1229     n = std::min<size_t>(n, 16);
   1230     constexpr uint32_t kMask[32] = {
   1231         ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u,
   1232         ~0u, ~0u, ~0u, ~0u, ~0u, 0,   0,   0,   0,   0,   0,
   1233         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
   1234     };
   1235     __m512i mask = _mm512_loadu_si512((__m512i*)(kMask + 16 - n));
   1236     nbits = _mm512_and_si512(mask, nbits);
   1237     bits = _mm512_and_si512(mask, bits);
   1238   }
   1239   void Skip(size_t n) {
   1240     n = std::min<size_t>(n, 16);
   1241     constexpr uint32_t kMask[32] = {
   1242         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
   1243         0,   0,   0,   0,   0,   ~0u, ~0u, ~0u, ~0u, ~0u, ~0u,
   1244         ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u,
   1245     };
   1246     __m512i mask = _mm512_loadu_si512((__m512i*)(kMask + 16 - n));
   1247     nbits = _mm512_and_si512(mask, nbits);
   1248     bits = _mm512_and_si512(mask, bits);
   1249   }
   1250 };
   1251 
   1252 struct Bits16 {
   1253   __m512i nbits;
   1254   __m512i bits;
   1255 
   1256   static Bits16 FromRaw(SIMDVec16 nbits, SIMDVec16 bits) {
   1257     return Bits16{nbits.vec, bits.vec};
   1258   }
   1259 
   1260   Bits32 Merge() const {
   1261     auto nbits_hi16 = _mm512_srli_epi32(nbits, 16);
   1262     auto nbits_lo16 = _mm512_and_si512(nbits, _mm512_set1_epi32(0xFFFF));
   1263     auto bits_hi16 = _mm512_srli_epi32(bits, 16);
   1264     auto bits_lo16 = _mm512_and_si512(bits, _mm512_set1_epi32(0xFFFF));
   1265 
   1266     auto nbits32 = _mm512_add_epi32(nbits_hi16, nbits_lo16);
   1267     auto bits32 =
   1268         _mm512_or_si512(_mm512_sllv_epi32(bits_hi16, nbits_lo16), bits_lo16);
   1269     return Bits32{nbits32, bits32};
   1270   }
   1271 
   1272   void Interleave(const Bits16& low) {
   1273     bits = _mm512_or_si512(_mm512_sllv_epi16(bits, low.nbits), low.bits);
   1274     nbits = _mm512_add_epi16(nbits, low.nbits);
   1275   }
   1276 
   1277   void ClipTo(size_t n) {
   1278     n = std::min<size_t>(n, 32);
   1279     constexpr uint16_t kMask[64] = {
   1280         0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   1281         0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   1282         0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   1283         0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   1284         0,      0,      0,      0,      0,      0,      0,      0,
   1285         0,      0,      0,      0,      0,      0,      0,      0,
   1286         0,      0,      0,      0,      0,      0,      0,      0,
   1287         0,      0,      0,      0,      0,      0,      0,      0,
   1288     };
   1289     __m512i mask = _mm512_loadu_si512((__m512i*)(kMask + 32 - n));
   1290     nbits = _mm512_and_si512(mask, nbits);
   1291     bits = _mm512_and_si512(mask, bits);
   1292   }
   1293   void Skip(size_t n) {
   1294     n = std::min<size_t>(n, 32);
   1295     constexpr uint16_t kMask[64] = {
   1296         0,      0,      0,      0,      0,      0,      0,      0,
   1297         0,      0,      0,      0,      0,      0,      0,      0,
   1298         0,      0,      0,      0,      0,      0,      0,      0,
   1299         0,      0,      0,      0,      0,      0,      0,      0,
   1300         0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   1301         0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   1302         0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   1303         0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   1304     };
   1305     __m512i mask = _mm512_loadu_si512((__m512i*)(kMask + 32 - n));
   1306     nbits = _mm512_and_si512(mask, nbits);
   1307     bits = _mm512_and_si512(mask, bits);
   1308   }
   1309 };
   1310 
   1311 #endif
   1312 
   1313 #ifdef FJXL_AVX2
   1314 #define FJXL_GENERIC_SIMD
   1315 
   1316 struct SIMDVec32;
   1317 
   1318 struct Mask32 {
   1319   __m256i mask;
   1320   SIMDVec32 IfThenElse(const SIMDVec32& if_true, const SIMDVec32& if_false);
   1321   size_t CountPrefix() const {
   1322     return CtzNonZero(~static_cast<uint64_t>(
   1323         static_cast<uint8_t>(_mm256_movemask_ps(_mm256_castsi256_ps(mask)))));
   1324   }
   1325 };
   1326 
   1327 struct SIMDVec32 {
   1328   __m256i vec;
   1329 
   1330   static constexpr size_t kLanes = 8;
   1331 
   1332   FJXL_INLINE static SIMDVec32 Load(const uint32_t* data) {
   1333     return SIMDVec32{_mm256_loadu_si256((__m256i*)data)};
   1334   }
   1335   FJXL_INLINE void Store(uint32_t* data) {
   1336     _mm256_storeu_si256((__m256i*)data, vec);
   1337   }
   1338   FJXL_INLINE static SIMDVec32 Val(uint32_t v) {
   1339     return SIMDVec32{_mm256_set1_epi32(v)};
   1340   }
   1341   FJXL_INLINE SIMDVec32 ValToToken() const {
   1342     // we know that each value has at most 20 bits, so we just need 5 nibbles
   1343     // and don't need to mask the fifth. However we do need to set the higher
   1344     // bytes to 0xFF, which will make table lookups return 0.
   1345     auto nibble0 =
   1346         _mm256_or_si256(_mm256_and_si256(vec, _mm256_set1_epi32(0xF)),
   1347                         _mm256_set1_epi32(0xFFFFFF00));
   1348     auto nibble1 = _mm256_or_si256(
   1349         _mm256_and_si256(_mm256_srli_epi32(vec, 4), _mm256_set1_epi32(0xF)),
   1350         _mm256_set1_epi32(0xFFFFFF00));
   1351     auto nibble2 = _mm256_or_si256(
   1352         _mm256_and_si256(_mm256_srli_epi32(vec, 8), _mm256_set1_epi32(0xF)),
   1353         _mm256_set1_epi32(0xFFFFFF00));
   1354     auto nibble3 = _mm256_or_si256(
   1355         _mm256_and_si256(_mm256_srli_epi32(vec, 12), _mm256_set1_epi32(0xF)),
   1356         _mm256_set1_epi32(0xFFFFFF00));
   1357     auto nibble4 = _mm256_or_si256(_mm256_srli_epi32(vec, 16),
   1358                                    _mm256_set1_epi32(0xFFFFFF00));
   1359 
   1360     auto lut0 = _mm256_broadcastsi128_si256(
   1361         _mm_setr_epi8(0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4));
   1362     auto lut1 = _mm256_broadcastsi128_si256(
   1363         _mm_setr_epi8(0, 5, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8));
   1364     auto lut2 = _mm256_broadcastsi128_si256(_mm_setr_epi8(
   1365         0, 9, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12));
   1366     auto lut3 = _mm256_broadcastsi128_si256(_mm_setr_epi8(
   1367         0, 13, 14, 14, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16));
   1368     auto lut4 = _mm256_broadcastsi128_si256(_mm_setr_epi8(
   1369         0, 17, 18, 18, 19, 19, 19, 19, 20, 20, 20, 20, 20, 20, 20, 20));
   1370 
   1371     auto token0 = _mm256_shuffle_epi8(lut0, nibble0);
   1372     auto token1 = _mm256_shuffle_epi8(lut1, nibble1);
   1373     auto token2 = _mm256_shuffle_epi8(lut2, nibble2);
   1374     auto token3 = _mm256_shuffle_epi8(lut3, nibble3);
   1375     auto token4 = _mm256_shuffle_epi8(lut4, nibble4);
   1376 
   1377     auto token =
   1378         _mm256_max_epi32(_mm256_max_epi32(_mm256_max_epi32(token0, token1),
   1379                                           _mm256_max_epi32(token2, token3)),
   1380                          token4);
   1381     return SIMDVec32{token};
   1382   }
   1383   FJXL_INLINE SIMDVec32 SatSubU(const SIMDVec32& to_subtract) const {
   1384     return SIMDVec32{_mm256_sub_epi32(_mm256_max_epu32(vec, to_subtract.vec),
   1385                                       to_subtract.vec)};
   1386   }
   1387   FJXL_INLINE SIMDVec32 Sub(const SIMDVec32& to_subtract) const {
   1388     return SIMDVec32{_mm256_sub_epi32(vec, to_subtract.vec)};
   1389   }
   1390   FJXL_INLINE SIMDVec32 Add(const SIMDVec32& oth) const {
   1391     return SIMDVec32{_mm256_add_epi32(vec, oth.vec)};
   1392   }
   1393   FJXL_INLINE SIMDVec32 Xor(const SIMDVec32& oth) const {
   1394     return SIMDVec32{_mm256_xor_si256(vec, oth.vec)};
   1395   }
   1396   FJXL_INLINE SIMDVec32 Pow2() const {
   1397     return SIMDVec32{_mm256_sllv_epi32(_mm256_set1_epi32(1), vec)};
   1398   }
   1399   FJXL_INLINE Mask32 Eq(const SIMDVec32& oth) const {
   1400     return Mask32{_mm256_cmpeq_epi32(vec, oth.vec)};
   1401   }
   1402   FJXL_INLINE Mask32 Gt(const SIMDVec32& oth) const {
   1403     return Mask32{_mm256_cmpgt_epi32(vec, oth.vec)};
   1404   }
   1405   template <size_t i>
   1406   FJXL_INLINE SIMDVec32 SignedShiftRight() const {
   1407     return SIMDVec32{_mm256_srai_epi32(vec, i)};
   1408   }
   1409 };
   1410 
   1411 struct SIMDVec16;
   1412 
   1413 struct Mask16 {
   1414   __m256i mask;
   1415   SIMDVec16 IfThenElse(const SIMDVec16& if_true, const SIMDVec16& if_false);
   1416   Mask16 And(const Mask16& oth) const {
   1417     return Mask16{_mm256_and_si256(mask, oth.mask)};
   1418   }
   1419   size_t CountPrefix() const {
   1420     return CtzNonZero(~static_cast<uint64_t>(
   1421                static_cast<uint32_t>(_mm256_movemask_epi8(mask)))) /
   1422            2;
   1423   }
   1424 };
   1425 
   1426 struct SIMDVec16 {
   1427   __m256i vec;
   1428 
   1429   static constexpr size_t kLanes = 16;
   1430 
   1431   FJXL_INLINE static SIMDVec16 Load(const uint16_t* data) {
   1432     return SIMDVec16{_mm256_loadu_si256((__m256i*)data)};
   1433   }
   1434   FJXL_INLINE void Store(uint16_t* data) {
   1435     _mm256_storeu_si256((__m256i*)data, vec);
   1436   }
   1437   FJXL_INLINE static SIMDVec16 Val(uint16_t v) {
   1438     return SIMDVec16{_mm256_set1_epi16(v)};
   1439   }
   1440   FJXL_INLINE static SIMDVec16 FromTwo32(const SIMDVec32& lo,
   1441                                          const SIMDVec32& hi) {
   1442     auto tmp = _mm256_packus_epi32(lo.vec, hi.vec);
   1443     return SIMDVec16{_mm256_permute4x64_epi64(tmp, 0b11011000)};
   1444   }
   1445 
   1446   FJXL_INLINE SIMDVec16 ValToToken() const {
   1447     auto nibble0 =
   1448         _mm256_or_si256(_mm256_and_si256(vec, _mm256_set1_epi16(0xF)),
   1449                         _mm256_set1_epi16(0xFF00));
   1450     auto nibble1 = _mm256_or_si256(
   1451         _mm256_and_si256(_mm256_srli_epi16(vec, 4), _mm256_set1_epi16(0xF)),
   1452         _mm256_set1_epi16(0xFF00));
   1453     auto nibble2 = _mm256_or_si256(
   1454         _mm256_and_si256(_mm256_srli_epi16(vec, 8), _mm256_set1_epi16(0xF)),
   1455         _mm256_set1_epi16(0xFF00));
   1456     auto nibble3 =
   1457         _mm256_or_si256(_mm256_srli_epi16(vec, 12), _mm256_set1_epi16(0xFF00));
   1458 
   1459     auto lut0 = _mm256_broadcastsi128_si256(
   1460         _mm_setr_epi8(0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4));
   1461     auto lut1 = _mm256_broadcastsi128_si256(
   1462         _mm_setr_epi8(0, 5, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8));
   1463     auto lut2 = _mm256_broadcastsi128_si256(_mm_setr_epi8(
   1464         0, 9, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12));
   1465     auto lut3 = _mm256_broadcastsi128_si256(_mm_setr_epi8(
   1466         0, 13, 14, 14, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16));
   1467 
   1468     auto token0 = _mm256_shuffle_epi8(lut0, nibble0);
   1469     auto token1 = _mm256_shuffle_epi8(lut1, nibble1);
   1470     auto token2 = _mm256_shuffle_epi8(lut2, nibble2);
   1471     auto token3 = _mm256_shuffle_epi8(lut3, nibble3);
   1472 
   1473     auto token = _mm256_max_epi16(_mm256_max_epi16(token0, token1),
   1474                                   _mm256_max_epi16(token2, token3));
   1475     return SIMDVec16{token};
   1476   }
   1477 
   1478   FJXL_INLINE SIMDVec16 SatSubU(const SIMDVec16& to_subtract) const {
   1479     return SIMDVec16{_mm256_subs_epu16(vec, to_subtract.vec)};
   1480   }
   1481   FJXL_INLINE SIMDVec16 Sub(const SIMDVec16& to_subtract) const {
   1482     return SIMDVec16{_mm256_sub_epi16(vec, to_subtract.vec)};
   1483   }
   1484   FJXL_INLINE SIMDVec16 Add(const SIMDVec16& oth) const {
   1485     return SIMDVec16{_mm256_add_epi16(vec, oth.vec)};
   1486   }
   1487   FJXL_INLINE SIMDVec16 Min(const SIMDVec16& oth) const {
   1488     return SIMDVec16{_mm256_min_epu16(vec, oth.vec)};
   1489   }
   1490   FJXL_INLINE Mask16 Eq(const SIMDVec16& oth) const {
   1491     return Mask16{_mm256_cmpeq_epi16(vec, oth.vec)};
   1492   }
   1493   FJXL_INLINE Mask16 Gt(const SIMDVec16& oth) const {
   1494     return Mask16{_mm256_cmpgt_epi16(vec, oth.vec)};
   1495   }
   1496   FJXL_INLINE SIMDVec16 Pow2() const {
   1497     auto pow2_lo_lut = _mm256_broadcastsi128_si256(
   1498         _mm_setr_epi8(1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 6,
   1499                       1u << 7, 0, 0, 0, 0, 0, 0, 0, 0));
   1500     auto pow2_hi_lut = _mm256_broadcastsi128_si256(
   1501         _mm_setr_epi8(0, 0, 0, 0, 0, 0, 0, 0, 1 << 0, 1 << 1, 1 << 2, 1 << 3,
   1502                       1 << 4, 1 << 5, 1 << 6, 1u << 7));
   1503 
   1504     auto masked = _mm256_or_si256(vec, _mm256_set1_epi16(0xFF00));
   1505 
   1506     auto pow2_lo = _mm256_shuffle_epi8(pow2_lo_lut, masked);
   1507     auto pow2_hi = _mm256_shuffle_epi8(pow2_hi_lut, masked);
   1508 
   1509     auto pow2 = _mm256_or_si256(_mm256_slli_epi16(pow2_hi, 8), pow2_lo);
   1510     return SIMDVec16{pow2};
   1511   }
   1512   FJXL_INLINE SIMDVec16 Or(const SIMDVec16& oth) const {
   1513     return SIMDVec16{_mm256_or_si256(vec, oth.vec)};
   1514   }
   1515   FJXL_INLINE SIMDVec16 Xor(const SIMDVec16& oth) const {
   1516     return SIMDVec16{_mm256_xor_si256(vec, oth.vec)};
   1517   }
   1518   FJXL_INLINE SIMDVec16 And(const SIMDVec16& oth) const {
   1519     return SIMDVec16{_mm256_and_si256(vec, oth.vec)};
   1520   }
   1521   FJXL_INLINE SIMDVec16 HAdd(const SIMDVec16& oth) const {
   1522     return SIMDVec16{_mm256_srai_epi16(_mm256_add_epi16(vec, oth.vec), 1)};
   1523   }
   1524   FJXL_INLINE SIMDVec16 PrepareForU8Lookup() const {
   1525     return SIMDVec16{_mm256_or_si256(vec, _mm256_set1_epi16(0xFF00))};
   1526   }
   1527   FJXL_INLINE SIMDVec16 U8Lookup(const uint8_t* table) const {
   1528     return SIMDVec16{_mm256_shuffle_epi8(
   1529         _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)table)), vec)};
   1530   }
   1531   FJXL_INLINE VecPair<SIMDVec16> Interleave(const SIMDVec16& low) const {
   1532     auto v02 = _mm256_unpacklo_epi16(low.vec, vec);
   1533     auto v13 = _mm256_unpackhi_epi16(low.vec, vec);
   1534     return {SIMDVec16{_mm256_permute2x128_si256(v02, v13, 0x20)},
   1535             SIMDVec16{_mm256_permute2x128_si256(v02, v13, 0x31)}};
   1536   }
   1537   FJXL_INLINE VecPair<SIMDVec32> Upcast() const {
   1538     auto v02 = _mm256_unpacklo_epi16(vec, _mm256_setzero_si256());
   1539     auto v13 = _mm256_unpackhi_epi16(vec, _mm256_setzero_si256());
   1540     return {SIMDVec32{_mm256_permute2x128_si256(v02, v13, 0x20)},
   1541             SIMDVec32{_mm256_permute2x128_si256(v02, v13, 0x31)}};
   1542   }
   1543   template <size_t i>
   1544   FJXL_INLINE SIMDVec16 SignedShiftRight() const {
   1545     return SIMDVec16{_mm256_srai_epi16(vec, i)};
   1546   }
   1547 
   1548   static std::array<SIMDVec16, 1> LoadG8(const unsigned char* data) {
   1549     __m128i bytes = _mm_loadu_si128((__m128i*)data);
   1550     return {SIMDVec16{_mm256_cvtepu8_epi16(bytes)}};
   1551   }
   1552   static std::array<SIMDVec16, 1> LoadG16(const unsigned char* data) {
   1553     return {Load((const uint16_t*)data)};
   1554   }
   1555 
   1556   static std::array<SIMDVec16, 2> LoadGA8(const unsigned char* data) {
   1557     __m256i bytes = _mm256_loadu_si256((__m256i*)data);
   1558     __m256i gray = _mm256_and_si256(bytes, _mm256_set1_epi16(0xFF));
   1559     __m256i alpha = _mm256_srli_epi16(bytes, 8);
   1560     return {SIMDVec16{gray}, SIMDVec16{alpha}};
   1561   }
   1562   static std::array<SIMDVec16, 2> LoadGA16(const unsigned char* data) {
   1563     __m256i bytes1 = _mm256_loadu_si256((__m256i*)data);
   1564     __m256i bytes2 = _mm256_loadu_si256((__m256i*)(data + 32));
   1565     __m256i g_mask = _mm256_set1_epi32(0xFFFF);
   1566     __m256i g = _mm256_permute4x64_epi64(
   1567         _mm256_packus_epi32(_mm256_and_si256(bytes1, g_mask),
   1568                             _mm256_and_si256(bytes2, g_mask)),
   1569         0b11011000);
   1570     __m256i a = _mm256_permute4x64_epi64(
   1571         _mm256_packus_epi32(_mm256_srli_epi32(bytes1, 16),
   1572                             _mm256_srli_epi32(bytes2, 16)),
   1573         0b11011000);
   1574     return {SIMDVec16{g}, SIMDVec16{a}};
   1575   }
   1576 
   1577   static std::array<SIMDVec16, 3> LoadRGB8(const unsigned char* data) {
   1578     __m128i bytes0 = _mm_loadu_si128((__m128i*)data);
   1579     __m128i bytes1 = _mm_loadu_si128((__m128i*)(data + 16));
   1580     __m128i bytes2 = _mm_loadu_si128((__m128i*)(data + 32));
   1581 
   1582     __m128i idx =
   1583         _mm_setr_epi8(0, 3, 6, 9, 12, 15, 2, 5, 8, 11, 14, 1, 4, 7, 10, 13);
   1584 
   1585     __m128i r6b5g5_0 = _mm_shuffle_epi8(bytes0, idx);
   1586     __m128i g6r5b5_1 = _mm_shuffle_epi8(bytes1, idx);
   1587     __m128i b6g5r5_2 = _mm_shuffle_epi8(bytes2, idx);
   1588 
   1589     __m128i mask010 = _mm_setr_epi8(0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xFF, 0xFF,
   1590                                     0xFF, 0, 0, 0, 0, 0);
   1591     __m128i mask001 = _mm_setr_epi8(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF,
   1592                                     0xFF, 0xFF, 0xFF);
   1593 
   1594     __m128i b2g2b1 = _mm_blendv_epi8(b6g5r5_2, g6r5b5_1, mask001);
   1595     __m128i b2b0b1 = _mm_blendv_epi8(b2g2b1, r6b5g5_0, mask010);
   1596 
   1597     __m128i r0r1b1 = _mm_blendv_epi8(r6b5g5_0, g6r5b5_1, mask010);
   1598     __m128i r0r1r2 = _mm_blendv_epi8(r0r1b1, b6g5r5_2, mask001);
   1599 
   1600     __m128i g1r1g0 = _mm_blendv_epi8(g6r5b5_1, r6b5g5_0, mask001);
   1601     __m128i g1g2g0 = _mm_blendv_epi8(g1r1g0, b6g5r5_2, mask010);
   1602 
   1603     __m128i g0g1g2 = _mm_alignr_epi8(g1g2g0, g1g2g0, 11);
   1604     __m128i b0b1b2 = _mm_alignr_epi8(b2b0b1, b2b0b1, 6);
   1605 
   1606     return {SIMDVec16{_mm256_cvtepu8_epi16(r0r1r2)},
   1607             SIMDVec16{_mm256_cvtepu8_epi16(g0g1g2)},
   1608             SIMDVec16{_mm256_cvtepu8_epi16(b0b1b2)}};
   1609   }
   1610   static std::array<SIMDVec16, 3> LoadRGB16(const unsigned char* data) {
   1611     auto load_and_split_lohi = [](const unsigned char* data) {
   1612       // LHLHLH...
   1613       __m256i bytes = _mm256_loadu_si256((__m256i*)data);
   1614       // L0L0L0...
   1615       __m256i lo = _mm256_and_si256(bytes, _mm256_set1_epi16(0xFF));
   1616       // H0H0H0...
   1617       __m256i hi = _mm256_srli_epi16(bytes, 8);
   1618       // LLLLLLLLHHHHHHHHLLLLLLLLHHHHHHHH
   1619       __m256i packed = _mm256_packus_epi16(lo, hi);
   1620       return _mm256_permute4x64_epi64(packed, 0b11011000);
   1621     };
   1622     __m256i bytes0 = load_and_split_lohi(data);
   1623     __m256i bytes1 = load_and_split_lohi(data + 32);
   1624     __m256i bytes2 = load_and_split_lohi(data + 64);
   1625 
   1626     __m256i idx = _mm256_broadcastsi128_si256(
   1627         _mm_setr_epi8(0, 3, 6, 9, 12, 15, 2, 5, 8, 11, 14, 1, 4, 7, 10, 13));
   1628 
   1629     __m256i r6b5g5_0 = _mm256_shuffle_epi8(bytes0, idx);
   1630     __m256i g6r5b5_1 = _mm256_shuffle_epi8(bytes1, idx);
   1631     __m256i b6g5r5_2 = _mm256_shuffle_epi8(bytes2, idx);
   1632 
   1633     __m256i mask010 = _mm256_broadcastsi128_si256(_mm_setr_epi8(
   1634         0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0, 0, 0, 0, 0));
   1635     __m256i mask001 = _mm256_broadcastsi128_si256(_mm_setr_epi8(
   1636         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF));
   1637 
   1638     __m256i b2g2b1 = _mm256_blendv_epi8(b6g5r5_2, g6r5b5_1, mask001);
   1639     __m256i b2b0b1 = _mm256_blendv_epi8(b2g2b1, r6b5g5_0, mask010);
   1640 
   1641     __m256i r0r1b1 = _mm256_blendv_epi8(r6b5g5_0, g6r5b5_1, mask010);
   1642     __m256i r0r1r2 = _mm256_blendv_epi8(r0r1b1, b6g5r5_2, mask001);
   1643 
   1644     __m256i g1r1g0 = _mm256_blendv_epi8(g6r5b5_1, r6b5g5_0, mask001);
   1645     __m256i g1g2g0 = _mm256_blendv_epi8(g1r1g0, b6g5r5_2, mask010);
   1646 
   1647     __m256i g0g1g2 = _mm256_alignr_epi8(g1g2g0, g1g2g0, 11);
   1648     __m256i b0b1b2 = _mm256_alignr_epi8(b2b0b1, b2b0b1, 6);
   1649 
   1650     // Now r0r1r2, g0g1g2, b0b1b2 have the low bytes of the RGB pixels in their
   1651     // lower half, and the high bytes in their upper half.
   1652 
   1653     auto combine_low_hi = [](__m256i v) {
   1654       __m128i low = _mm256_extracti128_si256(v, 0);
   1655       __m128i hi = _mm256_extracti128_si256(v, 1);
   1656       __m256i low16 = _mm256_cvtepu8_epi16(low);
   1657       __m256i hi16 = _mm256_cvtepu8_epi16(hi);
   1658       return _mm256_or_si256(_mm256_slli_epi16(hi16, 8), low16);
   1659     };
   1660 
   1661     return {SIMDVec16{combine_low_hi(r0r1r2)},
   1662             SIMDVec16{combine_low_hi(g0g1g2)},
   1663             SIMDVec16{combine_low_hi(b0b1b2)}};
   1664   }
   1665 
   1666   static std::array<SIMDVec16, 4> LoadRGBA8(const unsigned char* data) {
   1667     __m256i bytes1 = _mm256_loadu_si256((__m256i*)data);
   1668     __m256i bytes2 = _mm256_loadu_si256((__m256i*)(data + 32));
   1669     __m256i rg_mask = _mm256_set1_epi32(0xFFFF);
   1670     __m256i rg = _mm256_permute4x64_epi64(
   1671         _mm256_packus_epi32(_mm256_and_si256(bytes1, rg_mask),
   1672                             _mm256_and_si256(bytes2, rg_mask)),
   1673         0b11011000);
   1674     __m256i ba = _mm256_permute4x64_epi64(
   1675         _mm256_packus_epi32(_mm256_srli_epi32(bytes1, 16),
   1676                             _mm256_srli_epi32(bytes2, 16)),
   1677         0b11011000);
   1678     __m256i r = _mm256_and_si256(rg, _mm256_set1_epi16(0xFF));
   1679     __m256i g = _mm256_srli_epi16(rg, 8);
   1680     __m256i b = _mm256_and_si256(ba, _mm256_set1_epi16(0xFF));
   1681     __m256i a = _mm256_srli_epi16(ba, 8);
   1682     return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}, SIMDVec16{a}};
   1683   }
   1684   static std::array<SIMDVec16, 4> LoadRGBA16(const unsigned char* data) {
   1685     __m256i bytes0 = _mm256_loadu_si256((__m256i*)data);
   1686     __m256i bytes1 = _mm256_loadu_si256((__m256i*)(data + 32));
   1687     __m256i bytes2 = _mm256_loadu_si256((__m256i*)(data + 64));
   1688     __m256i bytes3 = _mm256_loadu_si256((__m256i*)(data + 96));
   1689 
   1690     auto pack32 = [](__m256i a, __m256i b) {
   1691       return _mm256_permute4x64_epi64(_mm256_packus_epi32(a, b), 0b11011000);
   1692     };
   1693     auto packlow32 = [&pack32](__m256i a, __m256i b) {
   1694       __m256i mask = _mm256_set1_epi32(0xFFFF);
   1695       return pack32(_mm256_and_si256(a, mask), _mm256_and_si256(b, mask));
   1696     };
   1697     auto packhi32 = [&pack32](__m256i a, __m256i b) {
   1698       return pack32(_mm256_srli_epi32(a, 16), _mm256_srli_epi32(b, 16));
   1699     };
   1700 
   1701     __m256i rb0 = packlow32(bytes0, bytes1);
   1702     __m256i rb1 = packlow32(bytes2, bytes3);
   1703     __m256i ga0 = packhi32(bytes0, bytes1);
   1704     __m256i ga1 = packhi32(bytes2, bytes3);
   1705 
   1706     __m256i r = packlow32(rb0, rb1);
   1707     __m256i g = packlow32(ga0, ga1);
   1708     __m256i b = packhi32(rb0, rb1);
   1709     __m256i a = packhi32(ga0, ga1);
   1710     return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}, SIMDVec16{a}};
   1711   }
   1712 
   1713   void SwapEndian() {
   1714     auto indices = _mm256_broadcastsi128_si256(
   1715         _mm_setr_epi8(1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14));
   1716     vec = _mm256_shuffle_epi8(vec, indices);
   1717   }
   1718 };
   1719 
   1720 SIMDVec16 Mask16::IfThenElse(const SIMDVec16& if_true,
   1721                              const SIMDVec16& if_false) {
   1722   return SIMDVec16{_mm256_blendv_epi8(if_false.vec, if_true.vec, mask)};
   1723 }
   1724 
   1725 SIMDVec32 Mask32::IfThenElse(const SIMDVec32& if_true,
   1726                              const SIMDVec32& if_false) {
   1727   return SIMDVec32{_mm256_blendv_epi8(if_false.vec, if_true.vec, mask)};
   1728 }
   1729 
   1730 struct Bits64 {
   1731   static constexpr size_t kLanes = 4;
   1732 
   1733   __m256i nbits;
   1734   __m256i bits;
   1735 
   1736   FJXL_INLINE void Store(uint64_t* nbits_out, uint64_t* bits_out) {
   1737     _mm256_storeu_si256((__m256i*)nbits_out, nbits);
   1738     _mm256_storeu_si256((__m256i*)bits_out, bits);
   1739   }
   1740 };
   1741 
   1742 struct Bits32 {
   1743   __m256i nbits;
   1744   __m256i bits;
   1745 
   1746   static Bits32 FromRaw(SIMDVec32 nbits, SIMDVec32 bits) {
   1747     return Bits32{nbits.vec, bits.vec};
   1748   }
   1749 
   1750   Bits64 Merge() const {
   1751     auto nbits_hi32 = _mm256_srli_epi64(nbits, 32);
   1752     auto nbits_lo32 = _mm256_and_si256(nbits, _mm256_set1_epi64x(0xFFFFFFFF));
   1753     auto bits_hi32 = _mm256_srli_epi64(bits, 32);
   1754     auto bits_lo32 = _mm256_and_si256(bits, _mm256_set1_epi64x(0xFFFFFFFF));
   1755 
   1756     auto nbits64 = _mm256_add_epi64(nbits_hi32, nbits_lo32);
   1757     auto bits64 =
   1758         _mm256_or_si256(_mm256_sllv_epi64(bits_hi32, nbits_lo32), bits_lo32);
   1759     return Bits64{nbits64, bits64};
   1760   }
   1761 
   1762   void Interleave(const Bits32& low) {
   1763     bits = _mm256_or_si256(_mm256_sllv_epi32(bits, low.nbits), low.bits);
   1764     nbits = _mm256_add_epi32(nbits, low.nbits);
   1765   }
   1766 
   1767   void ClipTo(size_t n) {
   1768     n = std::min<size_t>(n, 8);
   1769     constexpr uint32_t kMask[16] = {
   1770         ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, 0, 0, 0, 0, 0, 0, 0, 0,
   1771     };
   1772     __m256i mask = _mm256_loadu_si256((__m256i*)(kMask + 8 - n));
   1773     nbits = _mm256_and_si256(mask, nbits);
   1774     bits = _mm256_and_si256(mask, bits);
   1775   }
   1776   void Skip(size_t n) {
   1777     n = std::min<size_t>(n, 8);
   1778     constexpr uint32_t kMask[16] = {
   1779         0, 0, 0, 0, 0, 0, 0, 0, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u,
   1780     };
   1781     __m256i mask = _mm256_loadu_si256((__m256i*)(kMask + 8 - n));
   1782     nbits = _mm256_and_si256(mask, nbits);
   1783     bits = _mm256_and_si256(mask, bits);
   1784   }
   1785 };
   1786 
   1787 struct Bits16 {
   1788   __m256i nbits;
   1789   __m256i bits;
   1790 
   1791   static Bits16 FromRaw(SIMDVec16 nbits, SIMDVec16 bits) {
   1792     return Bits16{nbits.vec, bits.vec};
   1793   }
   1794 
   1795   Bits32 Merge() const {
   1796     auto nbits_hi16 = _mm256_srli_epi32(nbits, 16);
   1797     auto nbits_lo16 = _mm256_and_si256(nbits, _mm256_set1_epi32(0xFFFF));
   1798     auto bits_hi16 = _mm256_srli_epi32(bits, 16);
   1799     auto bits_lo16 = _mm256_and_si256(bits, _mm256_set1_epi32(0xFFFF));
   1800 
   1801     auto nbits32 = _mm256_add_epi32(nbits_hi16, nbits_lo16);
   1802     auto bits32 =
   1803         _mm256_or_si256(_mm256_sllv_epi32(bits_hi16, nbits_lo16), bits_lo16);
   1804     return Bits32{nbits32, bits32};
   1805   }
   1806 
   1807   void Interleave(const Bits16& low) {
   1808     auto pow2_lo_lut = _mm256_broadcastsi128_si256(
   1809         _mm_setr_epi8(1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 6,
   1810                       1u << 7, 0, 0, 0, 0, 0, 0, 0, 0));
   1811     auto low_nbits_masked =
   1812         _mm256_or_si256(low.nbits, _mm256_set1_epi16(0xFF00));
   1813 
   1814     auto bits_shifted = _mm256_mullo_epi16(
   1815         bits, _mm256_shuffle_epi8(pow2_lo_lut, low_nbits_masked));
   1816 
   1817     nbits = _mm256_add_epi16(nbits, low.nbits);
   1818     bits = _mm256_or_si256(bits_shifted, low.bits);
   1819   }
   1820 
   1821   void ClipTo(size_t n) {
   1822     n = std::min<size_t>(n, 16);
   1823     constexpr uint16_t kMask[32] = {
   1824         0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   1825         0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   1826         0,      0,      0,      0,      0,      0,      0,      0,
   1827         0,      0,      0,      0,      0,      0,      0,      0,
   1828     };
   1829     __m256i mask = _mm256_loadu_si256((__m256i*)(kMask + 16 - n));
   1830     nbits = _mm256_and_si256(mask, nbits);
   1831     bits = _mm256_and_si256(mask, bits);
   1832   }
   1833 
   1834   void Skip(size_t n) {
   1835     n = std::min<size_t>(n, 16);
   1836     constexpr uint16_t kMask[32] = {
   1837         0,      0,      0,      0,      0,      0,      0,      0,
   1838         0,      0,      0,      0,      0,      0,      0,      0,
   1839         0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   1840         0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   1841     };
   1842     __m256i mask = _mm256_loadu_si256((__m256i*)(kMask + 16 - n));
   1843     nbits = _mm256_and_si256(mask, nbits);
   1844     bits = _mm256_and_si256(mask, bits);
   1845   }
   1846 };
   1847 
   1848 #endif
   1849 
   1850 #ifdef FJXL_NEON
   1851 #define FJXL_GENERIC_SIMD
   1852 
   1853 struct SIMDVec32;
   1854 
   1855 struct Mask32 {
   1856   uint32x4_t mask;
   1857   SIMDVec32 IfThenElse(const SIMDVec32& if_true, const SIMDVec32& if_false);
   1858   Mask32 And(const Mask32& oth) const {
   1859     return Mask32{vandq_u32(mask, oth.mask)};
   1860   }
   1861   size_t CountPrefix() const {
   1862     uint32_t val_unset[4] = {0, 1, 2, 3};
   1863     uint32_t val_set[4] = {4, 4, 4, 4};
   1864     uint32x4_t val = vbslq_u32(mask, vld1q_u32(val_set), vld1q_u32(val_unset));
   1865     return vminvq_u32(val);
   1866   }
   1867 };
   1868 
   1869 struct SIMDVec32 {
   1870   uint32x4_t vec;
   1871 
   1872   static constexpr size_t kLanes = 4;
   1873 
   1874   FJXL_INLINE static SIMDVec32 Load(const uint32_t* data) {
   1875     return SIMDVec32{vld1q_u32(data)};
   1876   }
   1877   FJXL_INLINE void Store(uint32_t* data) { vst1q_u32(data, vec); }
   1878   FJXL_INLINE static SIMDVec32 Val(uint32_t v) {
   1879     return SIMDVec32{vdupq_n_u32(v)};
   1880   }
   1881   FJXL_INLINE SIMDVec32 ValToToken() const {
   1882     return SIMDVec32{vsubq_u32(vdupq_n_u32(32), vclzq_u32(vec))};
   1883   }
   1884   FJXL_INLINE SIMDVec32 SatSubU(const SIMDVec32& to_subtract) const {
   1885     return SIMDVec32{vqsubq_u32(vec, to_subtract.vec)};
   1886   }
   1887   FJXL_INLINE SIMDVec32 Sub(const SIMDVec32& to_subtract) const {
   1888     return SIMDVec32{vsubq_u32(vec, to_subtract.vec)};
   1889   }
   1890   FJXL_INLINE SIMDVec32 Add(const SIMDVec32& oth) const {
   1891     return SIMDVec32{vaddq_u32(vec, oth.vec)};
   1892   }
   1893   FJXL_INLINE SIMDVec32 Xor(const SIMDVec32& oth) const {
   1894     return SIMDVec32{veorq_u32(vec, oth.vec)};
   1895   }
   1896   FJXL_INLINE SIMDVec32 Pow2() const {
   1897     return SIMDVec32{vshlq_u32(vdupq_n_u32(1), vreinterpretq_s32_u32(vec))};
   1898   }
   1899   FJXL_INLINE Mask32 Eq(const SIMDVec32& oth) const {
   1900     return Mask32{vceqq_u32(vec, oth.vec)};
   1901   }
   1902   FJXL_INLINE Mask32 Gt(const SIMDVec32& oth) const {
   1903     return Mask32{
   1904         vcgtq_s32(vreinterpretq_s32_u32(vec), vreinterpretq_s32_u32(oth.vec))};
   1905   }
   1906   template <size_t i>
   1907   FJXL_INLINE SIMDVec32 SignedShiftRight() const {
   1908     return SIMDVec32{
   1909         vreinterpretq_u32_s32(vshrq_n_s32(vreinterpretq_s32_u32(vec), i))};
   1910   }
   1911 };
   1912 
   1913 struct SIMDVec16;
   1914 
   1915 struct Mask16 {
   1916   uint16x8_t mask;
   1917   SIMDVec16 IfThenElse(const SIMDVec16& if_true, const SIMDVec16& if_false);
   1918   Mask16 And(const Mask16& oth) const {
   1919     return Mask16{vandq_u16(mask, oth.mask)};
   1920   }
   1921   size_t CountPrefix() const {
   1922     uint16_t val_unset[8] = {0, 1, 2, 3, 4, 5, 6, 7};
   1923     uint16_t val_set[8] = {8, 8, 8, 8, 8, 8, 8, 8};
   1924     uint16x8_t val = vbslq_u16(mask, vld1q_u16(val_set), vld1q_u16(val_unset));
   1925     return vminvq_u16(val);
   1926   }
   1927 };
   1928 
   1929 struct SIMDVec16 {
   1930   uint16x8_t vec;
   1931 
   1932   static constexpr size_t kLanes = 8;
   1933 
   1934   FJXL_INLINE static SIMDVec16 Load(const uint16_t* data) {
   1935     return SIMDVec16{vld1q_u16(data)};
   1936   }
   1937   FJXL_INLINE void Store(uint16_t* data) { vst1q_u16(data, vec); }
   1938   FJXL_INLINE static SIMDVec16 Val(uint16_t v) {
   1939     return SIMDVec16{vdupq_n_u16(v)};
   1940   }
   1941   FJXL_INLINE static SIMDVec16 FromTwo32(const SIMDVec32& lo,
   1942                                          const SIMDVec32& hi) {
   1943     return SIMDVec16{vmovn_high_u32(vmovn_u32(lo.vec), hi.vec)};
   1944   }
   1945 
   1946   FJXL_INLINE SIMDVec16 ValToToken() const {
   1947     return SIMDVec16{vsubq_u16(vdupq_n_u16(16), vclzq_u16(vec))};
   1948   }
   1949   FJXL_INLINE SIMDVec16 SatSubU(const SIMDVec16& to_subtract) const {
   1950     return SIMDVec16{vqsubq_u16(vec, to_subtract.vec)};
   1951   }
   1952   FJXL_INLINE SIMDVec16 Sub(const SIMDVec16& to_subtract) const {
   1953     return SIMDVec16{vsubq_u16(vec, to_subtract.vec)};
   1954   }
   1955   FJXL_INLINE SIMDVec16 Add(const SIMDVec16& oth) const {
   1956     return SIMDVec16{vaddq_u16(vec, oth.vec)};
   1957   }
   1958   FJXL_INLINE SIMDVec16 Min(const SIMDVec16& oth) const {
   1959     return SIMDVec16{vminq_u16(vec, oth.vec)};
   1960   }
   1961   FJXL_INLINE Mask16 Eq(const SIMDVec16& oth) const {
   1962     return Mask16{vceqq_u16(vec, oth.vec)};
   1963   }
   1964   FJXL_INLINE Mask16 Gt(const SIMDVec16& oth) const {
   1965     return Mask16{
   1966         vcgtq_s16(vreinterpretq_s16_u16(vec), vreinterpretq_s16_u16(oth.vec))};
   1967   }
   1968   FJXL_INLINE SIMDVec16 Pow2() const {
   1969     return SIMDVec16{vshlq_u16(vdupq_n_u16(1), vreinterpretq_s16_u16(vec))};
   1970   }
   1971   FJXL_INLINE SIMDVec16 Or(const SIMDVec16& oth) const {
   1972     return SIMDVec16{vorrq_u16(vec, oth.vec)};
   1973   }
   1974   FJXL_INLINE SIMDVec16 Xor(const SIMDVec16& oth) const {
   1975     return SIMDVec16{veorq_u16(vec, oth.vec)};
   1976   }
   1977   FJXL_INLINE SIMDVec16 And(const SIMDVec16& oth) const {
   1978     return SIMDVec16{vandq_u16(vec, oth.vec)};
   1979   }
   1980   FJXL_INLINE SIMDVec16 HAdd(const SIMDVec16& oth) const {
   1981     return SIMDVec16{vhaddq_u16(vec, oth.vec)};
   1982   }
   1983   FJXL_INLINE SIMDVec16 PrepareForU8Lookup() const {
   1984     return SIMDVec16{vorrq_u16(vec, vdupq_n_u16(0xFF00))};
   1985   }
   1986   FJXL_INLINE SIMDVec16 U8Lookup(const uint8_t* table) const {
   1987     uint8x16_t tbl = vld1q_u8(table);
   1988     uint8x16_t indices = vreinterpretq_u8_u16(vec);
   1989     return SIMDVec16{vreinterpretq_u16_u8(vqtbl1q_u8(tbl, indices))};
   1990   }
   1991   FJXL_INLINE VecPair<SIMDVec16> Interleave(const SIMDVec16& low) const {
   1992     return {SIMDVec16{vzip1q_u16(low.vec, vec)},
   1993             SIMDVec16{vzip2q_u16(low.vec, vec)}};
   1994   }
   1995   FJXL_INLINE VecPair<SIMDVec32> Upcast() const {
   1996     uint32x4_t lo = vmovl_u16(vget_low_u16(vec));
   1997     uint32x4_t hi = vmovl_high_u16(vec);
   1998     return {SIMDVec32{lo}, SIMDVec32{hi}};
   1999   }
   2000   template <size_t i>
   2001   FJXL_INLINE SIMDVec16 SignedShiftRight() const {
   2002     return SIMDVec16{
   2003         vreinterpretq_u16_s16(vshrq_n_s16(vreinterpretq_s16_u16(vec), i))};
   2004   }
   2005 
   2006   static std::array<SIMDVec16, 1> LoadG8(const unsigned char* data) {
   2007     uint8x8_t v = vld1_u8(data);
   2008     return {SIMDVec16{vmovl_u8(v)}};
   2009   }
   2010   static std::array<SIMDVec16, 1> LoadG16(const unsigned char* data) {
   2011     return {Load((const uint16_t*)data)};
   2012   }
   2013 
   2014   static std::array<SIMDVec16, 2> LoadGA8(const unsigned char* data) {
   2015     uint8x8x2_t v = vld2_u8(data);
   2016     return {SIMDVec16{vmovl_u8(v.val[0])}, SIMDVec16{vmovl_u8(v.val[1])}};
   2017   }
   2018   static std::array<SIMDVec16, 2> LoadGA16(const unsigned char* data) {
   2019     uint16x8x2_t v = vld2q_u16((const uint16_t*)data);
   2020     return {SIMDVec16{v.val[0]}, SIMDVec16{v.val[1]}};
   2021   }
   2022 
   2023   static std::array<SIMDVec16, 3> LoadRGB8(const unsigned char* data) {
   2024     uint8x8x3_t v = vld3_u8(data);
   2025     return {SIMDVec16{vmovl_u8(v.val[0])}, SIMDVec16{vmovl_u8(v.val[1])},
   2026             SIMDVec16{vmovl_u8(v.val[2])}};
   2027   }
   2028   static std::array<SIMDVec16, 3> LoadRGB16(const unsigned char* data) {
   2029     uint16x8x3_t v = vld3q_u16((const uint16_t*)data);
   2030     return {SIMDVec16{v.val[0]}, SIMDVec16{v.val[1]}, SIMDVec16{v.val[2]}};
   2031   }
   2032 
   2033   static std::array<SIMDVec16, 4> LoadRGBA8(const unsigned char* data) {
   2034     uint8x8x4_t v = vld4_u8(data);
   2035     return {SIMDVec16{vmovl_u8(v.val[0])}, SIMDVec16{vmovl_u8(v.val[1])},
   2036             SIMDVec16{vmovl_u8(v.val[2])}, SIMDVec16{vmovl_u8(v.val[3])}};
   2037   }
   2038   static std::array<SIMDVec16, 4> LoadRGBA16(const unsigned char* data) {
   2039     uint16x8x4_t v = vld4q_u16((const uint16_t*)data);
   2040     return {SIMDVec16{v.val[0]}, SIMDVec16{v.val[1]}, SIMDVec16{v.val[2]},
   2041             SIMDVec16{v.val[3]}};
   2042   }
   2043 
   2044   void SwapEndian() {
   2045     vec = vreinterpretq_u16_u8(vrev16q_u8(vreinterpretq_u8_u16(vec)));
   2046   }
   2047 };
   2048 
   2049 SIMDVec16 Mask16::IfThenElse(const SIMDVec16& if_true,
   2050                              const SIMDVec16& if_false) {
   2051   return SIMDVec16{vbslq_u16(mask, if_true.vec, if_false.vec)};
   2052 }
   2053 
   2054 SIMDVec32 Mask32::IfThenElse(const SIMDVec32& if_true,
   2055                              const SIMDVec32& if_false) {
   2056   return SIMDVec32{vbslq_u32(mask, if_true.vec, if_false.vec)};
   2057 }
   2058 
   2059 struct Bits64 {
   2060   static constexpr size_t kLanes = 2;
   2061 
   2062   uint64x2_t nbits;
   2063   uint64x2_t bits;
   2064 
   2065   FJXL_INLINE void Store(uint64_t* nbits_out, uint64_t* bits_out) {
   2066     vst1q_u64(nbits_out, nbits);
   2067     vst1q_u64(bits_out, bits);
   2068   }
   2069 };
   2070 
   2071 struct Bits32 {
   2072   uint32x4_t nbits;
   2073   uint32x4_t bits;
   2074 
   2075   static Bits32 FromRaw(SIMDVec32 nbits, SIMDVec32 bits) {
   2076     return Bits32{nbits.vec, bits.vec};
   2077   }
   2078 
   2079   Bits64 Merge() const {
   2080     // TODO(veluca): can probably be optimized.
   2081     uint64x2_t nbits_lo32 =
   2082         vandq_u64(vreinterpretq_u64_u32(nbits), vdupq_n_u64(0xFFFFFFFF));
   2083     uint64x2_t bits_hi32 =
   2084         vshlq_u64(vshrq_n_u64(vreinterpretq_u64_u32(bits), 32),
   2085                   vreinterpretq_s64_u64(nbits_lo32));
   2086     uint64x2_t bits_lo32 =
   2087         vandq_u64(vreinterpretq_u64_u32(bits), vdupq_n_u64(0xFFFFFFFF));
   2088     uint64x2_t nbits64 =
   2089         vsraq_n_u64(nbits_lo32, vreinterpretq_u64_u32(nbits), 32);
   2090     uint64x2_t bits64 = vorrq_u64(bits_hi32, bits_lo32);
   2091     return Bits64{nbits64, bits64};
   2092   }
   2093 
   2094   void Interleave(const Bits32& low) {
   2095     bits =
   2096         vorrq_u32(vshlq_u32(bits, vreinterpretq_s32_u32(low.nbits)), low.bits);
   2097     nbits = vaddq_u32(nbits, low.nbits);
   2098   }
   2099 
   2100   void ClipTo(size_t n) {
   2101     n = std::min<size_t>(n, 4);
   2102     constexpr uint32_t kMask[8] = {
   2103         ~0u, ~0u, ~0u, ~0u, 0, 0, 0, 0,
   2104     };
   2105     uint32x4_t mask = vld1q_u32(kMask + 4 - n);
   2106     nbits = vandq_u32(mask, nbits);
   2107     bits = vandq_u32(mask, bits);
   2108   }
   2109   void Skip(size_t n) {
   2110     n = std::min<size_t>(n, 4);
   2111     constexpr uint32_t kMask[8] = {
   2112         0, 0, 0, 0, ~0u, ~0u, ~0u, ~0u,
   2113     };
   2114     uint32x4_t mask = vld1q_u32(kMask + 4 - n);
   2115     nbits = vandq_u32(mask, nbits);
   2116     bits = vandq_u32(mask, bits);
   2117   }
   2118 };
   2119 
   2120 struct Bits16 {
   2121   uint16x8_t nbits;
   2122   uint16x8_t bits;
   2123 
   2124   static Bits16 FromRaw(SIMDVec16 nbits, SIMDVec16 bits) {
   2125     return Bits16{nbits.vec, bits.vec};
   2126   }
   2127 
   2128   Bits32 Merge() const {
   2129     // TODO(veluca): can probably be optimized.
   2130     uint32x4_t nbits_lo16 =
   2131         vandq_u32(vreinterpretq_u32_u16(nbits), vdupq_n_u32(0xFFFF));
   2132     uint32x4_t bits_hi16 =
   2133         vshlq_u32(vshrq_n_u32(vreinterpretq_u32_u16(bits), 16),
   2134                   vreinterpretq_s32_u32(nbits_lo16));
   2135     uint32x4_t bits_lo16 =
   2136         vandq_u32(vreinterpretq_u32_u16(bits), vdupq_n_u32(0xFFFF));
   2137     uint32x4_t nbits32 =
   2138         vsraq_n_u32(nbits_lo16, vreinterpretq_u32_u16(nbits), 16);
   2139     uint32x4_t bits32 = vorrq_u32(bits_hi16, bits_lo16);
   2140     return Bits32{nbits32, bits32};
   2141   }
   2142 
   2143   void Interleave(const Bits16& low) {
   2144     bits =
   2145         vorrq_u16(vshlq_u16(bits, vreinterpretq_s16_u16(low.nbits)), low.bits);
   2146     nbits = vaddq_u16(nbits, low.nbits);
   2147   }
   2148 
   2149   void ClipTo(size_t n) {
   2150     n = std::min<size_t>(n, 8);
   2151     constexpr uint16_t kMask[16] = {
   2152         0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   2153         0,      0,      0,      0,      0,      0,      0,      0,
   2154     };
   2155     uint16x8_t mask = vld1q_u16(kMask + 8 - n);
   2156     nbits = vandq_u16(mask, nbits);
   2157     bits = vandq_u16(mask, bits);
   2158   }
   2159   void Skip(size_t n) {
   2160     n = std::min<size_t>(n, 8);
   2161     constexpr uint16_t kMask[16] = {
   2162         0,      0,      0,      0,      0,      0,      0,      0,
   2163         0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   2164     };
   2165     uint16x8_t mask = vld1q_u16(kMask + 8 - n);
   2166     nbits = vandq_u16(mask, nbits);
   2167     bits = vandq_u16(mask, bits);
   2168   }
   2169 };
   2170 
   2171 #endif
   2172 
   2173 #ifdef FJXL_GENERIC_SIMD
   2174 constexpr size_t SIMDVec32::kLanes;
   2175 constexpr size_t SIMDVec16::kLanes;
   2176 
   2177 //  Each of these functions will process SIMDVec16::kLanes worth of values.
   2178 
   2179 FJXL_INLINE void TokenizeSIMD(const uint16_t* residuals, uint16_t* token_out,
   2180                               uint16_t* nbits_out, uint16_t* bits_out) {
   2181   SIMDVec16 res = SIMDVec16::Load(residuals);
   2182   SIMDVec16 token = res.ValToToken();
   2183   SIMDVec16 nbits = token.SatSubU(SIMDVec16::Val(1));
   2184   SIMDVec16 bits = res.SatSubU(nbits.Pow2());
   2185   token.Store(token_out);
   2186   nbits.Store(nbits_out);
   2187   bits.Store(bits_out);
   2188 }
   2189 
   2190 FJXL_INLINE void TokenizeSIMD(const uint32_t* residuals, uint16_t* token_out,
   2191                               uint32_t* nbits_out, uint32_t* bits_out) {
   2192   static_assert(SIMDVec16::kLanes == 2 * SIMDVec32::kLanes, "");
   2193   SIMDVec32 res_lo = SIMDVec32::Load(residuals);
   2194   SIMDVec32 res_hi = SIMDVec32::Load(residuals + SIMDVec32::kLanes);
   2195   SIMDVec32 token_lo = res_lo.ValToToken();
   2196   SIMDVec32 token_hi = res_hi.ValToToken();
   2197   SIMDVec32 nbits_lo = token_lo.SatSubU(SIMDVec32::Val(1));
   2198   SIMDVec32 nbits_hi = token_hi.SatSubU(SIMDVec32::Val(1));
   2199   SIMDVec32 bits_lo = res_lo.SatSubU(nbits_lo.Pow2());
   2200   SIMDVec32 bits_hi = res_hi.SatSubU(nbits_hi.Pow2());
   2201   SIMDVec16 token = SIMDVec16::FromTwo32(token_lo, token_hi);
   2202   token.Store(token_out);
   2203   nbits_lo.Store(nbits_out);
   2204   nbits_hi.Store(nbits_out + SIMDVec32::kLanes);
   2205   bits_lo.Store(bits_out);
   2206   bits_hi.Store(bits_out + SIMDVec32::kLanes);
   2207 }
   2208 
   2209 FJXL_INLINE void HuffmanSIMDUpTo13(const uint16_t* tokens,
   2210                                    const uint8_t* raw_nbits_simd,
   2211                                    const uint8_t* raw_bits_simd,
   2212                                    uint16_t* nbits_out, uint16_t* bits_out) {
   2213   SIMDVec16 tok = SIMDVec16::Load(tokens).PrepareForU8Lookup();
   2214   tok.U8Lookup(raw_nbits_simd).Store(nbits_out);
   2215   tok.U8Lookup(raw_bits_simd).Store(bits_out);
   2216 }
   2217 
   2218 FJXL_INLINE void HuffmanSIMD14(const uint16_t* tokens,
   2219                                const uint8_t* raw_nbits_simd,
   2220                                const uint8_t* raw_bits_simd,
   2221                                uint16_t* nbits_out, uint16_t* bits_out) {
   2222   SIMDVec16 token_cap = SIMDVec16::Val(15);
   2223   SIMDVec16 tok = SIMDVec16::Load(tokens);
   2224   SIMDVec16 tok_index = tok.Min(token_cap).PrepareForU8Lookup();
   2225   SIMDVec16 huff_bits_pre = tok_index.U8Lookup(raw_bits_simd);
   2226   // Set the highest bit when token == 16; the Huffman code is constructed in
   2227   // such a way that the code for token 15 is the same as the code for 16,
   2228   // except for the highest bit.
   2229   Mask16 needs_high_bit = tok.Eq(SIMDVec16::Val(16));
   2230   SIMDVec16 huff_bits = needs_high_bit.IfThenElse(
   2231       huff_bits_pre.Or(SIMDVec16::Val(128)), huff_bits_pre);
   2232   huff_bits.Store(bits_out);
   2233   tok_index.U8Lookup(raw_nbits_simd).Store(nbits_out);
   2234 }
   2235 
   2236 FJXL_INLINE void HuffmanSIMDAbove14(const uint16_t* tokens,
   2237                                     const uint8_t* raw_nbits_simd,
   2238                                     const uint8_t* raw_bits_simd,
   2239                                     uint16_t* nbits_out, uint16_t* bits_out) {
   2240   SIMDVec16 tok = SIMDVec16::Load(tokens);
   2241   // We assume `tok` fits in a *signed* 16-bit integer.
   2242   Mask16 above = tok.Gt(SIMDVec16::Val(12));
   2243   // 13, 14 -> 13
   2244   // 15, 16 -> 14
   2245   // 17, 18 -> 15
   2246   SIMDVec16 remap_tok = above.IfThenElse(tok.HAdd(SIMDVec16::Val(13)), tok);
   2247   SIMDVec16 tok_index = remap_tok.PrepareForU8Lookup();
   2248   SIMDVec16 huff_bits_pre = tok_index.U8Lookup(raw_bits_simd);
   2249   // Set the highest bit when token == 14, 16, 18.
   2250   Mask16 needs_high_bit = above.And(tok.Eq(tok.And(SIMDVec16::Val(0xFFFE))));
   2251   SIMDVec16 huff_bits = needs_high_bit.IfThenElse(
   2252       huff_bits_pre.Or(SIMDVec16::Val(128)), huff_bits_pre);
   2253   huff_bits.Store(bits_out);
   2254   tok_index.U8Lookup(raw_nbits_simd).Store(nbits_out);
   2255 }
   2256 
   2257 FJXL_INLINE void StoreSIMDUpTo8(const uint16_t* nbits_tok,
   2258                                 const uint16_t* bits_tok,
   2259                                 const uint16_t* nbits_huff,
   2260                                 const uint16_t* bits_huff, size_t n,
   2261                                 size_t skip, Bits32* bits_out) {
   2262   Bits16 bits =
   2263       Bits16::FromRaw(SIMDVec16::Load(nbits_tok), SIMDVec16::Load(bits_tok));
   2264   Bits16 huff_bits =
   2265       Bits16::FromRaw(SIMDVec16::Load(nbits_huff), SIMDVec16::Load(bits_huff));
   2266   bits.Interleave(huff_bits);
   2267   bits.ClipTo(n);
   2268   bits.Skip(skip);
   2269   bits_out[0] = bits.Merge();
   2270 }
   2271 
   2272 // Huffman and raw bits don't necessarily fit in a single u16 here.
   2273 FJXL_INLINE void StoreSIMDUpTo14(const uint16_t* nbits_tok,
   2274                                  const uint16_t* bits_tok,
   2275                                  const uint16_t* nbits_huff,
   2276                                  const uint16_t* bits_huff, size_t n,
   2277                                  size_t skip, Bits32* bits_out) {
   2278   VecPair<SIMDVec16> bits =
   2279       SIMDVec16::Load(bits_tok).Interleave(SIMDVec16::Load(bits_huff));
   2280   VecPair<SIMDVec16> nbits =
   2281       SIMDVec16::Load(nbits_tok).Interleave(SIMDVec16::Load(nbits_huff));
   2282   Bits16 low = Bits16::FromRaw(nbits.low, bits.low);
   2283   Bits16 hi = Bits16::FromRaw(nbits.hi, bits.hi);
   2284   low.ClipTo(2 * n);
   2285   low.Skip(2 * skip);
   2286   hi.ClipTo(std::max(2 * n, SIMDVec16::kLanes) - SIMDVec16::kLanes);
   2287   hi.Skip(std::max(2 * skip, SIMDVec16::kLanes) - SIMDVec16::kLanes);
   2288 
   2289   bits_out[0] = low.Merge();
   2290   bits_out[1] = hi.Merge();
   2291 }
   2292 
   2293 FJXL_INLINE void StoreSIMDAbove14(const uint32_t* nbits_tok,
   2294                                   const uint32_t* bits_tok,
   2295                                   const uint16_t* nbits_huff,
   2296                                   const uint16_t* bits_huff, size_t n,
   2297                                   size_t skip, Bits32* bits_out) {
   2298   static_assert(SIMDVec16::kLanes == 2 * SIMDVec32::kLanes, "");
   2299   Bits32 bits_low =
   2300       Bits32::FromRaw(SIMDVec32::Load(nbits_tok), SIMDVec32::Load(bits_tok));
   2301   Bits32 bits_hi =
   2302       Bits32::FromRaw(SIMDVec32::Load(nbits_tok + SIMDVec32::kLanes),
   2303                       SIMDVec32::Load(bits_tok + SIMDVec32::kLanes));
   2304 
   2305   VecPair<SIMDVec32> huff_bits = SIMDVec16::Load(bits_huff).Upcast();
   2306   VecPair<SIMDVec32> huff_nbits = SIMDVec16::Load(nbits_huff).Upcast();
   2307 
   2308   Bits32 huff_low = Bits32::FromRaw(huff_nbits.low, huff_bits.low);
   2309   Bits32 huff_hi = Bits32::FromRaw(huff_nbits.hi, huff_bits.hi);
   2310 
   2311   bits_low.Interleave(huff_low);
   2312   bits_low.ClipTo(n);
   2313   bits_low.Skip(skip);
   2314   bits_out[0] = bits_low;
   2315   bits_hi.Interleave(huff_hi);
   2316   bits_hi.ClipTo(std::max(n, SIMDVec32::kLanes) - SIMDVec32::kLanes);
   2317   bits_hi.Skip(std::max(skip, SIMDVec32::kLanes) - SIMDVec32::kLanes);
   2318   bits_out[1] = bits_hi;
   2319 }
   2320 
   2321 #ifdef FJXL_AVX512
   2322 FJXL_INLINE void StoreToWriterAVX512(const Bits32& bits32, BitWriter& output) {
   2323   __m512i bits = bits32.bits;
   2324   __m512i nbits = bits32.nbits;
   2325 
   2326   // Insert the leftover bits from the bit buffer at the bottom of the vector
   2327   // and extract the top of the vector.
   2328   uint64_t trail_bits =
   2329       _mm512_cvtsi512_si32(_mm512_alignr_epi32(bits, bits, 15));
   2330   uint64_t trail_nbits =
   2331       _mm512_cvtsi512_si32(_mm512_alignr_epi32(nbits, nbits, 15));
   2332   __m512i lead_bits = _mm512_set1_epi32(output.buffer);
   2333   __m512i lead_nbits = _mm512_set1_epi32(output.bits_in_buffer);
   2334   bits = _mm512_alignr_epi32(bits, lead_bits, 15);
   2335   nbits = _mm512_alignr_epi32(nbits, lead_nbits, 15);
   2336 
   2337   // Merge 32 -> 64 bits.
   2338   Bits32 b{nbits, bits};
   2339   Bits64 b64 = b.Merge();
   2340   bits = b64.bits;
   2341   nbits = b64.nbits;
   2342 
   2343   __m512i zero = _mm512_setzero_si512();
   2344 
   2345   auto sh1 = [zero](__m512i vec) { return _mm512_alignr_epi64(vec, zero, 7); };
   2346   auto sh2 = [zero](__m512i vec) { return _mm512_alignr_epi64(vec, zero, 6); };
   2347   auto sh4 = [zero](__m512i vec) { return _mm512_alignr_epi64(vec, zero, 4); };
   2348 
   2349   // Compute first-past-end-bit-position.
   2350   __m512i end_interm0 = _mm512_add_epi64(nbits, sh1(nbits));
   2351   __m512i end_interm1 = _mm512_add_epi64(end_interm0, sh2(end_interm0));
   2352   __m512i end = _mm512_add_epi64(end_interm1, sh4(end_interm1));
   2353 
   2354   uint64_t simd_nbits = _mm512_cvtsi512_si32(_mm512_alignr_epi64(end, end, 7));
   2355 
   2356   // Compute begin-bit-position.
   2357   __m512i begin = _mm512_sub_epi64(end, nbits);
   2358 
   2359   // Index of the last bit in the chunk, or the end bit if nbits==0.
   2360   __m512i last = _mm512_mask_sub_epi64(
   2361       end, _mm512_cmpneq_epi64_mask(nbits, zero), end, _mm512_set1_epi64(1));
   2362 
   2363   __m512i lane_offset_mask = _mm512_set1_epi64(63);
   2364 
   2365   // Starting position of the chunk that each lane will ultimately belong to.
   2366   __m512i chunk_start = _mm512_andnot_si512(lane_offset_mask, last);
   2367 
   2368   // For all lanes that contain bits belonging to two different 64-bit chunks,
   2369   // compute the number of bits that belong to the first chunk.
   2370   // total # of bits fit in a u16, so we can satsub_u16 here.
   2371   __m512i first_chunk_nbits = _mm512_subs_epu16(chunk_start, begin);
   2372 
   2373   // Move all the previous-chunk-bits to the previous lane.
   2374   __m512i negnbits = _mm512_sub_epi64(_mm512_set1_epi64(64), first_chunk_nbits);
   2375   __m512i first_chunk_bits =
   2376       _mm512_srlv_epi64(_mm512_sllv_epi64(bits, negnbits), negnbits);
   2377   __m512i first_chunk_bits_down =
   2378       _mm512_alignr_epi32(zero, first_chunk_bits, 2);
   2379   bits = _mm512_srlv_epi64(bits, first_chunk_nbits);
   2380   nbits = _mm512_sub_epi64(nbits, first_chunk_nbits);
   2381   bits = _mm512_or_si512(bits, _mm512_sllv_epi64(first_chunk_bits_down, nbits));
   2382   begin = _mm512_add_epi64(begin, first_chunk_nbits);
   2383 
   2384   // We now know that every lane should give bits to only one chunk. We can
   2385   // shift the bits and then horizontally-or-reduce them within the same chunk.
   2386   __m512i offset = _mm512_and_si512(begin, lane_offset_mask);
   2387   __m512i aligned_bits = _mm512_sllv_epi64(bits, offset);
   2388   // h-or-reduce within same chunk
   2389   __m512i red0 = _mm512_mask_or_epi64(
   2390       aligned_bits, _mm512_cmpeq_epi64_mask(sh1(chunk_start), chunk_start),
   2391       sh1(aligned_bits), aligned_bits);
   2392   __m512i red1 = _mm512_mask_or_epi64(
   2393       red0, _mm512_cmpeq_epi64_mask(sh2(chunk_start), chunk_start), sh2(red0),
   2394       red0);
   2395   __m512i reduced = _mm512_mask_or_epi64(
   2396       red1, _mm512_cmpeq_epi64_mask(sh4(chunk_start), chunk_start), sh4(red1),
   2397       red1);
   2398   // Extract the highest lane that belongs to each chunk (the lane that ends up
   2399   // with the OR-ed value of all the other lanes of that chunk).
   2400   __m512i next_chunk_start =
   2401       _mm512_alignr_epi32(_mm512_set1_epi64(~0), chunk_start, 2);
   2402   __m512i result = _mm512_maskz_compress_epi64(
   2403       _mm512_cmpneq_epi64_mask(chunk_start, next_chunk_start), reduced);
   2404 
   2405   _mm512_storeu_si512((__m512i*)(output.data.get() + output.bytes_written),
   2406                       result);
   2407 
   2408   // Update the bit writer and add the last 32-bit lane.
   2409   // Note that since trail_nbits was at most 32 to begin with, operating on
   2410   // trail_bits does not risk overflowing.
   2411   output.bytes_written += simd_nbits / 8;
   2412   // Here we are implicitly relying on the fact that simd_nbits < 512 to know
   2413   // that the byte of bitreader data we access is initialized. This is
   2414   // guaranteed because the remaining bits in the bitreader buffer are at most
   2415   // 7, so simd_nbits <= 505 always.
   2416   trail_bits = (trail_bits << (simd_nbits % 8)) +
   2417                output.data.get()[output.bytes_written];
   2418   trail_nbits += simd_nbits % 8;
   2419   StoreLE64(output.data.get() + output.bytes_written, trail_bits);
   2420   size_t trail_bytes = trail_nbits / 8;
   2421   output.bits_in_buffer = trail_nbits % 8;
   2422   output.buffer = trail_bits >> (trail_bytes * 8);
   2423   output.bytes_written += trail_bytes;
   2424 }
   2425 
   2426 #endif
   2427 
   2428 template <size_t n>
   2429 FJXL_INLINE void StoreToWriter(const Bits32* bits, BitWriter& output) {
   2430 #ifdef FJXL_AVX512
   2431   static_assert(n <= 2, "");
   2432   StoreToWriterAVX512(bits[0], output);
   2433   if (n == 2) {
   2434     StoreToWriterAVX512(bits[1], output);
   2435   }
   2436   return;
   2437 #endif
   2438   static_assert(n <= 4, "");
   2439   alignas(64) uint64_t nbits64[Bits64::kLanes * n];
   2440   alignas(64) uint64_t bits64[Bits64::kLanes * n];
   2441   bits[0].Merge().Store(nbits64, bits64);
   2442   if (n > 1) {
   2443     bits[1].Merge().Store(nbits64 + Bits64::kLanes, bits64 + Bits64::kLanes);
   2444   }
   2445   if (n > 2) {
   2446     bits[2].Merge().Store(nbits64 + 2 * Bits64::kLanes,
   2447                           bits64 + 2 * Bits64::kLanes);
   2448   }
   2449   if (n > 3) {
   2450     bits[3].Merge().Store(nbits64 + 3 * Bits64::kLanes,
   2451                           bits64 + 3 * Bits64::kLanes);
   2452   }
   2453   output.WriteMultiple(nbits64, bits64, Bits64::kLanes * n);
   2454 }
   2455 
   2456 namespace detail {
   2457 template <typename T>
   2458 struct IntegerTypes;
   2459 
   2460 template <>
   2461 struct IntegerTypes<SIMDVec16> {
   2462   using signed_ = int16_t;
   2463   using unsigned_ = uint16_t;
   2464 };
   2465 
   2466 template <>
   2467 struct IntegerTypes<SIMDVec32> {
   2468   using signed_ = int32_t;
   2469   using unsigned_ = uint32_t;
   2470 };
   2471 
   2472 template <typename T>
   2473 struct SIMDType;
   2474 
   2475 template <>
   2476 struct SIMDType<int16_t> {
   2477   using type = SIMDVec16;
   2478 };
   2479 
   2480 template <>
   2481 struct SIMDType<int32_t> {
   2482   using type = SIMDVec32;
   2483 };
   2484 
   2485 }  // namespace detail
   2486 
   2487 template <typename T>
   2488 using signed_t = typename detail::IntegerTypes<T>::signed_;
   2489 
   2490 template <typename T>
   2491 using unsigned_t = typename detail::IntegerTypes<T>::unsigned_;
   2492 
   2493 template <typename T>
   2494 using simd_t = typename detail::SIMDType<T>::type;
   2495 
   2496 // This function will process exactly one vector worth of pixels.
   2497 
   2498 template <typename T>
   2499 size_t PredictPixels(const signed_t<T>* pixels, const signed_t<T>* pixels_left,
   2500                      const signed_t<T>* pixels_top,
   2501                      const signed_t<T>* pixels_topleft,
   2502                      unsigned_t<T>* residuals) {
   2503   T px = T::Load((unsigned_t<T>*)pixels);
   2504   T left = T::Load((unsigned_t<T>*)pixels_left);
   2505   T top = T::Load((unsigned_t<T>*)pixels_top);
   2506   T topleft = T::Load((unsigned_t<T>*)pixels_topleft);
   2507   T ac = left.Sub(topleft);
   2508   T ab = left.Sub(top);
   2509   T bc = top.Sub(topleft);
   2510   T grad = ac.Add(top);
   2511   T d = ab.Xor(bc);
   2512   T zero = T::Val(0);
   2513   T clamp = zero.Gt(d).IfThenElse(top, left);
   2514   T s = ac.Xor(bc);
   2515   T pred = zero.Gt(s).IfThenElse(grad, clamp);
   2516   T res = px.Sub(pred);
   2517   T res_times_2 = res.Add(res);
   2518   res = zero.Gt(res).IfThenElse(T::Val(-1).Sub(res_times_2), res_times_2);
   2519   res.Store(residuals);
   2520   return res.Eq(T::Val(0)).CountPrefix();
   2521 }
   2522 
   2523 #endif
   2524 
   2525 void EncodeHybridUint000(uint32_t value, uint32_t* token, uint32_t* nbits,
   2526                          uint32_t* bits) {
   2527   uint32_t n = FloorLog2(value);
   2528   *token = value ? n + 1 : 0;
   2529   *nbits = value ? n : 0;
   2530   *bits = value ? value - (1 << n) : 0;
   2531 }
   2532 
   2533 #ifdef FJXL_AVX512
   2534 constexpr static size_t kLogChunkSize = 5;
   2535 #elif defined(FJXL_AVX2) || defined(FJXL_NEON)
   2536 // Even if NEON only has 128-bit lanes, it is still significantly (~1.3x) faster
   2537 // to process two vectors at a time.
   2538 constexpr static size_t kLogChunkSize = 4;
   2539 #else
   2540 constexpr static size_t kLogChunkSize = 3;
   2541 #endif
   2542 
   2543 constexpr static size_t kChunkSize = 1 << kLogChunkSize;
   2544 
   2545 template <typename Residual>
   2546 void GenericEncodeChunk(const Residual* residuals, size_t n, size_t skip,
   2547                         const PrefixCode& code, BitWriter& output) {
   2548   for (size_t ix = skip; ix < n; ix++) {
   2549     unsigned token, nbits, bits;
   2550     EncodeHybridUint000(residuals[ix], &token, &nbits, &bits);
   2551     output.Write(code.raw_nbits[token] + nbits,
   2552                  code.raw_bits[token] | bits << code.raw_nbits[token]);
   2553   }
   2554 }
   2555 
   2556 struct UpTo8Bits {
   2557   size_t bitdepth;
   2558   explicit UpTo8Bits(size_t bitdepth) : bitdepth(bitdepth) {
   2559     assert(bitdepth <= 8);
   2560   }
   2561   // Here we can fit up to 9 extra bits + 7 Huffman bits in a u16; for all other
   2562   // symbols, we could actually go up to 8 Huffman bits as we have at most 8
   2563   // extra bits; however, the SIMD bit merging logic for AVX2 assumes that no
   2564   // Huffman length is 8 or more, so we cap at 8 anyway. Last symbol is used for
   2565   // LZ77 lengths and has no limitations except allowing to represent 32 symbols
   2566   // in total.
   2567   static constexpr uint8_t kMinRawLength[12] = {};
   2568   static constexpr uint8_t kMaxRawLength[12] = {
   2569       7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 10,
   2570   };
   2571   static size_t MaxEncodedBitsPerSample() { return 16; }
   2572   static constexpr size_t kInputBytes = 1;
   2573   using pixel_t = int16_t;
   2574   using upixel_t = uint16_t;
   2575 
   2576   static void PrepareForSimd(const uint8_t* nbits, const uint8_t* bits,
   2577                              size_t n, uint8_t* nbits_simd,
   2578                              uint8_t* bits_simd) {
   2579     assert(n <= 16);
   2580     memcpy(nbits_simd, nbits, 16);
   2581     memcpy(bits_simd, bits, 16);
   2582   }
   2583 
   2584 #ifdef FJXL_GENERIC_SIMD
   2585   static void EncodeChunkSimd(upixel_t* residuals, size_t n, size_t skip,
   2586                               const uint8_t* raw_nbits_simd,
   2587                               const uint8_t* raw_bits_simd, BitWriter& output) {
   2588     Bits32 bits32[kChunkSize / SIMDVec16::kLanes];
   2589     alignas(64) uint16_t bits[SIMDVec16::kLanes];
   2590     alignas(64) uint16_t nbits[SIMDVec16::kLanes];
   2591     alignas(64) uint16_t bits_huff[SIMDVec16::kLanes];
   2592     alignas(64) uint16_t nbits_huff[SIMDVec16::kLanes];
   2593     alignas(64) uint16_t token[SIMDVec16::kLanes];
   2594     for (size_t i = 0; i < kChunkSize; i += SIMDVec16::kLanes) {
   2595       TokenizeSIMD(residuals + i, token, nbits, bits);
   2596       HuffmanSIMDUpTo13(token, raw_nbits_simd, raw_bits_simd, nbits_huff,
   2597                         bits_huff);
   2598       StoreSIMDUpTo8(nbits, bits, nbits_huff, bits_huff, std::max(n, i) - i,
   2599                      std::max(skip, i) - i, bits32 + i / SIMDVec16::kLanes);
   2600     }
   2601     StoreToWriter<kChunkSize / SIMDVec16::kLanes>(bits32, output);
   2602   }
   2603 #endif
   2604 
   2605   size_t NumSymbols(bool doing_ycocg_or_large_palette) const {
   2606     // values gain 1 bit for YCoCg, 1 bit for prediction.
   2607     // Maximum symbol is 1 + effective bit depth of residuals.
   2608     if (doing_ycocg_or_large_palette) {
   2609       return bitdepth + 3;
   2610     } else {
   2611       return bitdepth + 2;
   2612     }
   2613   }
   2614 };
   2615 constexpr uint8_t UpTo8Bits::kMinRawLength[];
   2616 constexpr uint8_t UpTo8Bits::kMaxRawLength[];
   2617 
   2618 struct From9To13Bits {
   2619   size_t bitdepth;
   2620   explicit From9To13Bits(size_t bitdepth) : bitdepth(bitdepth) {
   2621     assert(bitdepth <= 13 && bitdepth >= 9);
   2622   }
   2623   // Last symbol is used for LZ77 lengths and has no limitations except allowing
   2624   // to represent 32 symbols in total.
   2625   // We cannot fit all the bits in a u16, so do not even try and use up to 8
   2626   // bits per raw symbol.
   2627   // There are at most 16 raw symbols, so Huffman coding can be SIMDfied without
   2628   // any special tricks.
   2629   static constexpr uint8_t kMinRawLength[17] = {};
   2630   static constexpr uint8_t kMaxRawLength[17] = {
   2631       8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 10,
   2632   };
   2633   static size_t MaxEncodedBitsPerSample() { return 21; }
   2634   static constexpr size_t kInputBytes = 2;
   2635   using pixel_t = int16_t;
   2636   using upixel_t = uint16_t;
   2637 
   2638   static void PrepareForSimd(const uint8_t* nbits, const uint8_t* bits,
   2639                              size_t n, uint8_t* nbits_simd,
   2640                              uint8_t* bits_simd) {
   2641     assert(n <= 16);
   2642     memcpy(nbits_simd, nbits, 16);
   2643     memcpy(bits_simd, bits, 16);
   2644   }
   2645 
   2646 #ifdef FJXL_GENERIC_SIMD
   2647   static void EncodeChunkSimd(upixel_t* residuals, size_t n, size_t skip,
   2648                               const uint8_t* raw_nbits_simd,
   2649                               const uint8_t* raw_bits_simd, BitWriter& output) {
   2650     Bits32 bits32[2 * kChunkSize / SIMDVec16::kLanes];
   2651     alignas(64) uint16_t bits[SIMDVec16::kLanes];
   2652     alignas(64) uint16_t nbits[SIMDVec16::kLanes];
   2653     alignas(64) uint16_t bits_huff[SIMDVec16::kLanes];
   2654     alignas(64) uint16_t nbits_huff[SIMDVec16::kLanes];
   2655     alignas(64) uint16_t token[SIMDVec16::kLanes];
   2656     for (size_t i = 0; i < kChunkSize; i += SIMDVec16::kLanes) {
   2657       TokenizeSIMD(residuals + i, token, nbits, bits);
   2658       HuffmanSIMDUpTo13(token, raw_nbits_simd, raw_bits_simd, nbits_huff,
   2659                         bits_huff);
   2660       StoreSIMDUpTo14(nbits, bits, nbits_huff, bits_huff, std::max(n, i) - i,
   2661                       std::max(skip, i) - i,
   2662                       bits32 + 2 * i / SIMDVec16::kLanes);
   2663     }
   2664     StoreToWriter<2 * kChunkSize / SIMDVec16::kLanes>(bits32, output);
   2665   }
   2666 #endif
   2667 
   2668   size_t NumSymbols(bool doing_ycocg_or_large_palette) const {
   2669     // values gain 1 bit for YCoCg, 1 bit for prediction.
   2670     // Maximum symbol is 1 + effective bit depth of residuals.
   2671     if (doing_ycocg_or_large_palette) {
   2672       return bitdepth + 3;
   2673     } else {
   2674       return bitdepth + 2;
   2675     }
   2676   }
   2677 };
   2678 constexpr uint8_t From9To13Bits::kMinRawLength[];
   2679 constexpr uint8_t From9To13Bits::kMaxRawLength[];
   2680 
   2681 void CheckHuffmanBitsSIMD(int bits1, int nbits1, int bits2, int nbits2) {
   2682   assert(nbits1 == 8);
   2683   assert(nbits2 == 8);
   2684   assert(bits2 == (bits1 | 128));
   2685 }
   2686 
   2687 struct Exactly14Bits {
   2688   explicit Exactly14Bits(size_t bitdepth) { assert(bitdepth == 14); }
   2689   // Force LZ77 symbols to have at least 8 bits, and raw symbols 15 and 16 to
   2690   // have exactly 8, and no other symbol to have 8 or more. This ensures that
   2691   // the representation for 15 and 16 is identical up to one bit.
   2692   static constexpr uint8_t kMinRawLength[18] = {
   2693       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 7,
   2694   };
   2695   static constexpr uint8_t kMaxRawLength[18] = {
   2696       7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 10,
   2697   };
   2698   static constexpr size_t bitdepth = 14;
   2699   static size_t MaxEncodedBitsPerSample() { return 22; }
   2700   static constexpr size_t kInputBytes = 2;
   2701   using pixel_t = int16_t;
   2702   using upixel_t = uint16_t;
   2703 
   2704   static void PrepareForSimd(const uint8_t* nbits, const uint8_t* bits,
   2705                              size_t n, uint8_t* nbits_simd,
   2706                              uint8_t* bits_simd) {
   2707     assert(n == 17);
   2708     CheckHuffmanBitsSIMD(bits[15], nbits[15], bits[16], nbits[16]);
   2709     memcpy(nbits_simd, nbits, 16);
   2710     memcpy(bits_simd, bits, 16);
   2711   }
   2712 
   2713 #ifdef FJXL_GENERIC_SIMD
   2714   static void EncodeChunkSimd(upixel_t* residuals, size_t n, size_t skip,
   2715                               const uint8_t* raw_nbits_simd,
   2716                               const uint8_t* raw_bits_simd, BitWriter& output) {
   2717     Bits32 bits32[2 * kChunkSize / SIMDVec16::kLanes];
   2718     alignas(64) uint16_t bits[SIMDVec16::kLanes];
   2719     alignas(64) uint16_t nbits[SIMDVec16::kLanes];
   2720     alignas(64) uint16_t bits_huff[SIMDVec16::kLanes];
   2721     alignas(64) uint16_t nbits_huff[SIMDVec16::kLanes];
   2722     alignas(64) uint16_t token[SIMDVec16::kLanes];
   2723     for (size_t i = 0; i < kChunkSize; i += SIMDVec16::kLanes) {
   2724       TokenizeSIMD(residuals + i, token, nbits, bits);
   2725       HuffmanSIMD14(token, raw_nbits_simd, raw_bits_simd, nbits_huff,
   2726                     bits_huff);
   2727       StoreSIMDUpTo14(nbits, bits, nbits_huff, bits_huff, std::max(n, i) - i,
   2728                       std::max(skip, i) - i,
   2729                       bits32 + 2 * i / SIMDVec16::kLanes);
   2730     }
   2731     StoreToWriter<2 * kChunkSize / SIMDVec16::kLanes>(bits32, output);
   2732   }
   2733 #endif
   2734 
   2735   size_t NumSymbols(bool) const { return 17; }
   2736 };
   2737 constexpr uint8_t Exactly14Bits::kMinRawLength[];
   2738 constexpr uint8_t Exactly14Bits::kMaxRawLength[];
   2739 
   2740 struct MoreThan14Bits {
   2741   size_t bitdepth;
   2742   explicit MoreThan14Bits(size_t bitdepth) : bitdepth(bitdepth) {
   2743     assert(bitdepth > 14);
   2744     assert(bitdepth <= 16);
   2745   }
   2746   // Force LZ77 symbols to have at least 8 bits, and raw symbols 13 to 18 to
   2747   // have exactly 8, and no other symbol to have 8 or more. This ensures that
   2748   // the representation for (13, 14), (15, 16), (17, 18) is identical up to one
   2749   // bit.
   2750   static constexpr uint8_t kMinRawLength[20] = {
   2751       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 8, 7,
   2752   };
   2753   static constexpr uint8_t kMaxRawLength[20] = {
   2754       7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 10,
   2755   };
   2756   static size_t MaxEncodedBitsPerSample() { return 24; }
   2757   static constexpr size_t kInputBytes = 2;
   2758   using pixel_t = int32_t;
   2759   using upixel_t = uint32_t;
   2760 
   2761   static void PrepareForSimd(const uint8_t* nbits, const uint8_t* bits,
   2762                              size_t n, uint8_t* nbits_simd,
   2763                              uint8_t* bits_simd) {
   2764     assert(n == 19);
   2765     CheckHuffmanBitsSIMD(bits[13], nbits[13], bits[14], nbits[14]);
   2766     CheckHuffmanBitsSIMD(bits[15], nbits[15], bits[16], nbits[16]);
   2767     CheckHuffmanBitsSIMD(bits[17], nbits[17], bits[18], nbits[18]);
   2768     for (size_t i = 0; i < 14; i++) {
   2769       nbits_simd[i] = nbits[i];
   2770       bits_simd[i] = bits[i];
   2771     }
   2772     nbits_simd[14] = nbits[15];
   2773     bits_simd[14] = bits[15];
   2774     nbits_simd[15] = nbits[17];
   2775     bits_simd[15] = bits[17];
   2776   }
   2777 
   2778 #ifdef FJXL_GENERIC_SIMD
   2779   static void EncodeChunkSimd(upixel_t* residuals, size_t n, size_t skip,
   2780                               const uint8_t* raw_nbits_simd,
   2781                               const uint8_t* raw_bits_simd, BitWriter& output) {
   2782     Bits32 bits32[2 * kChunkSize / SIMDVec16::kLanes];
   2783     alignas(64) uint32_t bits[SIMDVec16::kLanes];
   2784     alignas(64) uint32_t nbits[SIMDVec16::kLanes];
   2785     alignas(64) uint16_t bits_huff[SIMDVec16::kLanes];
   2786     alignas(64) uint16_t nbits_huff[SIMDVec16::kLanes];
   2787     alignas(64) uint16_t token[SIMDVec16::kLanes];
   2788     for (size_t i = 0; i < kChunkSize; i += SIMDVec16::kLanes) {
   2789       TokenizeSIMD(residuals + i, token, nbits, bits);
   2790       HuffmanSIMDAbove14(token, raw_nbits_simd, raw_bits_simd, nbits_huff,
   2791                          bits_huff);
   2792       StoreSIMDAbove14(nbits, bits, nbits_huff, bits_huff, std::max(n, i) - i,
   2793                        std::max(skip, i) - i,
   2794                        bits32 + 2 * i / SIMDVec16::kLanes);
   2795     }
   2796     StoreToWriter<2 * kChunkSize / SIMDVec16::kLanes>(bits32, output);
   2797   }
   2798 #endif
   2799   size_t NumSymbols(bool) const { return 19; }
   2800 };
   2801 constexpr uint8_t MoreThan14Bits::kMinRawLength[];
   2802 constexpr uint8_t MoreThan14Bits::kMaxRawLength[];
   2803 
   2804 void PrepareDCGlobalCommon(bool is_single_group, size_t width, size_t height,
   2805                            const PrefixCode code[4], BitWriter* output) {
   2806   output->Allocate(100000 + (is_single_group ? width * height * 16 : 0));
   2807   // No patches, spline or noise.
   2808   output->Write(1, 1);  // default DC dequantization factors (?)
   2809   output->Write(1, 1);  // use global tree / histograms
   2810   output->Write(1, 0);  // no lz77 for the tree
   2811 
   2812   output->Write(1, 1);         // simple code for the tree's context map
   2813   output->Write(2, 0);         // all contexts clustered together
   2814   output->Write(1, 1);         // use prefix code for tree
   2815   output->Write(4, 0);         // 000 hybrid uint
   2816   output->Write(6, 0b100011);  // Alphabet size is 4 (var16)
   2817   output->Write(2, 1);         // simple prefix code
   2818   output->Write(2, 3);         // with 4 symbols
   2819   output->Write(2, 0);
   2820   output->Write(2, 1);
   2821   output->Write(2, 2);
   2822   output->Write(2, 3);
   2823   output->Write(1, 0);  // First tree encoding option
   2824 
   2825   // Huffman table + extra bits for the tree.
   2826   uint8_t symbol_bits[6] = {0b00, 0b10, 0b001, 0b101, 0b0011, 0b0111};
   2827   uint8_t symbol_nbits[6] = {2, 2, 3, 3, 4, 4};
   2828   // Write a tree with a leaf per channel, and gradient predictor for every
   2829   // leaf.
   2830   for (auto v : {1, 2, 1, 4, 1, 0, 0, 5, 0, 0, 0, 0, 5,
   2831                  0, 0, 0, 0, 5, 0, 0, 0, 0, 5, 0, 0, 0}) {
   2832     output->Write(symbol_nbits[v], symbol_bits[v]);
   2833   }
   2834 
   2835   output->Write(1, 1);     // Enable lz77 for the main bitstream
   2836   output->Write(2, 0b00);  // lz77 offset 224
   2837   static_assert(kLZ77Offset == 224, "");
   2838   output->Write(4, 0b1010);  // lz77 min length 7
   2839   // 400 hybrid uint config for lz77
   2840   output->Write(4, 4);
   2841   output->Write(3, 0);
   2842   output->Write(3, 0);
   2843 
   2844   output->Write(1, 1);  // simple code for the context map
   2845   output->Write(2, 3);  // 3 bits per entry
   2846   output->Write(3, 4);  // channel 3
   2847   output->Write(3, 3);  // channel 2
   2848   output->Write(3, 2);  // channel 1
   2849   output->Write(3, 1);  // channel 0
   2850   output->Write(3, 0);  // distance histogram first
   2851 
   2852   output->Write(1, 1);  // use prefix codes
   2853   output->Write(4, 0);  // 000 hybrid uint config for distances (only need 0)
   2854   for (size_t i = 0; i < 4; i++) {
   2855     output->Write(4, 0);  // 000 hybrid uint config for symbols (only <= 10)
   2856   }
   2857 
   2858   // Distance alphabet size:
   2859   output->Write(5, 0b00001);  // 2: just need 1 for RLE (i.e. distance 1)
   2860   // Symbol + LZ77 alphabet size:
   2861   for (size_t i = 0; i < 4; i++) {
   2862     output->Write(1, 1);    // > 1
   2863     output->Write(4, 8);    // <= 512
   2864     output->Write(8, 256);  // == 512
   2865   }
   2866 
   2867   // Distance histogram:
   2868   output->Write(2, 1);  // simple prefix code
   2869   output->Write(2, 0);  // with one symbol
   2870   output->Write(1, 1);  // 1
   2871 
   2872   // Symbol + lz77 histogram:
   2873   for (size_t i = 0; i < 4; i++) {
   2874     code[i].WriteTo(output);
   2875   }
   2876 
   2877   // Group header for global modular image.
   2878   output->Write(1, 1);  // Global tree
   2879   output->Write(1, 1);  // All default wp
   2880 }
   2881 
   2882 void PrepareDCGlobal(bool is_single_group, size_t width, size_t height,
   2883                      size_t nb_chans, const PrefixCode code[4],
   2884                      BitWriter* output) {
   2885   PrepareDCGlobalCommon(is_single_group, width, height, code, output);
   2886   if (nb_chans > 2) {
   2887     output->Write(2, 0b01);     // 1 transform
   2888     output->Write(2, 0b00);     // RCT
   2889     output->Write(5, 0b00000);  // Starting from ch 0
   2890     output->Write(2, 0b00);     // YCoCg
   2891   } else {
   2892     output->Write(2, 0b00);  // no transforms
   2893   }
   2894   if (!is_single_group) {
   2895     output->ZeroPadToByte();
   2896   }
   2897 }
   2898 
   2899 template <typename BitDepth>
   2900 struct ChunkEncoder {
   2901   void PrepareForSimd() {
   2902     BitDepth::PrepareForSimd(code->raw_nbits, code->raw_bits, code->numraw,
   2903                              raw_nbits_simd, raw_bits_simd);
   2904   }
   2905   FJXL_INLINE static void EncodeRle(size_t count, const PrefixCode& code,
   2906                                     BitWriter& output) {
   2907     if (count == 0) return;
   2908     count -= kLZ77MinLength + 1;
   2909     if (count < kLZ77CacheSize) {
   2910       output.Write(code.lz77_cache_nbits[count], code.lz77_cache_bits[count]);
   2911     } else {
   2912       unsigned token, nbits, bits;
   2913       EncodeHybridUintLZ77(count, &token, &nbits, &bits);
   2914       uint64_t wbits = bits;
   2915       wbits = (wbits << code.lz77_nbits[token]) | code.lz77_bits[token];
   2916       wbits = (wbits << code.raw_nbits[0]) | code.raw_bits[0];
   2917       output.Write(code.lz77_nbits[token] + nbits + code.raw_nbits[0], wbits);
   2918     }
   2919   }
   2920 
   2921   FJXL_INLINE void Chunk(size_t run, typename BitDepth::upixel_t* residuals,
   2922                          size_t skip, size_t n) {
   2923     EncodeRle(run, *code, *output);
   2924 #ifdef FJXL_GENERIC_SIMD
   2925     BitDepth::EncodeChunkSimd(residuals, n, skip, raw_nbits_simd, raw_bits_simd,
   2926                               *output);
   2927 #else
   2928     GenericEncodeChunk(residuals, n, skip, *code, *output);
   2929 #endif
   2930   }
   2931 
   2932   inline void Finalize(size_t run) { EncodeRle(run, *code, *output); }
   2933 
   2934   const PrefixCode* code;
   2935   BitWriter* output;
   2936   alignas(64) uint8_t raw_nbits_simd[16] = {};
   2937   alignas(64) uint8_t raw_bits_simd[16] = {};
   2938 };
   2939 
   2940 template <typename BitDepth>
   2941 struct ChunkSampleCollector {
   2942   FJXL_INLINE void Rle(size_t count, uint64_t* lz77_counts) {
   2943     if (count == 0) return;
   2944     raw_counts[0] += 1;
   2945     count -= kLZ77MinLength + 1;
   2946     unsigned token, nbits, bits;
   2947     EncodeHybridUintLZ77(count, &token, &nbits, &bits);
   2948     lz77_counts[token]++;
   2949   }
   2950 
   2951   FJXL_INLINE void Chunk(size_t run, typename BitDepth::upixel_t* residuals,
   2952                          size_t skip, size_t n) {
   2953     // Run is broken. Encode the run and encode the individual vector.
   2954     Rle(run, lz77_counts);
   2955     for (size_t ix = skip; ix < n; ix++) {
   2956       unsigned token, nbits, bits;
   2957       EncodeHybridUint000(residuals[ix], &token, &nbits, &bits);
   2958       raw_counts[token]++;
   2959     }
   2960   }
   2961 
   2962   // don't count final run since we don't know how long it really is
   2963   void Finalize(size_t run) {}
   2964 
   2965   uint64_t* raw_counts;
   2966   uint64_t* lz77_counts;
   2967 };
   2968 
   2969 constexpr uint32_t PackSigned(int32_t value) {
   2970   return (static_cast<uint32_t>(value) << 1) ^
   2971          ((static_cast<uint32_t>(~value) >> 31) - 1);
   2972 }
   2973 
   2974 template <typename T, typename BitDepth>
   2975 struct ChannelRowProcessor {
   2976   using upixel_t = typename BitDepth::upixel_t;
   2977   using pixel_t = typename BitDepth::pixel_t;
   2978   T* t;
   2979   void ProcessChunk(const pixel_t* row, const pixel_t* row_left,
   2980                     const pixel_t* row_top, const pixel_t* row_topleft,
   2981                     size_t n) {
   2982     alignas(64) upixel_t residuals[kChunkSize] = {};
   2983     size_t prefix_size = 0;
   2984     size_t required_prefix_size = 0;
   2985 #ifdef FJXL_GENERIC_SIMD
   2986     constexpr size_t kNum =
   2987         sizeof(pixel_t) == 2 ? SIMDVec16::kLanes : SIMDVec32::kLanes;
   2988     for (size_t ix = 0; ix < kChunkSize; ix += kNum) {
   2989       size_t c =
   2990           PredictPixels<simd_t<pixel_t>>(row + ix, row_left + ix, row_top + ix,
   2991                                          row_topleft + ix, residuals + ix);
   2992       prefix_size =
   2993           prefix_size == required_prefix_size ? prefix_size + c : prefix_size;
   2994       required_prefix_size += kNum;
   2995     }
   2996 #else
   2997     for (size_t ix = 0; ix < kChunkSize; ix++) {
   2998       pixel_t px = row[ix];
   2999       pixel_t left = row_left[ix];
   3000       pixel_t top = row_top[ix];
   3001       pixel_t topleft = row_topleft[ix];
   3002       pixel_t ac = left - topleft;
   3003       pixel_t ab = left - top;
   3004       pixel_t bc = top - topleft;
   3005       pixel_t grad = static_cast<pixel_t>(static_cast<upixel_t>(ac) +
   3006                                           static_cast<upixel_t>(top));
   3007       pixel_t d = ab ^ bc;
   3008       pixel_t clamp = d < 0 ? top : left;
   3009       pixel_t s = ac ^ bc;
   3010       pixel_t pred = s < 0 ? grad : clamp;
   3011       residuals[ix] = PackSigned(px - pred);
   3012       prefix_size = prefix_size == required_prefix_size
   3013                         ? prefix_size + (residuals[ix] == 0)
   3014                         : prefix_size;
   3015       required_prefix_size += 1;
   3016     }
   3017 #endif
   3018     prefix_size = std::min(n, prefix_size);
   3019     if (prefix_size == n && (run > 0 || prefix_size > kLZ77MinLength)) {
   3020       // Run continues, nothing to do.
   3021       run += prefix_size;
   3022     } else if (prefix_size + run > kLZ77MinLength) {
   3023       // Run is broken. Encode the run and encode the individual vector.
   3024       t->Chunk(run + prefix_size, residuals, prefix_size, n);
   3025       run = 0;
   3026     } else {
   3027       // There was no run to begin with.
   3028       t->Chunk(0, residuals, 0, n);
   3029     }
   3030   }
   3031 
   3032   void ProcessRow(const pixel_t* row, const pixel_t* row_left,
   3033                   const pixel_t* row_top, const pixel_t* row_topleft,
   3034                   size_t xs) {
   3035     for (size_t x = 0; x < xs; x += kChunkSize) {
   3036       ProcessChunk(row + x, row_left + x, row_top + x, row_topleft + x,
   3037                    std::min(kChunkSize, xs - x));
   3038     }
   3039   }
   3040 
   3041   void Finalize() { t->Finalize(run); }
   3042   // Invariant: run == 0 or run > kLZ77MinLength.
   3043   size_t run = 0;
   3044 };
   3045 
   3046 uint16_t LoadLE16(const unsigned char* ptr) {
   3047   return uint16_t{ptr[0]} | (uint16_t{ptr[1]} << 8);
   3048 }
   3049 
   3050 uint16_t SwapEndian(uint16_t in) { return (in >> 8) | (in << 8); }
   3051 
   3052 #ifdef FJXL_GENERIC_SIMD
   3053 void StorePixels(SIMDVec16 p, int16_t* dest) { p.Store((uint16_t*)dest); }
   3054 
   3055 void StorePixels(SIMDVec16 p, int32_t* dest) {
   3056   VecPair<SIMDVec32> p_up = p.Upcast();
   3057   p_up.low.Store((uint32_t*)dest);
   3058   p_up.hi.Store((uint32_t*)dest + SIMDVec32::kLanes);
   3059 }
   3060 #endif
   3061 
   3062 template <typename pixel_t>
   3063 void FillRowG8(const unsigned char* rgba, size_t oxs, pixel_t* luma) {
   3064   size_t x = 0;
   3065 #ifdef FJXL_GENERIC_SIMD
   3066   for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) {
   3067     auto rgb = SIMDVec16::LoadG8(rgba + x);
   3068     StorePixels(rgb[0], luma + x);
   3069   }
   3070 #endif
   3071   for (; x < oxs; x++) {
   3072     luma[x] = rgba[x];
   3073   }
   3074 }
   3075 
   3076 template <bool big_endian, typename pixel_t>
   3077 void FillRowG16(const unsigned char* rgba, size_t oxs, pixel_t* luma) {
   3078   size_t x = 0;
   3079 #ifdef FJXL_GENERIC_SIMD
   3080   for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) {
   3081     auto rgb = SIMDVec16::LoadG16(rgba + 2 * x);
   3082     if (big_endian) {
   3083       rgb[0].SwapEndian();
   3084     }
   3085     StorePixels(rgb[0], luma + x);
   3086   }
   3087 #endif
   3088   for (; x < oxs; x++) {
   3089     uint16_t val = LoadLE16(rgba + 2 * x);
   3090     if (big_endian) {
   3091       val = SwapEndian(val);
   3092     }
   3093     luma[x] = val;
   3094   }
   3095 }
   3096 
   3097 template <typename pixel_t>
   3098 void FillRowGA8(const unsigned char* rgba, size_t oxs, pixel_t* luma,
   3099                 pixel_t* alpha) {
   3100   size_t x = 0;
   3101 #ifdef FJXL_GENERIC_SIMD
   3102   for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) {
   3103     auto rgb = SIMDVec16::LoadGA8(rgba + 2 * x);
   3104     StorePixels(rgb[0], luma + x);
   3105     StorePixels(rgb[1], alpha + x);
   3106   }
   3107 #endif
   3108   for (; x < oxs; x++) {
   3109     luma[x] = rgba[2 * x];
   3110     alpha[x] = rgba[2 * x + 1];
   3111   }
   3112 }
   3113 
   3114 template <bool big_endian, typename pixel_t>
   3115 void FillRowGA16(const unsigned char* rgba, size_t oxs, pixel_t* luma,
   3116                  pixel_t* alpha) {
   3117   size_t x = 0;
   3118 #ifdef FJXL_GENERIC_SIMD
   3119   for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) {
   3120     auto rgb = SIMDVec16::LoadGA16(rgba + 4 * x);
   3121     if (big_endian) {
   3122       rgb[0].SwapEndian();
   3123       rgb[1].SwapEndian();
   3124     }
   3125     StorePixels(rgb[0], luma + x);
   3126     StorePixels(rgb[1], alpha + x);
   3127   }
   3128 #endif
   3129   for (; x < oxs; x++) {
   3130     uint16_t l = LoadLE16(rgba + 4 * x);
   3131     uint16_t a = LoadLE16(rgba + 4 * x + 2);
   3132     if (big_endian) {
   3133       l = SwapEndian(l);
   3134       a = SwapEndian(a);
   3135     }
   3136     luma[x] = l;
   3137     alpha[x] = a;
   3138   }
   3139 }
   3140 
   3141 template <typename pixel_t>
   3142 void StoreYCoCg(pixel_t r, pixel_t g, pixel_t b, pixel_t* y, pixel_t* co,
   3143                 pixel_t* cg) {
   3144   *co = r - b;
   3145   pixel_t tmp = b + (*co >> 1);
   3146   *cg = g - tmp;
   3147   *y = tmp + (*cg >> 1);
   3148 }
   3149 
   3150 #ifdef FJXL_GENERIC_SIMD
   3151 void StoreYCoCg(SIMDVec16 r, SIMDVec16 g, SIMDVec16 b, int16_t* y, int16_t* co,
   3152                 int16_t* cg) {
   3153   SIMDVec16 co_v = r.Sub(b);
   3154   SIMDVec16 tmp = b.Add(co_v.SignedShiftRight<1>());
   3155   SIMDVec16 cg_v = g.Sub(tmp);
   3156   SIMDVec16 y_v = tmp.Add(cg_v.SignedShiftRight<1>());
   3157   y_v.Store(reinterpret_cast<uint16_t*>(y));
   3158   co_v.Store(reinterpret_cast<uint16_t*>(co));
   3159   cg_v.Store(reinterpret_cast<uint16_t*>(cg));
   3160 }
   3161 
   3162 void StoreYCoCg(SIMDVec16 r, SIMDVec16 g, SIMDVec16 b, int32_t* y, int32_t* co,
   3163                 int32_t* cg) {
   3164   VecPair<SIMDVec32> r_up = r.Upcast();
   3165   VecPair<SIMDVec32> g_up = g.Upcast();
   3166   VecPair<SIMDVec32> b_up = b.Upcast();
   3167   SIMDVec32 co_lo_v = r_up.low.Sub(b_up.low);
   3168   SIMDVec32 tmp_lo = b_up.low.Add(co_lo_v.SignedShiftRight<1>());
   3169   SIMDVec32 cg_lo_v = g_up.low.Sub(tmp_lo);
   3170   SIMDVec32 y_lo_v = tmp_lo.Add(cg_lo_v.SignedShiftRight<1>());
   3171   SIMDVec32 co_hi_v = r_up.hi.Sub(b_up.hi);
   3172   SIMDVec32 tmp_hi = b_up.hi.Add(co_hi_v.SignedShiftRight<1>());
   3173   SIMDVec32 cg_hi_v = g_up.hi.Sub(tmp_hi);
   3174   SIMDVec32 y_hi_v = tmp_hi.Add(cg_hi_v.SignedShiftRight<1>());
   3175   y_lo_v.Store(reinterpret_cast<uint32_t*>(y));
   3176   co_lo_v.Store(reinterpret_cast<uint32_t*>(co));
   3177   cg_lo_v.Store(reinterpret_cast<uint32_t*>(cg));
   3178   y_hi_v.Store(reinterpret_cast<uint32_t*>(y) + SIMDVec32::kLanes);
   3179   co_hi_v.Store(reinterpret_cast<uint32_t*>(co) + SIMDVec32::kLanes);
   3180   cg_hi_v.Store(reinterpret_cast<uint32_t*>(cg) + SIMDVec32::kLanes);
   3181 }
   3182 #endif
   3183 
   3184 template <typename pixel_t>
   3185 void FillRowRGB8(const unsigned char* rgba, size_t oxs, pixel_t* y, pixel_t* co,
   3186                  pixel_t* cg) {
   3187   size_t x = 0;
   3188 #ifdef FJXL_GENERIC_SIMD
   3189   for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) {
   3190     auto rgb = SIMDVec16::LoadRGB8(rgba + 3 * x);
   3191     StoreYCoCg(rgb[0], rgb[1], rgb[2], y + x, co + x, cg + x);
   3192   }
   3193 #endif
   3194   for (; x < oxs; x++) {
   3195     uint16_t r = rgba[3 * x];
   3196     uint16_t g = rgba[3 * x + 1];
   3197     uint16_t b = rgba[3 * x + 2];
   3198     StoreYCoCg<pixel_t>(r, g, b, y + x, co + x, cg + x);
   3199   }
   3200 }
   3201 
   3202 template <bool big_endian, typename pixel_t>
   3203 void FillRowRGB16(const unsigned char* rgba, size_t oxs, pixel_t* y,
   3204                   pixel_t* co, pixel_t* cg) {
   3205   size_t x = 0;
   3206 #ifdef FJXL_GENERIC_SIMD
   3207   for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) {
   3208     auto rgb = SIMDVec16::LoadRGB16(rgba + 6 * x);
   3209     if (big_endian) {
   3210       rgb[0].SwapEndian();
   3211       rgb[1].SwapEndian();
   3212       rgb[2].SwapEndian();
   3213     }
   3214     StoreYCoCg(rgb[0], rgb[1], rgb[2], y + x, co + x, cg + x);
   3215   }
   3216 #endif
   3217   for (; x < oxs; x++) {
   3218     uint16_t r = LoadLE16(rgba + 6 * x);
   3219     uint16_t g = LoadLE16(rgba + 6 * x + 2);
   3220     uint16_t b = LoadLE16(rgba + 6 * x + 4);
   3221     if (big_endian) {
   3222       r = SwapEndian(r);
   3223       g = SwapEndian(g);
   3224       b = SwapEndian(b);
   3225     }
   3226     StoreYCoCg<pixel_t>(r, g, b, y + x, co + x, cg + x);
   3227   }
   3228 }
   3229 
   3230 template <typename pixel_t>
   3231 void FillRowRGBA8(const unsigned char* rgba, size_t oxs, pixel_t* y,
   3232                   pixel_t* co, pixel_t* cg, pixel_t* alpha) {
   3233   size_t x = 0;
   3234 #ifdef FJXL_GENERIC_SIMD
   3235   for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) {
   3236     auto rgb = SIMDVec16::LoadRGBA8(rgba + 4 * x);
   3237     StoreYCoCg(rgb[0], rgb[1], rgb[2], y + x, co + x, cg + x);
   3238     StorePixels(rgb[3], alpha + x);
   3239   }
   3240 #endif
   3241   for (; x < oxs; x++) {
   3242     uint16_t r = rgba[4 * x];
   3243     uint16_t g = rgba[4 * x + 1];
   3244     uint16_t b = rgba[4 * x + 2];
   3245     uint16_t a = rgba[4 * x + 3];
   3246     StoreYCoCg<pixel_t>(r, g, b, y + x, co + x, cg + x);
   3247     alpha[x] = a;
   3248   }
   3249 }
   3250 
   3251 template <bool big_endian, typename pixel_t>
   3252 void FillRowRGBA16(const unsigned char* rgba, size_t oxs, pixel_t* y,
   3253                    pixel_t* co, pixel_t* cg, pixel_t* alpha) {
   3254   size_t x = 0;
   3255 #ifdef FJXL_GENERIC_SIMD
   3256   for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) {
   3257     auto rgb = SIMDVec16::LoadRGBA16(rgba + 8 * x);
   3258     if (big_endian) {
   3259       rgb[0].SwapEndian();
   3260       rgb[1].SwapEndian();
   3261       rgb[2].SwapEndian();
   3262       rgb[3].SwapEndian();
   3263     }
   3264     StoreYCoCg(rgb[0], rgb[1], rgb[2], y + x, co + x, cg + x);
   3265     StorePixels(rgb[3], alpha + x);
   3266   }
   3267 #endif
   3268   for (; x < oxs; x++) {
   3269     uint16_t r = LoadLE16(rgba + 8 * x);
   3270     uint16_t g = LoadLE16(rgba + 8 * x + 2);
   3271     uint16_t b = LoadLE16(rgba + 8 * x + 4);
   3272     uint16_t a = LoadLE16(rgba + 8 * x + 6);
   3273     if (big_endian) {
   3274       r = SwapEndian(r);
   3275       g = SwapEndian(g);
   3276       b = SwapEndian(b);
   3277       a = SwapEndian(a);
   3278     }
   3279     StoreYCoCg<pixel_t>(r, g, b, y + x, co + x, cg + x);
   3280     alpha[x] = a;
   3281   }
   3282 }
   3283 
   3284 template <typename Processor, typename BitDepth>
   3285 void ProcessImageArea(const unsigned char* rgba, size_t x0, size_t y0,
   3286                       size_t xs, size_t yskip, size_t ys, size_t row_stride,
   3287                       BitDepth bitdepth, size_t nb_chans, bool big_endian,
   3288                       Processor* processors) {
   3289   constexpr size_t kPadding = 32;
   3290 
   3291   using pixel_t = typename BitDepth::pixel_t;
   3292 
   3293   constexpr size_t kAlign = 64;
   3294   constexpr size_t kAlignPixels = kAlign / sizeof(pixel_t);
   3295 
   3296   auto align = [=](pixel_t* ptr) {
   3297     size_t offset = reinterpret_cast<uintptr_t>(ptr) % kAlign;
   3298     if (offset) {
   3299       ptr += offset / sizeof(pixel_t);
   3300     }
   3301     return ptr;
   3302   };
   3303 
   3304   constexpr size_t kNumPx =
   3305       (256 + kPadding * 2 + kAlignPixels + kAlignPixels - 1) / kAlignPixels *
   3306       kAlignPixels;
   3307 
   3308   std::vector<std::array<std::array<pixel_t, kNumPx>, 2>> group_data(nb_chans);
   3309 
   3310   for (size_t y = 0; y < ys; y++) {
   3311     const auto rgba_row =
   3312         rgba + row_stride * (y0 + y) + x0 * nb_chans * BitDepth::kInputBytes;
   3313     pixel_t* crow[4] = {};
   3314     pixel_t* prow[4] = {};
   3315     for (size_t i = 0; i < nb_chans; i++) {
   3316       crow[i] = align(&group_data[i][y & 1][kPadding]);
   3317       prow[i] = align(&group_data[i][(y - 1) & 1][kPadding]);
   3318     }
   3319 
   3320     // Pre-fill rows with YCoCg converted pixels.
   3321     if (nb_chans == 1) {
   3322       if (BitDepth::kInputBytes == 1) {
   3323         FillRowG8(rgba_row, xs, crow[0]);
   3324       } else if (big_endian) {
   3325         FillRowG16</*big_endian=*/true>(rgba_row, xs, crow[0]);
   3326       } else {
   3327         FillRowG16</*big_endian=*/false>(rgba_row, xs, crow[0]);
   3328       }
   3329     } else if (nb_chans == 2) {
   3330       if (BitDepth::kInputBytes == 1) {
   3331         FillRowGA8(rgba_row, xs, crow[0], crow[1]);
   3332       } else if (big_endian) {
   3333         FillRowGA16</*big_endian=*/true>(rgba_row, xs, crow[0], crow[1]);
   3334       } else {
   3335         FillRowGA16</*big_endian=*/false>(rgba_row, xs, crow[0], crow[1]);
   3336       }
   3337     } else if (nb_chans == 3) {
   3338       if (BitDepth::kInputBytes == 1) {
   3339         FillRowRGB8(rgba_row, xs, crow[0], crow[1], crow[2]);
   3340       } else if (big_endian) {
   3341         FillRowRGB16</*big_endian=*/true>(rgba_row, xs, crow[0], crow[1],
   3342                                           crow[2]);
   3343       } else {
   3344         FillRowRGB16</*big_endian=*/false>(rgba_row, xs, crow[0], crow[1],
   3345                                            crow[2]);
   3346       }
   3347     } else {
   3348       if (BitDepth::kInputBytes == 1) {
   3349         FillRowRGBA8(rgba_row, xs, crow[0], crow[1], crow[2], crow[3]);
   3350       } else if (big_endian) {
   3351         FillRowRGBA16</*big_endian=*/true>(rgba_row, xs, crow[0], crow[1],
   3352                                            crow[2], crow[3]);
   3353       } else {
   3354         FillRowRGBA16</*big_endian=*/false>(rgba_row, xs, crow[0], crow[1],
   3355                                             crow[2], crow[3]);
   3356       }
   3357     }
   3358     // Deal with x == 0.
   3359     for (size_t c = 0; c < nb_chans; c++) {
   3360       *(crow[c] - 1) = y > 0 ? *(prow[c]) : 0;
   3361       // Fix topleft.
   3362       *(prow[c] - 1) = y > 0 ? *(prow[c]) : 0;
   3363     }
   3364     if (y < yskip) continue;
   3365     for (size_t c = 0; c < nb_chans; c++) {
   3366       // Get pointers to px/left/top/topleft data to speedup loop.
   3367       const pixel_t* row = crow[c];
   3368       const pixel_t* row_left = crow[c] - 1;
   3369       const pixel_t* row_top = y == 0 ? row_left : prow[c];
   3370       const pixel_t* row_topleft = y == 0 ? row_left : prow[c] - 1;
   3371 
   3372       processors[c].ProcessRow(row, row_left, row_top, row_topleft, xs);
   3373     }
   3374   }
   3375   for (size_t c = 0; c < nb_chans; c++) {
   3376     processors[c].Finalize();
   3377   }
   3378 }
   3379 
   3380 template <typename BitDepth>
   3381 void WriteACSection(const unsigned char* rgba, size_t x0, size_t y0, size_t xs,
   3382                     size_t ys, size_t row_stride, bool is_single_group,
   3383                     BitDepth bitdepth, size_t nb_chans, bool big_endian,
   3384                     const PrefixCode code[4],
   3385                     std::array<BitWriter, 4>& output) {
   3386   for (size_t i = 0; i < nb_chans; i++) {
   3387     if (is_single_group && i == 0) continue;
   3388     output[i].Allocate(xs * ys * bitdepth.MaxEncodedBitsPerSample() + 4);
   3389   }
   3390   if (!is_single_group) {
   3391     // Group header for modular image.
   3392     // When the image is single-group, the global modular image is the one
   3393     // that contains the pixel data, and there is no group header.
   3394     output[0].Write(1, 1);     // Global tree
   3395     output[0].Write(1, 1);     // All default wp
   3396     output[0].Write(2, 0b00);  // 0 transforms
   3397   }
   3398 
   3399   ChunkEncoder<BitDepth> encoders[4];
   3400   ChannelRowProcessor<ChunkEncoder<BitDepth>, BitDepth> row_encoders[4];
   3401   for (size_t c = 0; c < nb_chans; c++) {
   3402     row_encoders[c].t = &encoders[c];
   3403     encoders[c].output = &output[c];
   3404     encoders[c].code = &code[c];
   3405     encoders[c].PrepareForSimd();
   3406   }
   3407   ProcessImageArea<ChannelRowProcessor<ChunkEncoder<BitDepth>, BitDepth>>(
   3408       rgba, x0, y0, xs, 0, ys, row_stride, bitdepth, nb_chans, big_endian,
   3409       row_encoders);
   3410 }
   3411 
   3412 constexpr int kHashExp = 16;
   3413 constexpr uint32_t kHashSize = 1 << kHashExp;
   3414 constexpr uint32_t kHashMultiplier = 2654435761;
   3415 constexpr int kMaxColors = 512;
   3416 
   3417 // can be any function that returns a value in 0 .. kHashSize-1
   3418 // has to map 0 to 0
   3419 inline uint32_t pixel_hash(uint32_t p) {
   3420   return (p * kHashMultiplier) >> (32 - kHashExp);
   3421 }
   3422 
   3423 template <size_t nb_chans>
   3424 void FillRowPalette(const unsigned char* inrow, size_t xs,
   3425                     const int16_t* lookup, int16_t* out) {
   3426   for (size_t x = 0; x < xs; x++) {
   3427     uint32_t p = 0;
   3428     memcpy(&p, inrow + x * nb_chans, nb_chans);
   3429     out[x] = lookup[pixel_hash(p)];
   3430   }
   3431 }
   3432 
   3433 template <typename Processor>
   3434 void ProcessImageAreaPalette(const unsigned char* rgba, size_t x0, size_t y0,
   3435                              size_t xs, size_t yskip, size_t ys,
   3436                              size_t row_stride, const int16_t* lookup,
   3437                              size_t nb_chans, Processor* processors) {
   3438   constexpr size_t kPadding = 32;
   3439 
   3440   std::vector<std::array<int16_t, 256 + kPadding * 2>> group_data(2);
   3441   Processor& row_encoder = processors[0];
   3442 
   3443   for (size_t y = 0; y < ys; y++) {
   3444     // Pre-fill rows with palette converted pixels.
   3445     const unsigned char* inrow = rgba + row_stride * (y0 + y) + x0 * nb_chans;
   3446     int16_t* outrow = &group_data[y & 1][kPadding];
   3447     if (nb_chans == 1) {
   3448       FillRowPalette<1>(inrow, xs, lookup, outrow);
   3449     } else if (nb_chans == 2) {
   3450       FillRowPalette<2>(inrow, xs, lookup, outrow);
   3451     } else if (nb_chans == 3) {
   3452       FillRowPalette<3>(inrow, xs, lookup, outrow);
   3453     } else if (nb_chans == 4) {
   3454       FillRowPalette<4>(inrow, xs, lookup, outrow);
   3455     }
   3456     // Deal with x == 0.
   3457     group_data[y & 1][kPadding - 1] =
   3458         y > 0 ? group_data[(y - 1) & 1][kPadding] : 0;
   3459     // Fix topleft.
   3460     group_data[(y - 1) & 1][kPadding - 1] =
   3461         y > 0 ? group_data[(y - 1) & 1][kPadding] : 0;
   3462     // Get pointers to px/left/top/topleft data to speedup loop.
   3463     const int16_t* row = &group_data[y & 1][kPadding];
   3464     const int16_t* row_left = &group_data[y & 1][kPadding - 1];
   3465     const int16_t* row_top =
   3466         y == 0 ? row_left : &group_data[(y - 1) & 1][kPadding];
   3467     const int16_t* row_topleft =
   3468         y == 0 ? row_left : &group_data[(y - 1) & 1][kPadding - 1];
   3469 
   3470     row_encoder.ProcessRow(row, row_left, row_top, row_topleft, xs);
   3471   }
   3472   row_encoder.Finalize();
   3473 }
   3474 
   3475 void WriteACSectionPalette(const unsigned char* rgba, size_t x0, size_t y0,
   3476                            size_t xs, size_t ys, size_t row_stride,
   3477                            bool is_single_group, const PrefixCode code[4],
   3478                            const int16_t* lookup, size_t nb_chans,
   3479                            BitWriter& output) {
   3480   if (!is_single_group) {
   3481     output.Allocate(16 * xs * ys + 4);
   3482     // Group header for modular image.
   3483     // When the image is single-group, the global modular image is the one
   3484     // that contains the pixel data, and there is no group header.
   3485     output.Write(1, 1);     // Global tree
   3486     output.Write(1, 1);     // All default wp
   3487     output.Write(2, 0b00);  // 0 transforms
   3488   }
   3489 
   3490   ChunkEncoder<UpTo8Bits> encoder;
   3491   ChannelRowProcessor<ChunkEncoder<UpTo8Bits>, UpTo8Bits> row_encoder;
   3492 
   3493   row_encoder.t = &encoder;
   3494   encoder.output = &output;
   3495   encoder.code = &code[is_single_group ? 1 : 0];
   3496   encoder.PrepareForSimd();
   3497   ProcessImageAreaPalette<
   3498       ChannelRowProcessor<ChunkEncoder<UpTo8Bits>, UpTo8Bits>>(
   3499       rgba, x0, y0, xs, 0, ys, row_stride, lookup, nb_chans, &row_encoder);
   3500 }
   3501 
   3502 template <typename BitDepth>
   3503 void CollectSamples(const unsigned char* rgba, size_t x0, size_t y0, size_t xs,
   3504                     size_t row_stride, size_t row_count,
   3505                     uint64_t raw_counts[4][kNumRawSymbols],
   3506                     uint64_t lz77_counts[4][kNumLZ77], bool is_single_group,
   3507                     bool palette, BitDepth bitdepth, size_t nb_chans,
   3508                     bool big_endian, const int16_t* lookup) {
   3509   if (palette) {
   3510     ChunkSampleCollector<UpTo8Bits> sample_collectors[4];
   3511     ChannelRowProcessor<ChunkSampleCollector<UpTo8Bits>, UpTo8Bits>
   3512         row_sample_collectors[4];
   3513     for (size_t c = 0; c < nb_chans; c++) {
   3514       row_sample_collectors[c].t = &sample_collectors[c];
   3515       sample_collectors[c].raw_counts = raw_counts[is_single_group ? 1 : 0];
   3516       sample_collectors[c].lz77_counts = lz77_counts[is_single_group ? 1 : 0];
   3517     }
   3518     ProcessImageAreaPalette<
   3519         ChannelRowProcessor<ChunkSampleCollector<UpTo8Bits>, UpTo8Bits>>(
   3520         rgba, x0, y0, xs, 1, 1 + row_count, row_stride, lookup, nb_chans,
   3521         row_sample_collectors);
   3522   } else {
   3523     ChunkSampleCollector<BitDepth> sample_collectors[4];
   3524     ChannelRowProcessor<ChunkSampleCollector<BitDepth>, BitDepth>
   3525         row_sample_collectors[4];
   3526     for (size_t c = 0; c < nb_chans; c++) {
   3527       row_sample_collectors[c].t = &sample_collectors[c];
   3528       sample_collectors[c].raw_counts = raw_counts[c];
   3529       sample_collectors[c].lz77_counts = lz77_counts[c];
   3530     }
   3531     ProcessImageArea<
   3532         ChannelRowProcessor<ChunkSampleCollector<BitDepth>, BitDepth>>(
   3533         rgba, x0, y0, xs, 1, 1 + row_count, row_stride, bitdepth, nb_chans,
   3534         big_endian, row_sample_collectors);
   3535   }
   3536 }
   3537 
   3538 void PrepareDCGlobalPalette(bool is_single_group, size_t width, size_t height,
   3539                             size_t nb_chans, const PrefixCode code[4],
   3540                             const std::vector<uint32_t>& palette,
   3541                             size_t pcolors, BitWriter* output) {
   3542   PrepareDCGlobalCommon(is_single_group, width, height, code, output);
   3543   output->Write(2, 0b01);     // 1 transform
   3544   output->Write(2, 0b01);     // Palette
   3545   output->Write(5, 0b00000);  // Starting from ch 0
   3546   if (nb_chans == 1) {
   3547     output->Write(2, 0b00);  // 1-channel palette (Gray)
   3548   } else if (nb_chans == 3) {
   3549     output->Write(2, 0b01);  // 3-channel palette (RGB)
   3550   } else if (nb_chans == 4) {
   3551     output->Write(2, 0b10);  // 4-channel palette (RGBA)
   3552   } else {
   3553     output->Write(2, 0b11);
   3554     output->Write(13, nb_chans - 1);
   3555   }
   3556   // pcolors <= kMaxColors + kChunkSize - 1
   3557   static_assert(kMaxColors + kChunkSize < 1281,
   3558                 "add code to signal larger palette sizes");
   3559   if (pcolors < 256) {
   3560     output->Write(2, 0b00);
   3561     output->Write(8, pcolors);
   3562   } else {
   3563     output->Write(2, 0b01);
   3564     output->Write(10, pcolors - 256);
   3565   }
   3566 
   3567   output->Write(2, 0b00);  // nb_deltas == 0
   3568   output->Write(4, 0);     // Zero predictor for delta palette
   3569   // Encode palette
   3570   ChunkEncoder<UpTo8Bits> encoder;
   3571   ChannelRowProcessor<ChunkEncoder<UpTo8Bits>, UpTo8Bits> row_encoder;
   3572   row_encoder.t = &encoder;
   3573   encoder.output = output;
   3574   encoder.code = &code[0];
   3575   encoder.PrepareForSimd();
   3576   int16_t p[4][32 + 1024] = {};
   3577   uint8_t prgba[4];
   3578   size_t i = 0;
   3579   size_t have_zero = 1;
   3580   for (; i < pcolors; i++) {
   3581     memcpy(prgba, &palette[i], 4);
   3582     p[0][16 + i + have_zero] = prgba[0];
   3583     p[1][16 + i + have_zero] = prgba[1];
   3584     p[2][16 + i + have_zero] = prgba[2];
   3585     p[3][16 + i + have_zero] = prgba[3];
   3586   }
   3587   p[0][15] = 0;
   3588   row_encoder.ProcessRow(p[0] + 16, p[0] + 15, p[0] + 15, p[0] + 15, pcolors);
   3589   p[1][15] = p[0][16];
   3590   p[0][15] = p[0][16];
   3591   if (nb_chans > 1) {
   3592     row_encoder.ProcessRow(p[1] + 16, p[1] + 15, p[0] + 16, p[0] + 15, pcolors);
   3593   }
   3594   p[2][15] = p[1][16];
   3595   p[1][15] = p[1][16];
   3596   if (nb_chans > 2) {
   3597     row_encoder.ProcessRow(p[2] + 16, p[2] + 15, p[1] + 16, p[1] + 15, pcolors);
   3598   }
   3599   p[3][15] = p[2][16];
   3600   p[2][15] = p[2][16];
   3601   if (nb_chans > 3) {
   3602     row_encoder.ProcessRow(p[3] + 16, p[3] + 15, p[2] + 16, p[2] + 15, pcolors);
   3603   }
   3604   row_encoder.Finalize();
   3605 
   3606   if (!is_single_group) {
   3607     output->ZeroPadToByte();
   3608   }
   3609 }
   3610 
   3611 template <size_t nb_chans>
   3612 bool detect_palette(const unsigned char* r, size_t width,
   3613                     std::vector<uint32_t>& palette) {
   3614   size_t x = 0;
   3615   bool collided = false;
   3616   // this is just an unrolling of the next loop
   3617   for (; x + 7 < width; x += 8) {
   3618     uint32_t p[8] = {}, index[8];
   3619     for (int i = 0; i < 8; i++) memcpy(&p[i], r + (x + i) * nb_chans, 4);
   3620     for (int i = 0; i < 8; i++) p[i] &= ((1llu << (8 * nb_chans)) - 1);
   3621     for (int i = 0; i < 8; i++) index[i] = pixel_hash(p[i]);
   3622     for (int i = 0; i < 8; i++) {
   3623       collided |= (palette[index[i]] != 0 && p[i] != palette[index[i]]);
   3624     }
   3625     for (int i = 0; i < 8; i++) palette[index[i]] = p[i];
   3626   }
   3627   for (; x < width; x++) {
   3628     uint32_t p = 0;
   3629     memcpy(&p, r + x * nb_chans, nb_chans);
   3630     uint32_t index = pixel_hash(p);
   3631     collided |= (palette[index] != 0 && p != palette[index]);
   3632     palette[index] = p;
   3633   }
   3634   return collided;
   3635 }
   3636 
   3637 template <typename BitDepth>
   3638 JxlFastLosslessFrameState* LLPrepare(JxlChunkedFrameInputSource input,
   3639                                      size_t width, size_t height,
   3640                                      BitDepth bitdepth, size_t nb_chans,
   3641                                      bool big_endian, int effort, int oneshot) {
   3642   assert(width != 0);
   3643   assert(height != 0);
   3644 
   3645   // Count colors to try palette
   3646   std::vector<uint32_t> palette(kHashSize);
   3647   std::vector<int16_t> lookup(kHashSize);
   3648   lookup[0] = 0;
   3649   int pcolors = 0;
   3650   bool collided = effort < 2 || bitdepth.bitdepth != 8 || !oneshot;
   3651   for (size_t y0 = 0; y0 < height && !collided; y0 += 256) {
   3652     size_t ys = std::min<size_t>(height - y0, 256);
   3653     for (size_t x0 = 0; x0 < width && !collided; x0 += 256) {
   3654       size_t xs = std::min<size_t>(width - x0, 256);
   3655       size_t stride;
   3656       // TODO(szabadka): Add RAII wrapper around this.
   3657       const void* buffer = input.get_color_channel_data_at(input.opaque, x0, y0,
   3658                                                            xs, ys, &stride);
   3659       auto rgba = reinterpret_cast<const unsigned char*>(buffer);
   3660       for (size_t y = 0; y < ys && !collided; y++) {
   3661         const unsigned char* r = rgba + stride * y;
   3662         if (nb_chans == 1) collided = detect_palette<1>(r, xs, palette);
   3663         if (nb_chans == 2) collided = detect_palette<2>(r, xs, palette);
   3664         if (nb_chans == 3) collided = detect_palette<3>(r, xs, palette);
   3665         if (nb_chans == 4) collided = detect_palette<4>(r, xs, palette);
   3666       }
   3667       input.release_buffer(input.opaque, buffer);
   3668     }
   3669   }
   3670   int nb_entries = 0;
   3671   if (!collided) {
   3672     pcolors = 1;  // always have all-zero as a palette color
   3673     bool have_color = false;
   3674     uint8_t minG = 255, maxG = 0;
   3675     for (uint32_t k = 0; k < kHashSize; k++) {
   3676       if (palette[k] == 0) continue;
   3677       uint8_t p[4];
   3678       memcpy(p, &palette[k], 4);
   3679       // move entries to front so sort has less work
   3680       palette[nb_entries] = palette[k];
   3681       if (p[0] != p[1] || p[0] != p[2]) have_color = true;
   3682       if (p[1] < minG) minG = p[1];
   3683       if (p[1] > maxG) maxG = p[1];
   3684       nb_entries++;
   3685       // don't do palette if too many colors are needed
   3686       if (nb_entries + pcolors > kMaxColors) {
   3687         collided = true;
   3688         break;
   3689       }
   3690     }
   3691     if (!have_color) {
   3692       // don't do palette if it's just grayscale without many holes
   3693       if (maxG - minG < nb_entries * 1.4f) collided = true;
   3694     }
   3695   }
   3696   if (!collided) {
   3697     std::sort(
   3698         palette.begin(), palette.begin() + nb_entries,
   3699         [&nb_chans](uint32_t ap, uint32_t bp) {
   3700           if (ap == 0) return false;
   3701           if (bp == 0) return true;
   3702           uint8_t a[4], b[4];
   3703           memcpy(a, &ap, 4);
   3704           memcpy(b, &bp, 4);
   3705           float ay, by;
   3706           if (nb_chans == 4) {
   3707             ay = (0.299f * a[0] + 0.587f * a[1] + 0.114f * a[2] + 0.01f) * a[3];
   3708             by = (0.299f * b[0] + 0.587f * b[1] + 0.114f * b[2] + 0.01f) * b[3];
   3709           } else {
   3710             ay = (0.299f * a[0] + 0.587f * a[1] + 0.114f * a[2] + 0.01f);
   3711             by = (0.299f * b[0] + 0.587f * b[1] + 0.114f * b[2] + 0.01f);
   3712           }
   3713           return ay < by;  // sort on alpha*luma
   3714         });
   3715     for (int k = 0; k < nb_entries; k++) {
   3716       if (palette[k] == 0) break;
   3717       lookup[pixel_hash(palette[k])] = pcolors++;
   3718     }
   3719   }
   3720 
   3721   size_t num_groups_x = (width + 255) / 256;
   3722   size_t num_groups_y = (height + 255) / 256;
   3723   size_t num_dc_groups_x = (width + 2047) / 2048;
   3724   size_t num_dc_groups_y = (height + 2047) / 2048;
   3725 
   3726   uint64_t raw_counts[4][kNumRawSymbols] = {};
   3727   uint64_t lz77_counts[4][kNumLZ77] = {};
   3728 
   3729   bool onegroup = num_groups_x == 1 && num_groups_y == 1;
   3730 
   3731   auto sample_rows = [&](size_t xg, size_t yg, size_t num_rows) {
   3732     size_t y0 = yg * 256;
   3733     size_t x0 = xg * 256;
   3734     size_t ys = std::min<size_t>(height - y0, 256);
   3735     size_t xs = std::min<size_t>(width - x0, 256);
   3736     size_t stride;
   3737     const void* buffer =
   3738         input.get_color_channel_data_at(input.opaque, x0, y0, xs, ys, &stride);
   3739     auto rgba = reinterpret_cast<const unsigned char*>(buffer);
   3740     int y_begin_group =
   3741         std::max<ssize_t>(
   3742             0, static_cast<ssize_t>(ys) - static_cast<ssize_t>(num_rows)) /
   3743         2;
   3744     int y_count = std::min<int>(num_rows, ys - y_begin_group);
   3745     int x_max = xs / kChunkSize * kChunkSize;
   3746     CollectSamples(rgba, 0, y_begin_group, x_max, stride, y_count, raw_counts,
   3747                    lz77_counts, onegroup, !collided, bitdepth, nb_chans,
   3748                    big_endian, lookup.data());
   3749     input.release_buffer(input.opaque, buffer);
   3750   };
   3751 
   3752   // TODO(veluca): that `64` is an arbitrary constant, meant to correspond to
   3753   // the point where the number of processed rows is large enough that loading
   3754   // the entire image is cost-effective.
   3755   if (oneshot || effort >= 64) {
   3756     for (size_t g = 0; g < num_groups_y * num_groups_x; g++) {
   3757       size_t xg = g % num_groups_x;
   3758       size_t yg = g / num_groups_x;
   3759       size_t y0 = yg * 256;
   3760       size_t ys = std::min<size_t>(height - y0, 256);
   3761       size_t num_rows = 2 * effort * ys / 256;
   3762       sample_rows(xg, yg, num_rows);
   3763     }
   3764   } else {
   3765     // sample the middle (effort * 2 * num_groups) rows of the center group
   3766     // (possibly all of them).
   3767     sample_rows((num_groups_x - 1) / 2, (num_groups_y - 1) / 2,
   3768                 2 * effort * num_groups_x * num_groups_y);
   3769   }
   3770 
   3771   // TODO(veluca): can probably improve this and make it bitdepth-dependent.
   3772   uint64_t base_raw_counts[kNumRawSymbols] = {
   3773       3843, 852, 1270, 1214, 1014, 727, 481, 300, 159, 51,
   3774       5,    1,   1,    1,    1,    1,   1,   1,   1};
   3775 
   3776   bool doing_ycocg = nb_chans > 2 && collided;
   3777   bool large_palette = !collided || pcolors >= 256;
   3778   for (size_t i = bitdepth.NumSymbols(doing_ycocg || large_palette);
   3779        i < kNumRawSymbols; i++) {
   3780     base_raw_counts[i] = 0;
   3781   }
   3782 
   3783   for (size_t c = 0; c < 4; c++) {
   3784     for (size_t i = 0; i < kNumRawSymbols; i++) {
   3785       raw_counts[c][i] = (raw_counts[c][i] << 8) + base_raw_counts[i];
   3786     }
   3787   }
   3788 
   3789   if (!collided) {
   3790     unsigned token, nbits, bits;
   3791     EncodeHybridUint000(PackSigned(pcolors - 1), &token, &nbits, &bits);
   3792     // ensure all palette indices can actually be encoded
   3793     for (size_t i = 0; i < token + 1; i++)
   3794       raw_counts[0][i] = std::max<uint64_t>(raw_counts[0][i], 1);
   3795     // these tokens are only used for the palette itself so they can get a bad
   3796     // code
   3797     for (size_t i = token + 1; i < 10; i++) raw_counts[0][i] = 1;
   3798   }
   3799 
   3800   uint64_t base_lz77_counts[kNumLZ77] = {
   3801       29, 27, 25,  23, 21, 21, 19, 18, 21, 17, 16, 15, 15, 14,
   3802       13, 13, 137, 98, 61, 34, 1,  1,  1,  1,  1,  1,  1,  1,
   3803   };
   3804 
   3805   for (size_t c = 0; c < 4; c++) {
   3806     for (size_t i = 0; i < kNumLZ77; i++) {
   3807       lz77_counts[c][i] = (lz77_counts[c][i] << 8) + base_lz77_counts[i];
   3808     }
   3809   }
   3810 
   3811   JxlFastLosslessFrameState* frame_state = new JxlFastLosslessFrameState();
   3812   for (size_t i = 0; i < 4; i++) {
   3813     frame_state->hcode[i] = PrefixCode(bitdepth, raw_counts[i], lz77_counts[i]);
   3814   }
   3815 
   3816   size_t num_dc_groups = num_dc_groups_x * num_dc_groups_y;
   3817   size_t num_ac_groups = num_groups_x * num_groups_y;
   3818   size_t num_groups = onegroup ? 1 : (2 + num_dc_groups + num_ac_groups);
   3819   frame_state->input = input;
   3820   frame_state->width = width;
   3821   frame_state->height = height;
   3822   frame_state->num_groups_x = num_groups_x;
   3823   frame_state->num_groups_y = num_groups_y;
   3824   frame_state->num_dc_groups_x = num_dc_groups_x;
   3825   frame_state->num_dc_groups_y = num_dc_groups_y;
   3826   frame_state->nb_chans = nb_chans;
   3827   frame_state->bitdepth = bitdepth.bitdepth;
   3828   frame_state->big_endian = big_endian;
   3829   frame_state->effort = effort;
   3830   frame_state->collided = collided;
   3831   frame_state->lookup = lookup;
   3832 
   3833   frame_state->group_data = std::vector<std::array<BitWriter, 4>>(num_groups);
   3834   frame_state->group_sizes.resize(num_groups);
   3835   if (collided) {
   3836     PrepareDCGlobal(onegroup, width, height, nb_chans, frame_state->hcode,
   3837                     &frame_state->group_data[0][0]);
   3838   } else {
   3839     PrepareDCGlobalPalette(onegroup, width, height, nb_chans,
   3840                            frame_state->hcode, palette, pcolors,
   3841                            &frame_state->group_data[0][0]);
   3842   }
   3843   frame_state->group_sizes[0] = SectionSize(frame_state->group_data[0]);
   3844   if (!onegroup) {
   3845     ComputeAcGroupDataOffset(frame_state->group_sizes[0], num_dc_groups,
   3846                              num_ac_groups, frame_state->min_dc_global_size,
   3847                              frame_state->ac_group_data_offset);
   3848   }
   3849 
   3850   return frame_state;
   3851 }
   3852 
   3853 template <typename BitDepth>
   3854 void LLProcess(JxlFastLosslessFrameState* frame_state, bool is_last,
   3855                BitDepth bitdepth, void* runner_opaque,
   3856                FJxlParallelRunner runner,
   3857                JxlEncoderOutputProcessorWrapper* output_processor) {
   3858 #if !FJXL_STANDALONE
   3859   if (frame_state->process_done) {
   3860     JxlFastLosslessPrepareHeader(frame_state, /*add_image_header=*/0, is_last);
   3861     if (output_processor) {
   3862       JxlFastLosslessOutputFrame(frame_state, output_processor);
   3863     }
   3864     return;
   3865   }
   3866 #endif
   3867   // The maximum number of groups that we process concurrently here.
   3868   // TODO(szabadka) Use the number of threads or some outside parameter for the
   3869   // maximum memory usage instead.
   3870   constexpr size_t kMaxLocalGroups = 16;
   3871   bool onegroup = frame_state->group_sizes.size() == 1;
   3872   bool streaming = !onegroup && output_processor;
   3873   size_t total_groups = frame_state->num_groups_x * frame_state->num_groups_y;
   3874   size_t max_groups = streaming ? kMaxLocalGroups : total_groups;
   3875 #if !FJXL_STANDALONE
   3876   size_t start_pos = 0;
   3877   if (streaming) {
   3878     start_pos = output_processor->CurrentPosition();
   3879     output_processor->Seek(start_pos + frame_state->ac_group_data_offset);
   3880   }
   3881 #endif
   3882   for (size_t offset = 0; offset < total_groups; offset += max_groups) {
   3883     size_t num_groups = std::min(max_groups, total_groups - offset);
   3884     JxlFastLosslessFrameState local_frame_state;
   3885     if (streaming) {
   3886       local_frame_state.group_data =
   3887           std::vector<std::array<BitWriter, 4>>(num_groups);
   3888     }
   3889     auto run_one = [&](size_t i) {
   3890       size_t g = offset + i;
   3891       size_t xg = g % frame_state->num_groups_x;
   3892       size_t yg = g / frame_state->num_groups_x;
   3893       size_t num_dc_groups =
   3894           frame_state->num_dc_groups_x * frame_state->num_dc_groups_y;
   3895       size_t group_id = onegroup ? 0 : (2 + num_dc_groups + g);
   3896       size_t xs = std::min<size_t>(frame_state->width - xg * 256, 256);
   3897       size_t ys = std::min<size_t>(frame_state->height - yg * 256, 256);
   3898       size_t x0 = xg * 256;
   3899       size_t y0 = yg * 256;
   3900       size_t stride;
   3901       JxlChunkedFrameInputSource input = frame_state->input;
   3902       const void* buffer = input.get_color_channel_data_at(input.opaque, x0, y0,
   3903                                                            xs, ys, &stride);
   3904       const unsigned char* rgba =
   3905           reinterpret_cast<const unsigned char*>(buffer);
   3906 
   3907       auto& gd = streaming ? local_frame_state.group_data[i]
   3908                            : frame_state->group_data[group_id];
   3909       if (frame_state->collided) {
   3910         WriteACSection(rgba, 0, 0, xs, ys, stride, onegroup, bitdepth,
   3911                        frame_state->nb_chans, frame_state->big_endian,
   3912                        frame_state->hcode, gd);
   3913       } else {
   3914         WriteACSectionPalette(rgba, 0, 0, xs, ys, stride, onegroup,
   3915                               frame_state->hcode, frame_state->lookup.data(),
   3916                               frame_state->nb_chans, gd[0]);
   3917       }
   3918       frame_state->group_sizes[group_id] = SectionSize(gd);
   3919       input.release_buffer(input.opaque, buffer);
   3920     };
   3921     runner(
   3922         runner_opaque, &run_one,
   3923         +[](void* r, size_t i) {
   3924           (*reinterpret_cast<decltype(&run_one)>(r))(i);
   3925         },
   3926         num_groups);
   3927 #if !FJXL_STANDALONE
   3928     if (streaming) {
   3929       local_frame_state.nb_chans = frame_state->nb_chans;
   3930       local_frame_state.current_bit_writer = 1;
   3931       JxlFastLosslessOutputFrame(&local_frame_state, output_processor);
   3932     }
   3933 #endif
   3934   }
   3935 #if !FJXL_STANDALONE
   3936   if (streaming) {
   3937     size_t end_pos = output_processor->CurrentPosition();
   3938     output_processor->Seek(start_pos);
   3939     frame_state->group_data.resize(1);
   3940     bool have_alpha = frame_state->nb_chans == 2 || frame_state->nb_chans == 4;
   3941     size_t padding = ComputeDcGlobalPadding(
   3942         frame_state->group_sizes, frame_state->ac_group_data_offset,
   3943         frame_state->min_dc_global_size, have_alpha, is_last);
   3944 
   3945     for (size_t i = 0; i < padding; ++i) {
   3946       frame_state->group_data[0][0].Write(8, 0);
   3947     }
   3948     frame_state->group_sizes[0] += padding;
   3949     JxlFastLosslessPrepareHeader(frame_state, /*add_image_header=*/0, is_last);
   3950     assert(frame_state->ac_group_data_offset ==
   3951            JxlFastLosslessOutputSize(frame_state));
   3952     JxlFastLosslessOutputHeaders(frame_state, output_processor);
   3953     output_processor->Seek(end_pos);
   3954   } else if (output_processor) {
   3955     assert(onegroup);
   3956     JxlFastLosslessPrepareHeader(frame_state, /*add_image_header=*/0, is_last);
   3957     if (output_processor) {
   3958       JxlFastLosslessOutputFrame(frame_state, output_processor);
   3959     }
   3960   }
   3961   frame_state->process_done = true;
   3962 #endif
   3963 }
   3964 
   3965 JxlFastLosslessFrameState* JxlFastLosslessPrepareImpl(
   3966     JxlChunkedFrameInputSource input, size_t width, size_t height,
   3967     size_t nb_chans, size_t bitdepth, bool big_endian, int effort,
   3968     int oneshot) {
   3969   assert(bitdepth > 0);
   3970   assert(nb_chans <= 4);
   3971   assert(nb_chans != 0);
   3972   if (bitdepth <= 8) {
   3973     return LLPrepare(input, width, height, UpTo8Bits(bitdepth), nb_chans,
   3974                      big_endian, effort, oneshot);
   3975   }
   3976   if (bitdepth <= 13) {
   3977     return LLPrepare(input, width, height, From9To13Bits(bitdepth), nb_chans,
   3978                      big_endian, effort, oneshot);
   3979   }
   3980   if (bitdepth == 14) {
   3981     return LLPrepare(input, width, height, Exactly14Bits(bitdepth), nb_chans,
   3982                      big_endian, effort, oneshot);
   3983   }
   3984   return LLPrepare(input, width, height, MoreThan14Bits(bitdepth), nb_chans,
   3985                    big_endian, effort, oneshot);
   3986 }
   3987 
   3988 void JxlFastLosslessProcessFrameImpl(
   3989     JxlFastLosslessFrameState* frame_state, bool is_last, void* runner_opaque,
   3990     FJxlParallelRunner runner,
   3991     JxlEncoderOutputProcessorWrapper* output_processor) {
   3992   const size_t bitdepth = frame_state->bitdepth;
   3993   if (bitdepth <= 8) {
   3994     LLProcess(frame_state, is_last, UpTo8Bits(bitdepth), runner_opaque, runner,
   3995               output_processor);
   3996   } else if (bitdepth <= 13) {
   3997     LLProcess(frame_state, is_last, From9To13Bits(bitdepth), runner_opaque,
   3998               runner, output_processor);
   3999   } else if (bitdepth == 14) {
   4000     LLProcess(frame_state, is_last, Exactly14Bits(bitdepth), runner_opaque,
   4001               runner, output_processor);
   4002   } else {
   4003     LLProcess(frame_state, is_last, MoreThan14Bits(bitdepth), runner_opaque,
   4004               runner, output_processor);
   4005   }
   4006 }
   4007 
   4008 }  // namespace
   4009 
   4010 #endif  // FJXL_SELF_INCLUDE
   4011 
   4012 #ifndef FJXL_SELF_INCLUDE
   4013 
   4014 #define FJXL_SELF_INCLUDE
   4015 
   4016 // If we have NEON enabled, it is the default target.
   4017 #if FJXL_ENABLE_NEON
   4018 
   4019 namespace default_implementation {
   4020 #define FJXL_NEON
   4021 #include "lib/jxl/enc_fast_lossless.cc"
   4022 #undef FJXL_NEON
   4023 }  // namespace default_implementation
   4024 
   4025 #else  // FJXL_ENABLE_NEON
   4026 
   4027 namespace default_implementation {
   4028 #include "lib/jxl/enc_fast_lossless.cc"  // NOLINT
   4029 }
   4030 
   4031 #if FJXL_ENABLE_AVX2
   4032 #ifdef __clang__
   4033 #pragma clang attribute push(__attribute__((target("avx,avx2"))), \
   4034                              apply_to = function)
   4035 // Causes spurious warnings on clang5.
   4036 #pragma clang diagnostic push
   4037 #pragma clang diagnostic ignored "-Wmissing-braces"
   4038 #elif defined(__GNUC__)
   4039 #pragma GCC push_options
   4040 // Seems to cause spurious errors on GCC8.
   4041 #pragma GCC diagnostic ignored "-Wpsabi"
   4042 #pragma GCC target "avx,avx2"
   4043 #endif
   4044 
   4045 namespace AVX2 {
   4046 #define FJXL_AVX2
   4047 #include "lib/jxl/enc_fast_lossless.cc"  // NOLINT
   4048 #undef FJXL_AVX2
   4049 }  // namespace AVX2
   4050 
   4051 #ifdef __clang__
   4052 #pragma clang attribute pop
   4053 #pragma clang diagnostic pop
   4054 #elif defined(__GNUC__)
   4055 #pragma GCC pop_options
   4056 #endif
   4057 #endif  // FJXL_ENABLE_AVX2
   4058 
   4059 #if FJXL_ENABLE_AVX512
   4060 #ifdef __clang__
   4061 #pragma clang attribute push(                                                 \
   4062     __attribute__((target("avx512cd,avx512bw,avx512vl,avx512f,avx512vbmi"))), \
   4063     apply_to = function)
   4064 #elif defined(__GNUC__)
   4065 #pragma GCC push_options
   4066 #pragma GCC target "avx512cd,avx512bw,avx512vl,avx512f,avx512vbmi"
   4067 #endif
   4068 
   4069 namespace AVX512 {
   4070 #define FJXL_AVX512
   4071 #include "lib/jxl/enc_fast_lossless.cc"
   4072 #undef FJXL_AVX512
   4073 }  // namespace AVX512
   4074 
   4075 #ifdef __clang__
   4076 #pragma clang attribute pop
   4077 #elif defined(__GNUC__)
   4078 #pragma GCC pop_options
   4079 #endif
   4080 #endif  // FJXL_ENABLE_AVX512
   4081 
   4082 #endif
   4083 
   4084 extern "C" {
   4085 
   4086 #if FJXL_STANDALONE
   4087 class FJxlFrameInput {
   4088  public:
   4089   FJxlFrameInput(const unsigned char* rgba, size_t row_stride, size_t nb_chans,
   4090                  size_t bitdepth)
   4091       : rgba_(rgba),
   4092         row_stride_(row_stride),
   4093         bytes_per_pixel_(bitdepth <= 8 ? nb_chans : 2 * nb_chans) {}
   4094 
   4095   JxlChunkedFrameInputSource GetInputSource() {
   4096     return JxlChunkedFrameInputSource{this, GetDataAt,
   4097                                       [](void*, const void*) {}};
   4098   }
   4099 
   4100  private:
   4101   static const void* GetDataAt(void* opaque, size_t xpos, size_t ypos,
   4102                                size_t xsize, size_t ysize, size_t* row_offset) {
   4103     FJxlFrameInput* self = static_cast<FJxlFrameInput*>(opaque);
   4104     *row_offset = self->row_stride_;
   4105     return self->rgba_ + ypos * (*row_offset) + xpos * self->bytes_per_pixel_;
   4106   }
   4107 
   4108   const uint8_t* rgba_;
   4109   size_t row_stride_;
   4110   size_t bytes_per_pixel_;
   4111 };
   4112 
   4113 size_t JxlFastLosslessEncode(const unsigned char* rgba, size_t width,
   4114                              size_t row_stride, size_t height, size_t nb_chans,
   4115                              size_t bitdepth, int big_endian, int effort,
   4116                              unsigned char** output, void* runner_opaque,
   4117                              FJxlParallelRunner runner) {
   4118   FJxlFrameInput input(rgba, row_stride, nb_chans, bitdepth);
   4119   auto frame_state = JxlFastLosslessPrepareFrame(
   4120       input.GetInputSource(), width, height, nb_chans, bitdepth, big_endian,
   4121       effort, /*oneshot=*/true);
   4122   JxlFastLosslessProcessFrame(frame_state, /*is_last=*/true, runner_opaque,
   4123                               runner, nullptr);
   4124   JxlFastLosslessPrepareHeader(frame_state, /*add_image_header=*/1,
   4125                                /*is_last=*/1);
   4126   size_t output_size = JxlFastLosslessMaxRequiredOutput(frame_state);
   4127   *output = (unsigned char*)malloc(output_size);
   4128   size_t written = 0;
   4129   size_t total = 0;
   4130   while ((written = JxlFastLosslessWriteOutput(frame_state, *output + total,
   4131                                                output_size - total)) != 0) {
   4132     total += written;
   4133   }
   4134   JxlFastLosslessFreeFrameState(frame_state);
   4135   return total;
   4136 }
   4137 #endif
   4138 
   4139 JxlFastLosslessFrameState* JxlFastLosslessPrepareFrame(
   4140     JxlChunkedFrameInputSource input, size_t width, size_t height,
   4141     size_t nb_chans, size_t bitdepth, int big_endian, int effort, int oneshot) {
   4142 #if FJXL_ENABLE_AVX512
   4143   if (__builtin_cpu_supports("avx512cd") &&
   4144       __builtin_cpu_supports("avx512vbmi") &&
   4145       __builtin_cpu_supports("avx512bw") && __builtin_cpu_supports("avx512f") &&
   4146       __builtin_cpu_supports("avx512vl")) {
   4147     return AVX512::JxlFastLosslessPrepareImpl(
   4148         input, width, height, nb_chans, bitdepth, big_endian, effort, oneshot);
   4149   }
   4150 #endif
   4151 #if FJXL_ENABLE_AVX2
   4152   if (__builtin_cpu_supports("avx2")) {
   4153     return AVX2::JxlFastLosslessPrepareImpl(
   4154         input, width, height, nb_chans, bitdepth, big_endian, effort, oneshot);
   4155   }
   4156 #endif
   4157 
   4158   return default_implementation::JxlFastLosslessPrepareImpl(
   4159       input, width, height, nb_chans, bitdepth, big_endian, effort, oneshot);
   4160 }
   4161 
   4162 void JxlFastLosslessProcessFrame(
   4163     JxlFastLosslessFrameState* frame_state, bool is_last, void* runner_opaque,
   4164     FJxlParallelRunner runner,
   4165     JxlEncoderOutputProcessorWrapper* output_processor) {
   4166   auto trivial_runner =
   4167       +[](void*, void* opaque, void fun(void*, size_t), size_t count) {
   4168         for (size_t i = 0; i < count; i++) {
   4169           fun(opaque, i);
   4170         }
   4171       };
   4172 
   4173   if (runner == nullptr) {
   4174     runner = trivial_runner;
   4175   }
   4176 
   4177 #if FJXL_ENABLE_AVX512
   4178   if (__builtin_cpu_supports("avx512cd") &&
   4179       __builtin_cpu_supports("avx512vbmi") &&
   4180       __builtin_cpu_supports("avx512bw") && __builtin_cpu_supports("avx512f") &&
   4181       __builtin_cpu_supports("avx512vl")) {
   4182     AVX512::JxlFastLosslessProcessFrameImpl(frame_state, is_last, runner_opaque,
   4183                                             runner, output_processor);
   4184     return;
   4185   }
   4186 #endif
   4187 #if FJXL_ENABLE_AVX2
   4188   if (__builtin_cpu_supports("avx2")) {
   4189     AVX2::JxlFastLosslessProcessFrameImpl(frame_state, is_last, runner_opaque,
   4190                                           runner, output_processor);
   4191     return;
   4192   }
   4193 #endif
   4194 
   4195   default_implementation::JxlFastLosslessProcessFrameImpl(
   4196       frame_state, is_last, runner_opaque, runner, output_processor);
   4197 }
   4198 
   4199 }  // extern "C"
   4200 
   4201 #if !FJXL_STANDALONE
   4202 void JxlFastLosslessOutputFrame(
   4203     JxlFastLosslessFrameState* frame_state,
   4204     JxlEncoderOutputProcessorWrapper* output_processor) {
   4205   size_t fl_size = JxlFastLosslessOutputSize(frame_state);
   4206   size_t written = 0;
   4207   while (written < fl_size) {
   4208     auto retval = output_processor->GetBuffer(32, fl_size - written);
   4209     assert(retval.status());
   4210     auto buffer = std::move(retval).value();
   4211     size_t n =
   4212         JxlFastLosslessWriteOutput(frame_state, buffer.data(), buffer.size());
   4213     if (n == 0) break;
   4214     buffer.advance(n);
   4215     written += n;
   4216   };
   4217 }
   4218 #endif
   4219 
   4220 #endif  // FJXL_SELF_INCLUDE