libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

color_quantize.cc (18323B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jpegli/color_quantize.h"
      7 
      8 #include <cmath>
      9 #include <limits>
     10 #include <unordered_map>
     11 
     12 #include "lib/jpegli/decode_internal.h"
     13 #include "lib/jpegli/error.h"
     14 #include "lib/jxl/base/status.h"
     15 
     16 namespace jpegli {
     17 
     18 namespace {
     19 
     20 constexpr int kNumColorCellBits[kMaxComponents] = {3, 4, 3, 3};
     21 constexpr int kCompW[kMaxComponents] = {2, 3, 1, 1};
     22 
     23 int Pow(int a, int b) {
     24   int r = 1;
     25   for (int i = 0; i < b; ++i) {
     26     r *= a;
     27   }
     28   return r;
     29 }
     30 
     31 int ComponentOrder(j_decompress_ptr cinfo, int i) {
     32   if (cinfo->out_color_components == 3) {
     33     return i < 2 ? 1 - i : i;
     34   }
     35   return i;
     36 }
     37 
     38 int GetColorComponent(int i, int N) {
     39   return (i * 255 + (N - 1) / 2) / (N - 1);
     40 }
     41 
     42 }  // namespace
     43 
     44 void ChooseColorMap1Pass(j_decompress_ptr cinfo) {
     45   jpeg_decomp_master* m = cinfo->master;
     46   int components = cinfo->out_color_components;
     47   int desired = std::min(cinfo->desired_number_of_colors, 256);
     48   int num = 1;
     49   while (Pow(num + 1, components) <= desired) {
     50     ++num;
     51   }
     52   if (num == 1) {
     53     JPEGLI_ERROR("Too few colors (%d) in requested colormap", desired);
     54   }
     55   int actual = Pow(num, components);
     56   for (int i = 0; i < components; ++i) {
     57     m->num_colors_[i] = num;
     58   }
     59   while (actual < desired) {
     60     int total = actual;
     61     for (int i = 0; i < components; ++i) {
     62       int c = ComponentOrder(cinfo, i);
     63       int new_total = (actual / m->num_colors_[c]) * (m->num_colors_[c] + 1);
     64       if (new_total <= desired) {
     65         ++m->num_colors_[c];
     66         actual = new_total;
     67       }
     68     }
     69     if (actual == total) {
     70       break;
     71     }
     72   }
     73   cinfo->actual_number_of_colors = actual;
     74   cinfo->colormap = (*cinfo->mem->alloc_sarray)(
     75       reinterpret_cast<j_common_ptr>(cinfo), JPOOL_IMAGE, actual, components);
     76   int next_color[kMaxComponents] = {0};
     77   for (int i = 0; i < actual; ++i) {
     78     for (int c = 0; c < components; ++c) {
     79       cinfo->colormap[c][i] =
     80           GetColorComponent(next_color[c], m->num_colors_[c]);
     81     }
     82     int c = components - 1;
     83     while (c > 0 && next_color[c] + 1 == m->num_colors_[c]) {
     84       next_color[c--] = 0;
     85     }
     86     ++next_color[c];
     87   }
     88   if (!m->colormap_lut_) {
     89     m->colormap_lut_ = Allocate<uint8_t>(cinfo, components * 256, JPOOL_IMAGE);
     90   }
     91   int stride = actual;
     92   for (int c = 0; c < components; ++c) {
     93     int N = m->num_colors_[c];
     94     stride /= N;
     95     for (int i = 0; i < 256; ++i) {
     96       int index = ((2 * i - 1) * (N - 1) + 254) / 510;
     97       m->colormap_lut_[c * 256 + i] = index * stride;
     98     }
     99   }
    100 }
    101 
    102 namespace {
    103 
    104 // 2^13 priority levels for the PQ seems to be a good compromise between
    105 // accuracy, running time and stack space usage.
    106 const int kMaxPriority = 1 << 13;
    107 const int kMaxLevel = 3;
    108 
    109 // This function is used in the multi-resolution grid to be able to compute
    110 // the keys for the different resolutions by just shifting the first key.
    111 inline int InterlaceBitsRGB(uint8_t r, uint8_t g, uint8_t b) {
    112   int z = 0;
    113   for (int i = 0; i < 7; ++i) {
    114     z += (r >> 5) & 4;
    115     z += (g >> 6) & 2;
    116     z += (b >> 7);
    117     z <<= 3;
    118     r <<= 1;
    119     g <<= 1;
    120     b <<= 1;
    121   }
    122   z += (r >> 5) & 4;
    123   z += (g >> 6) & 2;
    124   z += (b >> 7);
    125   return z;
    126 }
    127 
    128 // This function will compute the actual priorities of the colors based on
    129 // the current distance from the palette, the population count and the signals
    130 // from the multi-resolution grid.
    131 inline int Priority(int d, int n, const int* density, const int* radius) {
    132   int p = d * n;
    133   for (int level = 0; level < kMaxLevel; ++level) {
    134     if (d > radius[level]) {
    135       p += density[level] * (d - radius[level]);
    136     }
    137   }
    138   return std::min(kMaxPriority - 1, p >> 4);
    139 }
    140 
    141 inline int ColorIntQuadDistanceRGB(uint8_t r1, uint8_t g1, uint8_t b1,
    142                                    uint8_t r2, uint8_t g2, uint8_t b2) {
    143   // weights for the intensity calculation
    144   static constexpr int ired = 2;
    145   static constexpr int igreen = 5;
    146   static constexpr int iblue = 1;
    147   // normalization factor for the intensity calculation (2^ishift)
    148   static constexpr int ishift = 3;
    149   const int rd = r1 - r2;
    150   const int gd = g1 - g2;
    151   const int bd = b1 - b2;
    152   const int id = ired * rd + igreen * gd + iblue * bd;
    153   return rd * rd + gd * gd + bd * bd + ((id * id) >> (2 * ishift));
    154 }
    155 
    156 inline int ScaleQuadDistanceRGB(int d) {
    157   return static_cast<int>(std::lround(sqrt(d * 0.25)));
    158 }
    159 
    160 // The function updates the minimal distances, the clustering and the
    161 // quantization error after the insertion of the new color into the palette.
    162 void AddToRGBPalette(const uint8_t* red, const uint8_t* green,
    163                      const uint8_t* blue,
    164                      const int* count,  // histogram of colors
    165                      const int index,   // index of color to be added
    166                      const int k,       // size of current palette
    167                      const int n,       // number of colors
    168                      int* dist,         // array of distances from palette
    169                      int* cluster,      // mapping of color indices to palette
    170                      int* center,       // the inverse mapping
    171                      int64_t* error) {  // measure of the quantization error
    172   center[k] = index;
    173   cluster[index] = k;
    174   *error -=
    175       static_cast<int64_t>(dist[index]) * static_cast<int64_t>(count[index]);
    176   dist[index] = 0;
    177   for (int j = 0; j < n; ++j) {
    178     if (dist[j] > 0) {
    179       const int d = ColorIntQuadDistanceRGB(
    180           red[index], green[index], blue[index], red[j], green[j], blue[j]);
    181       if (d < dist[j]) {
    182         *error += static_cast<int64_t>((d - dist[j])) *
    183                   static_cast<int64_t>(count[j]);
    184         dist[j] = d;
    185         cluster[j] = k;
    186       }
    187     }
    188   }
    189 }
    190 
    191 struct RGBPixelHasher {
    192   // A quick but good-enough hash to get 24 bits of RGB into the lower 12 bits.
    193   size_t operator()(uint32_t a) const { return (a ^ (a >> 12)) * 0x9e3779b9; }
    194 };
    195 
    196 struct WangHasher {
    197   // Thomas Wang's Hash.  Nearly perfect and still quite fast.  Above (for
    198   // pixels) we use a simpler hash because the number of hash calls is
    199   // proportional to the number of pixels and that hash dominates; we want the
    200   // cost to be minimal and we start with a large table.  We can use a better
    201   // hash for the histogram since the number of hash calls is proportional to
    202   // the number of unique colors in the image, which is hopefully much smaller.
    203   // Note that the difference is slight; e.g. replacing RGBPixelHasher with
    204   // WangHasher only slows things down by 5% on an Opteron.
    205   size_t operator()(uint32_t a) const {
    206     a = (a ^ 61) ^ (a >> 16);
    207     a = a + (a << 3);
    208     a = a ^ (a >> 4);
    209     a = a * 0x27d4eb2d;
    210     a = a ^ (a >> 15);
    211     return a;
    212   }
    213 };
    214 
    215 // Build an index of all the different colors in the input
    216 // image. To do this we map the 24 bit RGB representation of the colors
    217 // to a unique integer index assigned to the different colors in order of
    218 // appearance in the image.  Return the number of unique colors found.
    219 // The colors are pre-quantized to 3 * 6 bits precision.
    220 int BuildRGBColorIndex(const uint8_t* const image, int const num_pixels,
    221                        int* const count, uint8_t* const red,
    222                        uint8_t* const green, uint8_t* const blue) {
    223   // Impossible because rgb are in the low 24 bits, and the upper 8 bits is 0.
    224   const uint32_t impossible_pixel_value = 0x10000000;
    225   std::unordered_map<uint32_t, int, RGBPixelHasher> index_map(1 << 12);
    226   std::unordered_map<uint32_t, int, RGBPixelHasher>::iterator index_map_lookup;
    227   const uint8_t* imagep = &image[0];
    228   uint32_t prev_pixel = impossible_pixel_value;
    229   int index = 0;
    230   int n = 0;
    231   for (int i = 0; i < num_pixels; ++i) {
    232     uint8_t r = ((*imagep++) & 0xfc) + 2;
    233     uint8_t g = ((*imagep++) & 0xfc) + 2;
    234     uint8_t b = ((*imagep++) & 0xfc) + 2;
    235     uint32_t pixel = (b << 16) | (g << 8) | r;
    236     if (pixel != prev_pixel) {
    237       prev_pixel = pixel;
    238       index_map_lookup = index_map.find(pixel);
    239       if (index_map_lookup != index_map.end()) {
    240         index = index_map_lookup->second;
    241       } else {
    242         index_map[pixel] = index = n++;
    243         red[index] = r;
    244         green[index] = g;
    245         blue[index] = b;
    246       }
    247     }
    248     ++count[index];
    249   }
    250   return n;
    251 }
    252 
    253 }  // namespace
    254 
    255 void ChooseColorMap2Pass(j_decompress_ptr cinfo) {
    256   if (cinfo->out_color_space != JCS_RGB) {
    257     JPEGLI_ERROR("Two-pass quantizer must use RGB output color space.");
    258   }
    259   jpeg_decomp_master* m = cinfo->master;
    260   const size_t num_pixels = cinfo->output_width * cinfo->output_height;
    261   const int max_color_count = std::max<size_t>(num_pixels, 1u << 18);
    262   const int max_palette_size = cinfo->desired_number_of_colors;
    263   std::unique_ptr<uint8_t[]> red(new uint8_t[max_color_count]);
    264   std::unique_ptr<uint8_t[]> green(new uint8_t[max_color_count]);
    265   std::unique_ptr<uint8_t[]> blue(new uint8_t[max_color_count]);
    266   std::vector<int> count(max_color_count, 0);
    267   // number of colors
    268   int n = BuildRGBColorIndex(m->pixels_, num_pixels, count.data(), &red[0],
    269                              &green[0], &blue[0]);
    270 
    271   std::vector<int> dist(n, std::numeric_limits<int>::max());
    272   std::vector<int> cluster(n);
    273   std::vector<bool> in_palette(n, false);
    274   int center[256];
    275   int k = 0;  // palette size
    276   const int count_threshold = (num_pixels * 4) / max_palette_size;
    277   static constexpr int kAveragePixelErrorThreshold = 1;
    278   const int64_t error_threshold = num_pixels * kAveragePixelErrorThreshold;
    279   int64_t error = 0;  // quantization error
    280 
    281   int max_count = 0;
    282   int winner = 0;
    283   for (int i = 0; i < n; ++i) {
    284     if (count[i] > max_count) {
    285       max_count = count[i];
    286       winner = i;
    287     }
    288     if (!in_palette[i] && count[i] > count_threshold) {
    289       AddToRGBPalette(&red[0], &green[0], &blue[0], count.data(), i, k++, n,
    290                       dist.data(), cluster.data(), &center[0], &error);
    291       in_palette[i] = true;
    292     }
    293   }
    294   if (k == 0) {
    295     AddToRGBPalette(&red[0], &green[0], &blue[0], count.data(), winner, k++, n,
    296                     dist.data(), cluster.data(), &center[0], &error);
    297     in_palette[winner] = true;
    298   }
    299 
    300   // Calculation of the multi-resolution density grid.
    301   std::vector<int> density(n * kMaxLevel);
    302   std::vector<int> radius(n * kMaxLevel);
    303   std::unordered_map<uint32_t, int, WangHasher> histogram[kMaxLevel];
    304   for (int level = 0; level < kMaxLevel; ++level) {
    305     // This value is never used because key = InterlaceBitsRGB(...) >> 6
    306   }
    307 
    308   for (int i = 0; i < n; ++i) {
    309     if (!in_palette[i]) {
    310       const int key = InterlaceBitsRGB(red[i], green[i], blue[i]) >> 6;
    311       for (int level = 0; level < kMaxLevel; ++level) {
    312         histogram[level][key >> (3 * level)] += count[i];
    313       }
    314     }
    315   }
    316   for (int i = 0; i < n; ++i) {
    317     if (!in_palette[i]) {
    318       for (int level = 0; level < kMaxLevel; ++level) {
    319         const int mask = (4 << level) - 1;
    320         const int rd = std::max(red[i] & mask, mask - (red[i] & mask));
    321         const int gd = std::max(green[i] & mask, mask - (green[i] & mask));
    322         const int bd = std::max(blue[i] & mask, mask - (blue[i] & mask));
    323         radius[i * kMaxLevel + level] =
    324             ScaleQuadDistanceRGB(ColorIntQuadDistanceRGB(0, 0, 0, rd, gd, bd));
    325       }
    326       const int key = InterlaceBitsRGB(red[i], green[i], blue[i]) >> 6;
    327       if (kMaxLevel > 0) {
    328         density[i * kMaxLevel] = histogram[0][key] - count[i];
    329       }
    330       for (int level = 1; level < kMaxLevel; ++level) {
    331         density[i * kMaxLevel + level] =
    332             (histogram[level][key >> (3 * level)] -
    333              histogram[level - 1][key >> (3 * level - 3)]);
    334       }
    335     }
    336   }
    337 
    338   // Calculate the initial error now that the palette has been initialized.
    339   error = 0;
    340   for (int i = 0; i < n; ++i) {
    341     error += static_cast<int64_t>(dist[i]) * static_cast<int64_t>(count[i]);
    342   }
    343 
    344   std::unique_ptr<std::vector<int>[]> bucket_array(
    345       new std::vector<int>[kMaxPriority]);
    346   int top_priority = -1;
    347   for (int i = 0; i < n; ++i) {
    348     if (!in_palette[i]) {
    349       int priority = Priority(ScaleQuadDistanceRGB(dist[i]), count[i],
    350                               &density[i * kMaxLevel], &radius[i * kMaxLevel]);
    351       bucket_array[priority].push_back(i);
    352       top_priority = std::max(priority, top_priority);
    353     }
    354   }
    355   double error_accum = 0;
    356   while (top_priority >= 0 && k < max_palette_size) {
    357     if (error < error_threshold) {
    358       error_accum += std::min(error_threshold, error_threshold - error);
    359       if (error_accum >= 10 * error_threshold) {
    360         break;
    361       }
    362     }
    363     int i = bucket_array[top_priority].back();
    364     int priority = Priority(ScaleQuadDistanceRGB(dist[i]), count[i],
    365                             &density[i * kMaxLevel], &radius[i * kMaxLevel]);
    366     if (priority < top_priority) {
    367       bucket_array[priority].push_back(i);
    368     } else {
    369       AddToRGBPalette(&red[0], &green[0], &blue[0], count.data(), i, k++, n,
    370                       dist.data(), cluster.data(), &center[0], &error);
    371     }
    372     bucket_array[top_priority].pop_back();
    373     while (top_priority >= 0 && bucket_array[top_priority].empty()) {
    374       --top_priority;
    375     }
    376   }
    377 
    378   cinfo->actual_number_of_colors = k;
    379   cinfo->colormap = (*cinfo->mem->alloc_sarray)(
    380       reinterpret_cast<j_common_ptr>(cinfo), JPOOL_IMAGE, k, 3);
    381   for (int i = 0; i < k; ++i) {
    382     int index = center[i];
    383     cinfo->colormap[0][i] = red[index];
    384     cinfo->colormap[1][i] = green[index];
    385     cinfo->colormap[2][i] = blue[index];
    386   }
    387 }
    388 
    389 namespace {
    390 
    391 void FindCandidatesForCell(j_decompress_ptr cinfo, int ncomp, const int cell[],
    392                            std::vector<uint8_t>* candidates) {
    393   int cell_min[kMaxComponents];
    394   int cell_max[kMaxComponents];
    395   int cell_center[kMaxComponents];
    396   for (int c = 0; c < ncomp; ++c) {
    397     cell_min[c] = cell[c] << (8 - kNumColorCellBits[c]);
    398     cell_max[c] = cell_min[c] + (1 << (8 - kNumColorCellBits[c])) - 1;
    399     cell_center[c] = (cell_min[c] + cell_max[c]) >> 1;
    400   }
    401   int min_maxdist = std::numeric_limits<int>::max();
    402   int mindist[256];
    403   for (int i = 0; i < cinfo->actual_number_of_colors; ++i) {
    404     int dmin = 0;
    405     int dmax = 0;
    406     for (int c = 0; c < ncomp; ++c) {
    407       int palette_c = cinfo->colormap[c][i];
    408       int dminc = 0;
    409       int dmaxc;
    410       if (palette_c < cell_min[c]) {
    411         dminc = cell_min[c] - palette_c;
    412         dmaxc = cell_max[c] - palette_c;
    413       } else if (palette_c > cell_max[c]) {
    414         dminc = palette_c - cell_max[c];
    415         dmaxc = palette_c - cell_min[c];
    416       } else if (palette_c > cell_center[c]) {
    417         dmaxc = palette_c - cell_min[c];
    418       } else {
    419         dmaxc = cell_max[c] - palette_c;
    420       }
    421       dminc *= kCompW[c];
    422       dmaxc *= kCompW[c];
    423       dmin += dminc * dminc;
    424       dmax += dmaxc * dmaxc;
    425     }
    426     mindist[i] = dmin;
    427     min_maxdist = std::min(dmax, min_maxdist);
    428   }
    429   for (int i = 0; i < cinfo->actual_number_of_colors; ++i) {
    430     if (mindist[i] < min_maxdist) {
    431       candidates->push_back(i);
    432     }
    433   }
    434 }
    435 
    436 }  // namespace
    437 
    438 void CreateInverseColorMap(j_decompress_ptr cinfo) {
    439   jpeg_decomp_master* m = cinfo->master;
    440   int ncomp = cinfo->out_color_components;
    441   JXL_ASSERT(ncomp > 0);
    442   JXL_ASSERT(ncomp <= kMaxComponents);
    443   int num_cells = 1;
    444   for (int c = 0; c < ncomp; ++c) {
    445     num_cells *= (1 << kNumColorCellBits[c]);
    446   }
    447   m->candidate_lists_.resize(num_cells);
    448 
    449   int next_cell[kMaxComponents] = {0};
    450   for (int i = 0; i < num_cells; ++i) {
    451     m->candidate_lists_[i].clear();
    452     FindCandidatesForCell(cinfo, ncomp, next_cell, &m->candidate_lists_[i]);
    453     int c = ncomp - 1;
    454     while (c > 0 && next_cell[c] + 1 == (1 << kNumColorCellBits[c])) {
    455       next_cell[c--] = 0;
    456     }
    457     ++next_cell[c];
    458   }
    459   m->regenerate_inverse_colormap_ = false;
    460 }
    461 
    462 int LookupColorIndex(j_decompress_ptr cinfo, const JSAMPLE* pixel) {
    463   jpeg_decomp_master* m = cinfo->master;
    464   int num_channels = cinfo->out_color_components;
    465   int index = 0;
    466   if (m->quant_mode_ == 1) {
    467     for (int c = 0; c < num_channels; ++c) {
    468       index += m->colormap_lut_[c * 256 + pixel[c]];
    469     }
    470   } else {
    471     size_t cell_idx = 0;
    472     size_t stride = 1;
    473     for (int c = num_channels - 1; c >= 0; --c) {
    474       cell_idx += (pixel[c] >> (8 - kNumColorCellBits[c])) * stride;
    475       stride <<= kNumColorCellBits[c];
    476     }
    477     JXL_ASSERT(cell_idx < m->candidate_lists_.size());
    478     int mindist = std::numeric_limits<int>::max();
    479     const auto& candidates = m->candidate_lists_[cell_idx];
    480     for (uint8_t i : candidates) {
    481       int dist = 0;
    482       for (int c = 0; c < num_channels; ++c) {
    483         int d = (cinfo->colormap[c][i] - pixel[c]) * kCompW[c];
    484         dist += d * d;
    485       }
    486       if (dist < mindist) {
    487         mindist = dist;
    488         index = i;
    489       }
    490     }
    491   }
    492   JXL_ASSERT(index < cinfo->actual_number_of_colors);
    493   return index;
    494 }
    495 
    496 void CreateOrderedDitherTables(j_decompress_ptr cinfo) {
    497   jpeg_decomp_master* m = cinfo->master;
    498   static constexpr size_t kDitherSize = 4;
    499   static constexpr size_t kDitherMask = kDitherSize - 1;
    500   static constexpr float kBaseDitherMatrix[] = {
    501       0,  8,  2,  10,  //
    502       12, 4,  14, 6,   //
    503       3,  11, 1,  9,   //
    504       15, 7,  13, 5,   //
    505   };
    506   m->dither_size_ = kDitherSize;
    507   m->dither_mask_ = kDitherMask;
    508   size_t ncells = m->dither_size_ * m->dither_size_;
    509   for (int c = 0; c < cinfo->out_color_components; ++c) {
    510     float spread = 1.0f / (m->num_colors_[c] - 1);
    511     float mul = spread / ncells;
    512     float offset = 0.5f * spread;
    513     if (m->dither_[c] == nullptr) {
    514       m->dither_[c] = Allocate<float>(cinfo, ncells, JPOOL_IMAGE_ALIGNED);
    515     }
    516     for (size_t idx = 0; idx < ncells; ++idx) {
    517       m->dither_[c][idx] = kBaseDitherMatrix[idx] * mul - offset;
    518     }
    519   }
    520 }
    521 
    522 void InitFSDitherState(j_decompress_ptr cinfo) {
    523   jpeg_decomp_master* m = cinfo->master;
    524   for (int c = 0; c < cinfo->out_color_components; ++c) {
    525     if (m->error_row_[c] == nullptr) {
    526       m->error_row_[c] =
    527           Allocate<float>(cinfo, cinfo->output_width, JPOOL_IMAGE_ALIGNED);
    528       m->error_row_[c + kMaxComponents] =
    529           Allocate<float>(cinfo, cinfo->output_width, JPOOL_IMAGE_ALIGNED);
    530     }
    531     memset(m->error_row_[c], 0.0, cinfo->output_width * sizeof(float));
    532     memset(m->error_row_[c + kMaxComponents], 0.0,
    533            cinfo->output_width * sizeof(float));
    534   }
    535 }
    536 
    537 }  // namespace jpegli