libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

enc_cluster.cc (12456B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/enc_cluster.h"
      7 
      8 #include <algorithm>
      9 #include <cmath>
     10 #include <limits>
     11 #include <map>
     12 #include <memory>
     13 #include <numeric>
     14 #include <queue>
     15 #include <tuple>
     16 
     17 #undef HWY_TARGET_INCLUDE
     18 #define HWY_TARGET_INCLUDE "lib/jxl/enc_cluster.cc"
     19 #include <hwy/foreach_target.h>
     20 #include <hwy/highway.h>
     21 
     22 #include "lib/jxl/ac_context.h"
     23 #include "lib/jxl/base/fast_math-inl.h"
     24 #include "lib/jxl/enc_ans.h"
     25 HWY_BEFORE_NAMESPACE();
     26 namespace jxl {
     27 namespace HWY_NAMESPACE {
     28 
     29 // These templates are not found via ADL.
     30 using hwy::HWY_NAMESPACE::Eq;
     31 using hwy::HWY_NAMESPACE::IfThenZeroElse;
     32 
     33 template <class V>
     34 V Entropy(V count, V inv_total, V total) {
     35   const HWY_CAPPED(float, Histogram::kRounding) d;
     36   const auto zero = Set(d, 0.0f);
     37   // TODO(eustas): why (0 - x) instead of Neg(x)?
     38   return IfThenZeroElse(
     39       Eq(count, total),
     40       Sub(zero, Mul(count, FastLog2f(d, Mul(inv_total, count)))));
     41 }
     42 
     43 void HistogramEntropy(const Histogram& a) {
     44   a.entropy_ = 0.0f;
     45   if (a.total_count_ == 0) return;
     46 
     47   const HWY_CAPPED(float, Histogram::kRounding) df;
     48   const HWY_CAPPED(int32_t, Histogram::kRounding) di;
     49 
     50   const auto inv_tot = Set(df, 1.0f / a.total_count_);
     51   auto entropy_lanes = Zero(df);
     52   auto total = Set(df, a.total_count_);
     53 
     54   for (size_t i = 0; i < a.data_.size(); i += Lanes(di)) {
     55     const auto counts = LoadU(di, &a.data_[i]);
     56     entropy_lanes =
     57         Add(entropy_lanes, Entropy(ConvertTo(df, counts), inv_tot, total));
     58   }
     59   a.entropy_ += GetLane(SumOfLanes(df, entropy_lanes));
     60 }
     61 
     62 float HistogramDistance(const Histogram& a, const Histogram& b) {
     63   if (a.total_count_ == 0 || b.total_count_ == 0) return 0;
     64 
     65   const HWY_CAPPED(float, Histogram::kRounding) df;
     66   const HWY_CAPPED(int32_t, Histogram::kRounding) di;
     67 
     68   const auto inv_tot = Set(df, 1.0f / (a.total_count_ + b.total_count_));
     69   auto distance_lanes = Zero(df);
     70   auto total = Set(df, a.total_count_ + b.total_count_);
     71 
     72   for (size_t i = 0; i < std::max(a.data_.size(), b.data_.size());
     73        i += Lanes(di)) {
     74     const auto a_counts =
     75         a.data_.size() > i ? LoadU(di, &a.data_[i]) : Zero(di);
     76     const auto b_counts =
     77         b.data_.size() > i ? LoadU(di, &b.data_[i]) : Zero(di);
     78     const auto counts = ConvertTo(df, Add(a_counts, b_counts));
     79     distance_lanes = Add(distance_lanes, Entropy(counts, inv_tot, total));
     80   }
     81   const float total_distance = GetLane(SumOfLanes(df, distance_lanes));
     82   return total_distance - a.entropy_ - b.entropy_;
     83 }
     84 
     85 constexpr const float kInfinity = std::numeric_limits<float>::infinity();
     86 
     87 float HistogramKLDivergence(const Histogram& actual, const Histogram& coding) {
     88   if (actual.total_count_ == 0) return 0;
     89   if (coding.total_count_ == 0) return kInfinity;
     90 
     91   const HWY_CAPPED(float, Histogram::kRounding) df;
     92   const HWY_CAPPED(int32_t, Histogram::kRounding) di;
     93 
     94   const auto coding_inv = Set(df, 1.0f / coding.total_count_);
     95   auto cost_lanes = Zero(df);
     96 
     97   for (size_t i = 0; i < actual.data_.size(); i += Lanes(di)) {
     98     const auto counts = LoadU(di, &actual.data_[i]);
     99     const auto coding_counts =
    100         coding.data_.size() > i ? LoadU(di, &coding.data_[i]) : Zero(di);
    101     const auto coding_probs = Mul(ConvertTo(df, coding_counts), coding_inv);
    102     const auto neg_coding_cost = BitCast(
    103         df,
    104         IfThenZeroElse(Eq(counts, Zero(di)),
    105                        IfThenElse(Eq(coding_counts, Zero(di)),
    106                                   BitCast(di, Set(df, -kInfinity)),
    107                                   BitCast(di, FastLog2f(df, coding_probs)))));
    108     cost_lanes = NegMulAdd(ConvertTo(df, counts), neg_coding_cost, cost_lanes);
    109   }
    110   const float total_cost = GetLane(SumOfLanes(df, cost_lanes));
    111   return total_cost - actual.entropy_;
    112 }
    113 
    114 // First step of a k-means clustering with a fancy distance metric.
    115 void FastClusterHistograms(const std::vector<Histogram>& in,
    116                            size_t max_histograms, std::vector<Histogram>* out,
    117                            std::vector<uint32_t>* histogram_symbols) {
    118   const size_t prev_histograms = out->size();
    119   out->reserve(max_histograms);
    120   histogram_symbols->clear();
    121   histogram_symbols->resize(in.size(), max_histograms);
    122 
    123   std::vector<float> dists(in.size(), std::numeric_limits<float>::max());
    124   size_t largest_idx = 0;
    125   for (size_t i = 0; i < in.size(); i++) {
    126     if (in[i].total_count_ == 0) {
    127       (*histogram_symbols)[i] = 0;
    128       dists[i] = 0.0f;
    129       continue;
    130     }
    131     HistogramEntropy(in[i]);
    132     if (in[i].total_count_ > in[largest_idx].total_count_) {
    133       largest_idx = i;
    134     }
    135   }
    136 
    137   if (prev_histograms > 0) {
    138     for (size_t j = 0; j < prev_histograms; ++j) {
    139       HistogramEntropy((*out)[j]);
    140     }
    141     for (size_t i = 0; i < in.size(); i++) {
    142       if (dists[i] == 0.0f) continue;
    143       for (size_t j = 0; j < prev_histograms; ++j) {
    144         dists[i] = std::min(HistogramKLDivergence(in[i], (*out)[j]), dists[i]);
    145       }
    146     }
    147     auto max_dist = std::max_element(dists.begin(), dists.end());
    148     if (*max_dist > 0.0f) {
    149       largest_idx = max_dist - dists.begin();
    150     }
    151   }
    152 
    153   constexpr float kMinDistanceForDistinct = 48.0f;
    154   while (out->size() < max_histograms) {
    155     (*histogram_symbols)[largest_idx] = out->size();
    156     out->push_back(in[largest_idx]);
    157     dists[largest_idx] = 0.0f;
    158     largest_idx = 0;
    159     for (size_t i = 0; i < in.size(); i++) {
    160       if (dists[i] == 0.0f) continue;
    161       dists[i] = std::min(HistogramDistance(in[i], out->back()), dists[i]);
    162       if (dists[i] > dists[largest_idx]) largest_idx = i;
    163     }
    164     if (dists[largest_idx] < kMinDistanceForDistinct) break;
    165   }
    166 
    167   for (size_t i = 0; i < in.size(); i++) {
    168     if ((*histogram_symbols)[i] != max_histograms) continue;
    169     size_t best = 0;
    170     float best_dist = std::numeric_limits<float>::max();
    171     for (size_t j = 0; j < out->size(); j++) {
    172       float dist = j < prev_histograms ? HistogramKLDivergence(in[i], (*out)[j])
    173                                        : HistogramDistance(in[i], (*out)[j]);
    174       if (dist < best_dist) {
    175         best = j;
    176         best_dist = dist;
    177       }
    178     }
    179     JXL_ASSERT(best_dist < std::numeric_limits<float>::max());
    180     if (best >= prev_histograms) {
    181       (*out)[best].AddHistogram(in[i]);
    182       HistogramEntropy((*out)[best]);
    183     }
    184     (*histogram_symbols)[i] = best;
    185   }
    186 }
    187 
    188 // NOLINTNEXTLINE(google-readability-namespace-comments)
    189 }  // namespace HWY_NAMESPACE
    190 }  // namespace jxl
    191 HWY_AFTER_NAMESPACE();
    192 
    193 #if HWY_ONCE
    194 namespace jxl {
    195 HWY_EXPORT(FastClusterHistograms);  // Local function
    196 HWY_EXPORT(HistogramEntropy);       // Local function
    197 
    198 float Histogram::PopulationCost() const {
    199   return ANSPopulationCost(data_.data(), data_.size());
    200 }
    201 
    202 float Histogram::ShannonEntropy() const {
    203   HWY_DYNAMIC_DISPATCH(HistogramEntropy)(*this);
    204   return entropy_;
    205 }
    206 
    207 namespace {
    208 // -----------------------------------------------------------------------------
    209 // Histogram refinement
    210 
    211 // Reorder histograms in *out so that the new symbols in *symbols come in
    212 // increasing order.
    213 void HistogramReindex(std::vector<Histogram>* out, size_t prev_histograms,
    214                       std::vector<uint32_t>* symbols) {
    215   std::vector<Histogram> tmp(*out);
    216   std::map<int, int> new_index;
    217   for (size_t i = 0; i < prev_histograms; ++i) {
    218     new_index[i] = i;
    219   }
    220   int next_index = prev_histograms;
    221   for (uint32_t symbol : *symbols) {
    222     if (new_index.find(symbol) == new_index.end()) {
    223       new_index[symbol] = next_index;
    224       (*out)[next_index] = tmp[symbol];
    225       ++next_index;
    226     }
    227   }
    228   out->resize(next_index);
    229   for (uint32_t& symbol : *symbols) {
    230     symbol = new_index[symbol];
    231   }
    232 }
    233 
    234 }  // namespace
    235 
    236 // Clusters similar histograms in 'in' together, the selected histograms are
    237 // placed in 'out', and for each index in 'in', *histogram_symbols will
    238 // indicate which of the 'out' histograms is the best approximation.
    239 void ClusterHistograms(const HistogramParams& params,
    240                        const std::vector<Histogram>& in, size_t max_histograms,
    241                        std::vector<Histogram>* out,
    242                        std::vector<uint32_t>* histogram_symbols) {
    243   size_t prev_histograms = out->size();
    244   max_histograms = std::min(max_histograms, params.max_histograms);
    245   max_histograms = std::min(max_histograms, in.size());
    246   if (params.clustering == HistogramParams::ClusteringType::kFastest) {
    247     max_histograms = std::min(max_histograms, static_cast<size_t>(4));
    248   }
    249 
    250   HWY_DYNAMIC_DISPATCH(FastClusterHistograms)
    251   (in, prev_histograms + max_histograms, out, histogram_symbols);
    252 
    253   if (prev_histograms == 0 &&
    254       params.clustering == HistogramParams::ClusteringType::kBest) {
    255     for (auto& histo : *out) {
    256       histo.entropy_ =
    257           ANSPopulationCost(histo.data_.data(), histo.data_.size());
    258     }
    259     uint32_t next_version = 2;
    260     std::vector<uint32_t> version(out->size(), 1);
    261     std::vector<uint32_t> renumbering(out->size());
    262     std::iota(renumbering.begin(), renumbering.end(), 0);
    263 
    264     // Try to pair up clusters if doing so reduces the total cost.
    265 
    266     struct HistogramPair {
    267       // validity of a pair: p.version == max(version[i], version[j])
    268       float cost;
    269       uint32_t first;
    270       uint32_t second;
    271       uint32_t version;
    272       // We use > because priority queues sort in *decreasing* order, but we
    273       // want lower cost elements to appear first.
    274       bool operator<(const HistogramPair& other) const {
    275         return std::make_tuple(cost, first, second, version) >
    276                std::make_tuple(other.cost, other.first, other.second,
    277                                other.version);
    278       }
    279     };
    280 
    281     // Create list of all pairs by increasing merging cost.
    282     std::priority_queue<HistogramPair> pairs_to_merge;
    283     for (uint32_t i = 0; i < out->size(); i++) {
    284       for (uint32_t j = i + 1; j < out->size(); j++) {
    285         Histogram histo;
    286         histo.AddHistogram((*out)[i]);
    287         histo.AddHistogram((*out)[j]);
    288         float cost = ANSPopulationCost(histo.data_.data(), histo.data_.size()) -
    289                      (*out)[i].entropy_ - (*out)[j].entropy_;
    290         // Avoid enqueueing pairs that are not advantageous to merge.
    291         if (cost >= 0) continue;
    292         pairs_to_merge.push(
    293             HistogramPair{cost, i, j, std::max(version[i], version[j])});
    294       }
    295     }
    296 
    297     // Merge the best pair to merge, add new pairs that get formed as a
    298     // consequence.
    299     while (!pairs_to_merge.empty()) {
    300       uint32_t first = pairs_to_merge.top().first;
    301       uint32_t second = pairs_to_merge.top().second;
    302       uint32_t ver = pairs_to_merge.top().version;
    303       pairs_to_merge.pop();
    304       if (ver != std::max(version[first], version[second]) ||
    305           version[first] == 0 || version[second] == 0) {
    306         continue;
    307       }
    308       (*out)[first].AddHistogram((*out)[second]);
    309       (*out)[first].entropy_ = ANSPopulationCost((*out)[first].data_.data(),
    310                                                  (*out)[first].data_.size());
    311       for (uint32_t& item : renumbering) {
    312         if (item == second) {
    313           item = first;
    314         }
    315       }
    316       version[second] = 0;
    317       version[first] = next_version++;
    318       for (uint32_t j = 0; j < out->size(); j++) {
    319         if (j == first) continue;
    320         if (version[j] == 0) continue;
    321         Histogram histo;
    322         histo.AddHistogram((*out)[first]);
    323         histo.AddHistogram((*out)[j]);
    324         float cost = ANSPopulationCost(histo.data_.data(), histo.data_.size()) -
    325                      (*out)[first].entropy_ - (*out)[j].entropy_;
    326         // Avoid enqueueing pairs that are not advantageous to merge.
    327         if (cost >= 0) continue;
    328         pairs_to_merge.push(
    329             HistogramPair{cost, std::min(first, j), std::max(first, j),
    330                           std::max(version[first], version[j])});
    331       }
    332     }
    333     std::vector<uint32_t> reverse_renumbering(out->size(), -1);
    334     size_t num_alive = 0;
    335     for (size_t i = 0; i < out->size(); i++) {
    336       if (version[i] == 0) continue;
    337       (*out)[num_alive++] = (*out)[i];
    338       reverse_renumbering[i] = num_alive - 1;
    339     }
    340     out->resize(num_alive);
    341     for (uint32_t& item : *histogram_symbols) {
    342       item = reverse_renumbering[renumbering[item]];
    343     }
    344   }
    345 
    346   // Convert the context map to a canonical form.
    347   HistogramReindex(out, prev_histograms, histogram_symbols);
    348 }
    349 
    350 }  // namespace jxl
    351 #endif  // HWY_ONCE