libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

entropy_coding.cc (29923B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jpegli/entropy_coding.h"
      7 
      8 #include <vector>
      9 
     10 #include "lib/jpegli/encode_internal.h"
     11 #include "lib/jpegli/error.h"
     12 #include "lib/jpegli/huffman.h"
     13 #include "lib/jxl/base/bits.h"
     14 
     15 #undef HWY_TARGET_INCLUDE
     16 #define HWY_TARGET_INCLUDE "lib/jpegli/entropy_coding.cc"
     17 #include <hwy/foreach_target.h>
     18 #include <hwy/highway.h>
     19 
     20 #include "lib/jpegli/entropy_coding-inl.h"
     21 
     22 HWY_BEFORE_NAMESPACE();
     23 namespace jpegli {
     24 namespace HWY_NAMESPACE {
     25 
     26 void ComputeTokensSequential(const coeff_t* block, int last_dc, int dc_ctx,
     27                              int ac_ctx, Token** tokens_ptr) {
     28   ComputeTokensForBlock<coeff_t, true>(block, last_dc, dc_ctx, ac_ctx,
     29                                        tokens_ptr);
     30 }
     31 
     32 // NOLINTNEXTLINE(google-readability-namespace-comments)
     33 }  // namespace HWY_NAMESPACE
     34 }  // namespace jpegli
     35 HWY_AFTER_NAMESPACE();
     36 
     37 #if HWY_ONCE
     38 namespace jpegli {
     39 
     40 size_t MaxNumTokensPerMCURow(j_compress_ptr cinfo) {
     41   int MCUs_per_row = DivCeil(cinfo->image_width, 8 * cinfo->max_h_samp_factor);
     42   size_t blocks_per_mcu = 0;
     43   for (int c = 0; c < cinfo->num_components; ++c) {
     44     jpeg_component_info* comp = &cinfo->comp_info[c];
     45     blocks_per_mcu += comp->h_samp_factor * comp->v_samp_factor;
     46   }
     47   return kDCTBlockSize * blocks_per_mcu * MCUs_per_row;
     48 }
     49 
     50 size_t EstimateNumTokens(j_compress_ptr cinfo, size_t mcu_y, size_t ysize_mcus,
     51                          size_t num_tokens, size_t max_per_row) {
     52   size_t estimate;
     53   if (mcu_y == 0) {
     54     estimate = 16 * max_per_row;
     55   } else {
     56     estimate = (4 * ysize_mcus * num_tokens) / (3 * mcu_y);
     57   }
     58   size_t mcus_left = ysize_mcus - mcu_y;
     59   return std::min(mcus_left * max_per_row,
     60                   std::max(max_per_row, estimate - num_tokens));
     61 }
     62 
     63 namespace {
     64 HWY_EXPORT(ComputeTokensSequential);
     65 
     66 void TokenizeProgressiveDC(const coeff_t* coeffs, int context, int Al,
     67                            coeff_t* last_dc_coeff, Token** next_token) {
     68   coeff_t temp2;
     69   coeff_t temp;
     70   temp2 = coeffs[0] >> Al;
     71   temp = temp2 - *last_dc_coeff;
     72   *last_dc_coeff = temp2;
     73   temp2 = temp;
     74   if (temp < 0) {
     75     temp = -temp;
     76     temp2--;
     77   }
     78   int nbits = (temp == 0) ? 0 : (jxl::FloorLog2Nonzero<uint32_t>(temp) + 1);
     79   int bits = temp2 & ((1 << nbits) - 1);
     80   *(*next_token)++ = Token(context, nbits, bits);
     81 }
     82 
     83 void TokenizeACProgressiveScan(j_compress_ptr cinfo, int scan_index,
     84                                int context, ScanTokenInfo* sti) {
     85   jpeg_comp_master* m = cinfo->master;
     86   const jpeg_scan_info* scan_info = &cinfo->scan_info[scan_index];
     87   const int comp_idx = scan_info->component_index[0];
     88   const jpeg_component_info* comp = &cinfo->comp_info[comp_idx];
     89   const int Al = scan_info->Al;
     90   const int Ss = scan_info->Ss;
     91   const int Se = scan_info->Se;
     92   const size_t restart_interval = sti->restart_interval;
     93   int restarts_to_go = restart_interval;
     94   size_t num_blocks = comp->height_in_blocks * comp->width_in_blocks;
     95   size_t num_restarts =
     96       restart_interval > 0 ? DivCeil(num_blocks, restart_interval) : 1;
     97   size_t restart_idx = 0;
     98   int eob_run = 0;
     99   TokenArray* ta = &m->token_arrays[m->cur_token_array];
    100   sti->token_offset = m->total_num_tokens + ta->num_tokens;
    101   sti->restarts = Allocate<size_t>(cinfo, num_restarts, JPOOL_IMAGE);
    102   const auto emit_eob_run = [&]() {
    103     int nbits = jxl::FloorLog2Nonzero<uint32_t>(eob_run);
    104     int symbol = nbits << 4u;
    105     *m->next_token++ = Token(context, symbol, eob_run & ((1 << nbits) - 1));
    106     eob_run = 0;
    107   };
    108   for (JDIMENSION by = 0; by < comp->height_in_blocks; ++by) {
    109     JBLOCKARRAY ba = (*cinfo->mem->access_virt_barray)(
    110         reinterpret_cast<j_common_ptr>(cinfo), m->coeff_buffers[comp_idx], by,
    111         1, FALSE);
    112     // Each coefficient can appear in at most one token, but we have to reserve
    113     // one extra EOBrun token that was rolled over from the previous block-row
    114     // and has to be flushed at the end.
    115     int max_tokens_per_row = 1 + comp->width_in_blocks * (Se - Ss + 1);
    116     if (ta->num_tokens + max_tokens_per_row > m->num_tokens) {
    117       if (ta->tokens) {
    118         m->total_num_tokens += ta->num_tokens;
    119         ++m->cur_token_array;
    120         ta = &m->token_arrays[m->cur_token_array];
    121       }
    122       m->num_tokens =
    123           EstimateNumTokens(cinfo, by, comp->height_in_blocks,
    124                             m->total_num_tokens, max_tokens_per_row);
    125       ta->tokens = Allocate<Token>(cinfo, m->num_tokens, JPOOL_IMAGE);
    126       m->next_token = ta->tokens;
    127     }
    128     for (JDIMENSION bx = 0; bx < comp->width_in_blocks; ++bx) {
    129       if (restart_interval > 0 && restarts_to_go == 0) {
    130         if (eob_run > 0) emit_eob_run();
    131         ta->num_tokens = m->next_token - ta->tokens;
    132         sti->restarts[restart_idx++] = m->total_num_tokens + ta->num_tokens;
    133         restarts_to_go = restart_interval;
    134       }
    135       const coeff_t* block = &ba[0][bx][0];
    136       coeff_t temp2;
    137       coeff_t temp;
    138       int r = 0;
    139       int num_nzeros = 0;
    140       int num_future_nzeros = 0;
    141       for (int k = Ss; k <= Se; ++k) {
    142         temp = block[k];
    143         if (temp == 0) {
    144           r++;
    145           continue;
    146         }
    147         if (temp < 0) {
    148           temp = -temp;
    149           temp >>= Al;
    150           temp2 = ~temp;
    151         } else {
    152           temp >>= Al;
    153           temp2 = temp;
    154         }
    155         if (temp == 0) {
    156           r++;
    157           num_future_nzeros++;
    158           continue;
    159         }
    160         if (eob_run > 0) emit_eob_run();
    161         while (r > 15) {
    162           *m->next_token++ = Token(context, 0xf0, 0);
    163           r -= 16;
    164         }
    165         int nbits = jxl::FloorLog2Nonzero<uint32_t>(temp) + 1;
    166         int symbol = (r << 4u) + nbits;
    167         *m->next_token++ = Token(context, symbol, temp2 & ((1 << nbits) - 1));
    168         ++num_nzeros;
    169         r = 0;
    170       }
    171       if (r > 0) {
    172         ++eob_run;
    173         if (eob_run == 0x7FFF) emit_eob_run();
    174       }
    175       sti->num_nonzeros += num_nzeros;
    176       sti->num_future_nonzeros += num_future_nzeros;
    177       --restarts_to_go;
    178     }
    179     ta->num_tokens = m->next_token - ta->tokens;
    180   }
    181   if (eob_run > 0) {
    182     emit_eob_run();
    183     ++ta->num_tokens;
    184   }
    185   sti->num_tokens = m->total_num_tokens + ta->num_tokens - sti->token_offset;
    186   sti->restarts[restart_idx++] = m->total_num_tokens + ta->num_tokens;
    187 }
    188 
    189 void TokenizeACRefinementScan(j_compress_ptr cinfo, int scan_index,
    190                               ScanTokenInfo* sti) {
    191   jpeg_comp_master* m = cinfo->master;
    192   const jpeg_scan_info* scan_info = &cinfo->scan_info[scan_index];
    193   const int comp_idx = scan_info->component_index[0];
    194   const jpeg_component_info* comp = &cinfo->comp_info[comp_idx];
    195   const int Al = scan_info->Al;
    196   const int Ss = scan_info->Ss;
    197   const int Se = scan_info->Se;
    198   const size_t restart_interval = sti->restart_interval;
    199   int restarts_to_go = restart_interval;
    200   RefToken token;
    201   int eob_run = 0;
    202   int eob_refbits = 0;
    203   size_t num_blocks = comp->height_in_blocks * comp->width_in_blocks;
    204   size_t num_restarts =
    205       restart_interval > 0 ? DivCeil(num_blocks, restart_interval) : 1;
    206   sti->tokens = m->next_refinement_token;
    207   sti->refbits = m->next_refinement_bit;
    208   sti->eobruns = Allocate<uint16_t>(cinfo, num_blocks / 2, JPOOL_IMAGE);
    209   sti->restarts = Allocate<size_t>(cinfo, num_restarts, JPOOL_IMAGE);
    210   RefToken* next_token = sti->tokens;
    211   RefToken* next_eob_token = next_token;
    212   uint8_t* next_ref_bit = sti->refbits;
    213   uint16_t* next_eobrun = sti->eobruns;
    214   size_t restart_idx = 0;
    215   for (JDIMENSION by = 0; by < comp->height_in_blocks; ++by) {
    216     JBLOCKARRAY ba = (*cinfo->mem->access_virt_barray)(
    217         reinterpret_cast<j_common_ptr>(cinfo), m->coeff_buffers[comp_idx], by,
    218         1, FALSE);
    219     for (JDIMENSION bx = 0; bx < comp->width_in_blocks; ++bx) {
    220       if (restart_interval > 0 && restarts_to_go == 0) {
    221         sti->restarts[restart_idx++] = next_token - sti->tokens;
    222         restarts_to_go = restart_interval;
    223         next_eob_token = next_token;
    224         eob_run = eob_refbits = 0;
    225       }
    226       const coeff_t* block = &ba[0][bx][0];
    227       int num_eob_refinement_bits = 0;
    228       int num_refinement_bits = 0;
    229       int num_nzeros = 0;
    230       int r = 0;
    231       for (int k = Ss; k <= Se; ++k) {
    232         int absval = block[k];
    233         if (absval == 0) {
    234           r++;
    235           continue;
    236         }
    237         const int mask = absval >> (8 * sizeof(int) - 1);
    238         absval += mask;
    239         absval ^= mask;
    240         absval >>= Al;
    241         if (absval == 0) {
    242           r++;
    243           continue;
    244         }
    245         while (r > 15) {
    246           token.symbol = 0xf0;
    247           token.refbits = num_refinement_bits;
    248           *next_token++ = token;
    249           r -= 16;
    250           num_eob_refinement_bits += num_refinement_bits;
    251           num_refinement_bits = 0;
    252         }
    253         if (absval > 1) {
    254           *next_ref_bit++ = absval & 1u;
    255           ++num_refinement_bits;
    256           continue;
    257         }
    258         int symbol = (r << 4u) + 1 + ((mask + 1) << 1);
    259         token.symbol = symbol;
    260         token.refbits = num_refinement_bits;
    261         *next_token++ = token;
    262         ++num_nzeros;
    263         num_refinement_bits = 0;
    264         num_eob_refinement_bits = 0;
    265         r = 0;
    266         next_eob_token = next_token;
    267         eob_run = eob_refbits = 0;
    268       }
    269       if (r > 0 || num_eob_refinement_bits + num_refinement_bits > 0) {
    270         ++eob_run;
    271         eob_refbits += num_eob_refinement_bits + num_refinement_bits;
    272         if (eob_refbits > 255) {
    273           ++next_eob_token;
    274           eob_refbits = num_eob_refinement_bits + num_refinement_bits;
    275           eob_run = 1;
    276         }
    277         next_token = next_eob_token;
    278         next_token->refbits = eob_refbits;
    279         if (eob_run == 1) {
    280           next_token->symbol = 0;
    281         } else if (eob_run == 2) {
    282           next_token->symbol = 16;
    283           *next_eobrun++ = 0;
    284         } else if ((eob_run & (eob_run - 1)) == 0) {
    285           next_token->symbol += 16;
    286           next_eobrun[-1] = 0;
    287         } else {
    288           ++next_eobrun[-1];
    289         }
    290         ++next_token;
    291         if (eob_run == 0x7fff) {
    292           next_eob_token = next_token;
    293           eob_run = eob_refbits = 0;
    294         }
    295       }
    296       sti->num_nonzeros += num_nzeros;
    297       --restarts_to_go;
    298     }
    299   }
    300   sti->num_tokens = next_token - sti->tokens;
    301   sti->restarts[restart_idx++] = sti->num_tokens;
    302   m->next_refinement_token = next_token;
    303   m->next_refinement_bit = next_ref_bit;
    304 }
    305 
    306 void TokenizeScan(j_compress_ptr cinfo, size_t scan_index, int ac_ctx_offset,
    307                   ScanTokenInfo* sti) {
    308   const jpeg_scan_info* scan_info = &cinfo->scan_info[scan_index];
    309   if (scan_info->Ss > 0) {
    310     if (scan_info->Ah == 0) {
    311       TokenizeACProgressiveScan(cinfo, scan_index, ac_ctx_offset, sti);
    312     } else {
    313       TokenizeACRefinementScan(cinfo, scan_index, sti);
    314     }
    315     return;
    316   }
    317 
    318   jpeg_comp_master* m = cinfo->master;
    319   size_t restart_interval = sti->restart_interval;
    320   int restarts_to_go = restart_interval;
    321   coeff_t last_dc_coeff[MAX_COMPS_IN_SCAN] = {0};
    322 
    323   // "Non-interleaved" means color data comes in separate scans, in other words
    324   // each scan can contain only one color component.
    325   const bool is_interleaved = (scan_info->comps_in_scan > 1);
    326   const bool is_progressive = FROM_JXL_BOOL(cinfo->progressive_mode);
    327   const int Ah = scan_info->Ah;
    328   const int Al = scan_info->Al;
    329   HWY_ALIGN constexpr coeff_t kSinkBlock[DCTSIZE2] = {0};
    330 
    331   size_t restart_idx = 0;
    332   TokenArray* ta = &m->token_arrays[m->cur_token_array];
    333   sti->token_offset = Ah > 0 ? 0 : m->total_num_tokens + ta->num_tokens;
    334 
    335   if (Ah > 0) {
    336     sti->refbits = Allocate<uint8_t>(cinfo, sti->num_blocks, JPOOL_IMAGE);
    337   } else if (cinfo->progressive_mode) {
    338     if (ta->num_tokens + sti->num_blocks > m->num_tokens) {
    339       if (ta->tokens) {
    340         m->total_num_tokens += ta->num_tokens;
    341         ++m->cur_token_array;
    342         ta = &m->token_arrays[m->cur_token_array];
    343       }
    344       m->num_tokens = sti->num_blocks;
    345       ta->tokens = Allocate<Token>(cinfo, m->num_tokens, JPOOL_IMAGE);
    346       m->next_token = ta->tokens;
    347     }
    348   }
    349 
    350   JBLOCKARRAY ba[MAX_COMPS_IN_SCAN];
    351   size_t block_idx = 0;
    352   for (size_t mcu_y = 0; mcu_y < sti->MCU_rows_in_scan; ++mcu_y) {
    353     for (int i = 0; i < scan_info->comps_in_scan; ++i) {
    354       int comp_idx = scan_info->component_index[i];
    355       jpeg_component_info* comp = &cinfo->comp_info[comp_idx];
    356       int n_blocks_y = is_interleaved ? comp->v_samp_factor : 1;
    357       int by0 = mcu_y * n_blocks_y;
    358       int block_rows_left = comp->height_in_blocks - by0;
    359       int max_block_rows = std::min(n_blocks_y, block_rows_left);
    360       ba[i] = (*cinfo->mem->access_virt_barray)(
    361           reinterpret_cast<j_common_ptr>(cinfo), m->coeff_buffers[comp_idx],
    362           by0, max_block_rows, FALSE);
    363     }
    364     if (!cinfo->progressive_mode) {
    365       int max_tokens_per_mcu_row = MaxNumTokensPerMCURow(cinfo);
    366       if (ta->num_tokens + max_tokens_per_mcu_row > m->num_tokens) {
    367         if (ta->tokens) {
    368           m->total_num_tokens += ta->num_tokens;
    369           ++m->cur_token_array;
    370           ta = &m->token_arrays[m->cur_token_array];
    371         }
    372         m->num_tokens =
    373             EstimateNumTokens(cinfo, mcu_y, sti->MCU_rows_in_scan,
    374                               m->total_num_tokens, max_tokens_per_mcu_row);
    375         ta->tokens = Allocate<Token>(cinfo, m->num_tokens, JPOOL_IMAGE);
    376         m->next_token = ta->tokens;
    377       }
    378     }
    379     for (size_t mcu_x = 0; mcu_x < sti->MCUs_per_row; ++mcu_x) {
    380       // Possibly emit a restart marker.
    381       if (restart_interval > 0 && restarts_to_go == 0) {
    382         restarts_to_go = restart_interval;
    383         memset(last_dc_coeff, 0, sizeof(last_dc_coeff));
    384         ta->num_tokens = m->next_token - ta->tokens;
    385         sti->restarts[restart_idx++] =
    386             Ah > 0 ? block_idx : m->total_num_tokens + ta->num_tokens;
    387       }
    388       // Encode one MCU
    389       for (int i = 0; i < scan_info->comps_in_scan; ++i) {
    390         int comp_idx = scan_info->component_index[i];
    391         jpeg_component_info* comp = &cinfo->comp_info[comp_idx];
    392         int n_blocks_y = is_interleaved ? comp->v_samp_factor : 1;
    393         int n_blocks_x = is_interleaved ? comp->h_samp_factor : 1;
    394         for (int iy = 0; iy < n_blocks_y; ++iy) {
    395           for (int ix = 0; ix < n_blocks_x; ++ix) {
    396             size_t block_y = mcu_y * n_blocks_y + iy;
    397             size_t block_x = mcu_x * n_blocks_x + ix;
    398             const coeff_t* block;
    399             if (block_x >= comp->width_in_blocks ||
    400                 block_y >= comp->height_in_blocks) {
    401               block = kSinkBlock;
    402             } else {
    403               block = &ba[i][iy][block_x][0];
    404             }
    405             if (!is_progressive) {
    406               HWY_DYNAMIC_DISPATCH(ComputeTokensSequential)
    407               (block, last_dc_coeff[i], comp_idx, ac_ctx_offset + i,
    408                &m->next_token);
    409               last_dc_coeff[i] = block[0];
    410             } else {
    411               if (Ah == 0) {
    412                 TokenizeProgressiveDC(block, comp_idx, Al, last_dc_coeff + i,
    413                                       &m->next_token);
    414               } else {
    415                 sti->refbits[block_idx] = (block[0] >> Al) & 1;
    416               }
    417             }
    418             ++block_idx;
    419           }
    420         }
    421       }
    422       --restarts_to_go;
    423     }
    424     ta->num_tokens = m->next_token - ta->tokens;
    425   }
    426   JXL_DASSERT(block_idx == sti->num_blocks);
    427   sti->num_tokens =
    428       Ah > 0 ? sti->num_blocks
    429              : m->total_num_tokens + ta->num_tokens - sti->token_offset;
    430   sti->restarts[restart_idx++] =
    431       Ah > 0 ? sti->num_blocks : m->total_num_tokens + ta->num_tokens;
    432   if (Ah == 0 && cinfo->progressive_mode) {
    433     JXL_DASSERT(sti->num_blocks == sti->num_tokens);
    434   }
    435 }
    436 
    437 }  // namespace
    438 
    439 void TokenizeJpeg(j_compress_ptr cinfo) {
    440   jpeg_comp_master* m = cinfo->master;
    441   std::vector<int> processed(cinfo->num_scans);
    442   size_t max_refinement_tokens = 0;
    443   size_t num_refinement_bits = 0;
    444   int num_refinement_scans[DCTSIZE2] = {};
    445   int max_num_refinement_scans = 0;
    446   for (int i = 0; i < cinfo->num_scans; ++i) {
    447     const jpeg_scan_info* si = &cinfo->scan_info[i];
    448     ScanTokenInfo* sti = &m->scan_token_info[i];
    449     if (si->Ss > 0 && si->Ah == 0 && si->Al > 0) {
    450       int offset = m->ac_ctx_offset[i];
    451       TokenizeScan(cinfo, i, offset, sti);
    452       processed[i] = 1;
    453       max_refinement_tokens += sti->num_future_nonzeros;
    454       for (int k = si->Ss; k <= si->Se; ++k) {
    455         num_refinement_scans[k] = si->Al;
    456       }
    457       max_num_refinement_scans = std::max(max_num_refinement_scans, si->Al);
    458       num_refinement_bits += sti->num_nonzeros;
    459     }
    460     if (si->Ss > 0 && si->Ah > 0) {
    461       int comp_idx = si->component_index[0];
    462       const jpeg_component_info* comp = &cinfo->comp_info[comp_idx];
    463       size_t num_blocks = comp->width_in_blocks * comp->height_in_blocks;
    464       max_refinement_tokens += (1 + (si->Se - si->Ss) / 16) * num_blocks;
    465     }
    466   }
    467   if (max_refinement_tokens > 0) {
    468     m->next_refinement_token =
    469         Allocate<RefToken>(cinfo, max_refinement_tokens, JPOOL_IMAGE);
    470   }
    471   for (int j = 0; j < max_num_refinement_scans; ++j) {
    472     uint8_t* refinement_bits =
    473         Allocate<uint8_t>(cinfo, num_refinement_bits, JPOOL_IMAGE);
    474     m->next_refinement_bit = refinement_bits;
    475     size_t new_refinement_bits = 0;
    476     for (int i = 0; i < cinfo->num_scans; ++i) {
    477       const jpeg_scan_info* si = &cinfo->scan_info[i];
    478       ScanTokenInfo* sti = &m->scan_token_info[i];
    479       if (si->Ss > 0 && si->Ah > 0 &&
    480           si->Ah == num_refinement_scans[si->Ss] - j) {
    481         int offset = m->ac_ctx_offset[i];
    482         TokenizeScan(cinfo, i, offset, sti);
    483         processed[i] = 1;
    484         new_refinement_bits += sti->num_nonzeros;
    485       }
    486     }
    487     JXL_DASSERT(m->next_refinement_bit ==
    488                 refinement_bits + num_refinement_bits);
    489     num_refinement_bits += new_refinement_bits;
    490   }
    491   for (int i = 0; i < cinfo->num_scans; ++i) {
    492     if (processed[i]) {
    493       continue;
    494     }
    495     int offset = m->ac_ctx_offset[i];
    496     TokenizeScan(cinfo, i, offset, &m->scan_token_info[i]);
    497     processed[i] = 1;
    498   }
    499 }
    500 
    501 namespace {
    502 
    503 struct Histogram {
    504   int count[kJpegHuffmanAlphabetSize];
    505   Histogram() { memset(count, 0, sizeof(count)); }
    506 };
    507 
    508 void BuildHistograms(j_compress_ptr cinfo, Histogram* histograms) {
    509   jpeg_comp_master* m = cinfo->master;
    510   size_t num_token_arrays = m->cur_token_array + 1;
    511   for (size_t i = 0; i < num_token_arrays; ++i) {
    512     Token* tokens = m->token_arrays[i].tokens;
    513     size_t num_tokens = m->token_arrays[i].num_tokens;
    514     for (size_t j = 0; j < num_tokens; ++j) {
    515       Token t = tokens[j];
    516       ++histograms[t.context].count[t.symbol];
    517     }
    518   }
    519   for (int i = 0; i < cinfo->num_scans; ++i) {
    520     const jpeg_scan_info& si = cinfo->scan_info[i];
    521     const ScanTokenInfo& sti = m->scan_token_info[i];
    522     if (si.Ss > 0 && si.Ah > 0) {
    523       int context = m->ac_ctx_offset[i];
    524       int* ac_histo = &histograms[context].count[0];
    525       for (size_t j = 0; j < sti.num_tokens; ++j) {
    526         ++ac_histo[sti.tokens[j].symbol & 253];
    527       }
    528     }
    529   }
    530 }
    531 
    532 struct JpegClusteredHistograms {
    533   std::vector<Histogram> histograms;
    534   std::vector<uint32_t> histogram_indexes;
    535   std::vector<uint32_t> slot_ids;
    536 };
    537 
    538 float HistogramCost(const Histogram& histo) {
    539   std::vector<uint32_t> counts(kJpegHuffmanAlphabetSize + 1);
    540   std::vector<uint8_t> depths(kJpegHuffmanAlphabetSize + 1);
    541   for (size_t i = 0; i < kJpegHuffmanAlphabetSize; ++i) {
    542     counts[i] = histo.count[i];
    543   }
    544   counts[kJpegHuffmanAlphabetSize] = 1;
    545   CreateHuffmanTree(counts.data(), counts.size(), kJpegHuffmanMaxBitLength,
    546                     depths.data());
    547   size_t header_bits = (1 + kJpegHuffmanMaxBitLength) * 8;
    548   size_t data_bits = 0;
    549   for (size_t i = 0; i < kJpegHuffmanAlphabetSize; ++i) {
    550     if (depths[i] > 0) {
    551       header_bits += 8;
    552       data_bits += counts[i] * depths[i];
    553     }
    554   }
    555   return header_bits + data_bits;
    556 }
    557 
    558 void AddHistograms(const Histogram& a, const Histogram& b, Histogram* c) {
    559   for (size_t i = 0; i < kJpegHuffmanAlphabetSize; ++i) {
    560     c->count[i] = a.count[i] + b.count[i];
    561   }
    562 }
    563 
    564 bool IsEmptyHistogram(const Histogram& histo) {
    565   for (int count : histo.count) {
    566     if (count) return false;
    567   }
    568   return true;
    569 }
    570 
    571 void ClusterJpegHistograms(const Histogram* histograms, size_t num,
    572                            JpegClusteredHistograms* clusters) {
    573   clusters->histogram_indexes.resize(num);
    574   std::vector<uint32_t> slot_histograms;
    575   std::vector<float> slot_costs;
    576   for (size_t i = 0; i < num; ++i) {
    577     const Histogram& cur = histograms[i];
    578     if (IsEmptyHistogram(cur)) {
    579       continue;
    580     }
    581     float best_cost = HistogramCost(cur);
    582     size_t best_slot = slot_histograms.size();
    583     for (size_t j = 0; j < slot_histograms.size(); ++j) {
    584       size_t prev_idx = slot_histograms[j];
    585       const Histogram& prev = clusters->histograms[prev_idx];
    586       Histogram combined;
    587       AddHistograms(prev, cur, &combined);
    588       float combined_cost = HistogramCost(combined);
    589       float cost = combined_cost - slot_costs[j];
    590       if (cost < best_cost) {
    591         best_cost = cost;
    592         best_slot = j;
    593       }
    594     }
    595     if (best_slot == slot_histograms.size()) {
    596       // Create new histogram.
    597       size_t histogram_index = clusters->histograms.size();
    598       clusters->histograms.push_back(cur);
    599       clusters->histogram_indexes[i] = histogram_index;
    600       if (best_slot < 4) {
    601         // We have a free slot, so we put the new histogram there.
    602         slot_histograms.push_back(histogram_index);
    603         slot_costs.push_back(best_cost);
    604       } else {
    605         // TODO(szabadka) Find the best histogram to replce.
    606         best_slot = (clusters->slot_ids.back() + 1) % 4;
    607       }
    608       slot_histograms[best_slot] = histogram_index;
    609       slot_costs[best_slot] = best_cost;
    610       clusters->slot_ids.push_back(best_slot);
    611     } else {
    612       // Merge this histogram with a previous one.
    613       size_t histogram_index = slot_histograms[best_slot];
    614       const Histogram& prev = clusters->histograms[histogram_index];
    615       AddHistograms(prev, cur, &clusters->histograms[histogram_index]);
    616       clusters->histogram_indexes[i] = histogram_index;
    617       JXL_ASSERT(clusters->slot_ids[histogram_index] == best_slot);
    618       slot_costs[best_slot] += best_cost;
    619     }
    620   }
    621 }
    622 
    623 void CopyHuffmanTable(j_compress_ptr cinfo, int index, bool is_dc,
    624                       int* inv_slot_map, uint8_t* slot_id_map,
    625                       JHUFF_TBL* huffman_tables, size_t* num_huffman_tables) {
    626   const char* type = is_dc ? "DC" : "AC";
    627   if (index < 0 || index >= NUM_HUFF_TBLS) {
    628     JPEGLI_ERROR("Invalid %s Huffman table index %d", type, index);
    629   }
    630   // Check if we have already copied this Huffman table.
    631   int slot_idx = index + (is_dc ? 0 : NUM_HUFF_TBLS);
    632   if (inv_slot_map[slot_idx] != -1) {
    633     return;
    634   }
    635   inv_slot_map[slot_idx] = *num_huffman_tables;
    636   // Look up and validate Huffman table.
    637   JHUFF_TBL* table =
    638       is_dc ? cinfo->dc_huff_tbl_ptrs[index] : cinfo->ac_huff_tbl_ptrs[index];
    639   if (table == nullptr) {
    640     JPEGLI_ERROR("Missing %s Huffman table %d", type, index);
    641   }
    642   ValidateHuffmanTable(reinterpret_cast<j_common_ptr>(cinfo), table, is_dc);
    643   // Copy Huffman table to the end of the list and save slot id.
    644   slot_id_map[*num_huffman_tables] = index + (is_dc ? 0 : 0x10);
    645   memcpy(&huffman_tables[*num_huffman_tables], table, sizeof(JHUFF_TBL));
    646   ++(*num_huffman_tables);
    647 }
    648 
    649 void BuildJpegHuffmanTable(const Histogram& histo, JHUFF_TBL* table) {
    650   std::vector<uint32_t> counts(kJpegHuffmanAlphabetSize + 1);
    651   std::vector<uint8_t> depths(kJpegHuffmanAlphabetSize + 1);
    652   for (size_t j = 0; j < kJpegHuffmanAlphabetSize; ++j) {
    653     counts[j] = histo.count[j];
    654   }
    655   counts[kJpegHuffmanAlphabetSize] = 1;
    656   CreateHuffmanTree(counts.data(), counts.size(), kJpegHuffmanMaxBitLength,
    657                     depths.data());
    658   memset(table, 0, sizeof(JHUFF_TBL));
    659   for (size_t i = 0; i < kJpegHuffmanAlphabetSize; ++i) {
    660     if (depths[i] > 0) {
    661       ++table->bits[depths[i]];
    662     }
    663   }
    664   int offset[kJpegHuffmanMaxBitLength + 1] = {0};
    665   for (size_t i = 1; i <= kJpegHuffmanMaxBitLength; ++i) {
    666     offset[i] = offset[i - 1] + table->bits[i - 1];
    667   }
    668   for (size_t i = 0; i < kJpegHuffmanAlphabetSize; ++i) {
    669     if (depths[i] > 0) {
    670       table->huffval[offset[depths[i]]++] = i;
    671     }
    672   }
    673 }
    674 
    675 }  // namespace
    676 
    677 void CopyHuffmanTables(j_compress_ptr cinfo) {
    678   jpeg_comp_master* m = cinfo->master;
    679   size_t max_huff_tables = 2 * cinfo->num_components;
    680   // Copy Huffman tables and save slot ids.
    681   m->huffman_tables = Allocate<JHUFF_TBL>(cinfo, max_huff_tables, JPOOL_IMAGE);
    682   m->slot_id_map = Allocate<uint8_t>(cinfo, max_huff_tables, JPOOL_IMAGE);
    683   m->num_huffman_tables = 0;
    684   int inv_slot_map[8] = {-1, -1, -1, -1, -1, -1, -1, -1};
    685   for (int c = 0; c < cinfo->num_components; ++c) {
    686     jpeg_component_info* comp = &cinfo->comp_info[c];
    687     CopyHuffmanTable(cinfo, comp->dc_tbl_no, /*is_dc=*/true, &inv_slot_map[0],
    688                      m->slot_id_map, m->huffman_tables, &m->num_huffman_tables);
    689     CopyHuffmanTable(cinfo, comp->ac_tbl_no, /*is_dc=*/false, &inv_slot_map[0],
    690                      m->slot_id_map, m->huffman_tables, &m->num_huffman_tables);
    691   }
    692   // Compute context map.
    693   m->context_map = Allocate<uint8_t>(cinfo, 8, JPOOL_IMAGE);
    694   memset(m->context_map, 0, 8);
    695   for (int c = 0; c < cinfo->num_components; ++c) {
    696     m->context_map[c] = inv_slot_map[cinfo->comp_info[c].dc_tbl_no];
    697   }
    698   int ac_ctx = 4;
    699   for (int i = 0; i < cinfo->num_scans; ++i) {
    700     const jpeg_scan_info* si = &cinfo->scan_info[i];
    701     if (si->Se > 0) {
    702       for (int j = 0; j < si->comps_in_scan; ++j) {
    703         int c = si->component_index[j];
    704         jpeg_component_info* comp = &cinfo->comp_info[c];
    705         m->context_map[ac_ctx++] = inv_slot_map[comp->ac_tbl_no + 4];
    706       }
    707     }
    708   }
    709 }
    710 
    711 void OptimizeHuffmanCodes(j_compress_ptr cinfo) {
    712   jpeg_comp_master* m = cinfo->master;
    713   // Build DC and AC histograms.
    714   std::vector<Histogram> histograms(m->num_contexts);
    715   BuildHistograms(cinfo, histograms.data());
    716 
    717   // Cluster DC histograms.
    718   JpegClusteredHistograms dc_clusters;
    719   ClusterJpegHistograms(histograms.data(), cinfo->num_components, &dc_clusters);
    720 
    721   // Cluster AC histograms.
    722   JpegClusteredHistograms ac_clusters;
    723   ClusterJpegHistograms(histograms.data() + 4, m->num_contexts - 4,
    724                         &ac_clusters);
    725 
    726   // Create Huffman tables and slot ids clusters.
    727   size_t num_dc_huff = dc_clusters.histograms.size();
    728   m->num_huffman_tables = num_dc_huff + ac_clusters.histograms.size();
    729   m->huffman_tables =
    730       Allocate<JHUFF_TBL>(cinfo, m->num_huffman_tables, JPOOL_IMAGE);
    731   m->slot_id_map = Allocate<uint8_t>(cinfo, m->num_huffman_tables, JPOOL_IMAGE);
    732   for (size_t i = 0; i < m->num_huffman_tables; ++i) {
    733     JHUFF_TBL huff_table = {};
    734     if (i < dc_clusters.histograms.size()) {
    735       m->slot_id_map[i] = i;
    736       BuildJpegHuffmanTable(dc_clusters.histograms[i], &huff_table);
    737     } else {
    738       m->slot_id_map[i] = 16 + ac_clusters.slot_ids[i - num_dc_huff];
    739       BuildJpegHuffmanTable(ac_clusters.histograms[i - num_dc_huff],
    740                             &huff_table);
    741     }
    742     memcpy(&m->huffman_tables[i], &huff_table, sizeof(huff_table));
    743   }
    744 
    745   // Create context map from clustered histogram indexes.
    746   m->context_map = Allocate<uint8_t>(cinfo, m->num_contexts, JPOOL_IMAGE);
    747   memset(m->context_map, 0, m->num_contexts);
    748   for (size_t i = 0; i < m->num_contexts; ++i) {
    749     if (i < static_cast<size_t>(cinfo->num_components)) {
    750       m->context_map[i] = dc_clusters.histogram_indexes[i];
    751     } else if (i >= 4) {
    752       m->context_map[i] = num_dc_huff + ac_clusters.histogram_indexes[i - 4];
    753     }
    754   }
    755 }
    756 
    757 namespace {
    758 
    759 constexpr uint8_t kNumExtraBits[256] = {
    760     0,  1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,  //
    761     1,  1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,  //
    762     2,  1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,  //
    763     3,  1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,  //
    764     4,  1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,  //
    765     5,  1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,  //
    766     6,  1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,  //
    767     7,  1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,  //
    768     8,  1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,  //
    769     9,  1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,  //
    770     10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,  //
    771     11, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,  //
    772     12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,  //
    773     13, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,  //
    774     14, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,  //
    775     0,  1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,  //
    776 };
    777 
    778 void BuildHuffmanCodeTable(const JHUFF_TBL& table, HuffmanCodeTable* code) {
    779   int huff_code[kJpegHuffmanAlphabetSize];
    780   // +1 for a sentinel element.
    781   uint32_t huff_size[kJpegHuffmanAlphabetSize + 1];
    782   int p = 0;
    783   for (size_t l = 1; l <= kJpegHuffmanMaxBitLength; ++l) {
    784     int i = table.bits[l];
    785     while (i--) huff_size[p++] = l;
    786   }
    787 
    788   // Reuse sentinel element.
    789   int last_p = p;
    790   huff_size[last_p] = 0;
    791 
    792   int next_code = 0;
    793   uint32_t si = huff_size[0];
    794   p = 0;
    795   while (huff_size[p]) {
    796     while ((huff_size[p]) == si) {
    797       huff_code[p++] = next_code;
    798       next_code++;
    799     }
    800     next_code <<= 1;
    801     si++;
    802   }
    803   for (p = 0; p < last_p; p++) {
    804     int i = table.huffval[p];
    805     int nbits = kNumExtraBits[i];
    806     code->depth[i] = huff_size[p] + nbits;
    807     code->code[i] = huff_code[p] << nbits;
    808   }
    809 }
    810 
    811 }  // namespace
    812 
    813 void InitEntropyCoder(j_compress_ptr cinfo) {
    814   jpeg_comp_master* m = cinfo->master;
    815   m->coding_tables =
    816       Allocate<HuffmanCodeTable>(cinfo, m->num_huffman_tables, JPOOL_IMAGE);
    817   for (size_t i = 0; i < m->num_huffman_tables; ++i) {
    818     BuildHuffmanCodeTable(m->huffman_tables[i], &m->coding_tables[i]);
    819   }
    820 }
    821 
    822 }  // namespace jpegli
    823 #endif  // HWY_ONCE