enc_cluster.cc (12456B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/enc_cluster.h" 7 8 #include <algorithm> 9 #include <cmath> 10 #include <limits> 11 #include <map> 12 #include <memory> 13 #include <numeric> 14 #include <queue> 15 #include <tuple> 16 17 #undef HWY_TARGET_INCLUDE 18 #define HWY_TARGET_INCLUDE "lib/jxl/enc_cluster.cc" 19 #include <hwy/foreach_target.h> 20 #include <hwy/highway.h> 21 22 #include "lib/jxl/ac_context.h" 23 #include "lib/jxl/base/fast_math-inl.h" 24 #include "lib/jxl/enc_ans.h" 25 HWY_BEFORE_NAMESPACE(); 26 namespace jxl { 27 namespace HWY_NAMESPACE { 28 29 // These templates are not found via ADL. 30 using hwy::HWY_NAMESPACE::Eq; 31 using hwy::HWY_NAMESPACE::IfThenZeroElse; 32 33 template <class V> 34 V Entropy(V count, V inv_total, V total) { 35 const HWY_CAPPED(float, Histogram::kRounding) d; 36 const auto zero = Set(d, 0.0f); 37 // TODO(eustas): why (0 - x) instead of Neg(x)? 38 return IfThenZeroElse( 39 Eq(count, total), 40 Sub(zero, Mul(count, FastLog2f(d, Mul(inv_total, count))))); 41 } 42 43 void HistogramEntropy(const Histogram& a) { 44 a.entropy_ = 0.0f; 45 if (a.total_count_ == 0) return; 46 47 const HWY_CAPPED(float, Histogram::kRounding) df; 48 const HWY_CAPPED(int32_t, Histogram::kRounding) di; 49 50 const auto inv_tot = Set(df, 1.0f / a.total_count_); 51 auto entropy_lanes = Zero(df); 52 auto total = Set(df, a.total_count_); 53 54 for (size_t i = 0; i < a.data_.size(); i += Lanes(di)) { 55 const auto counts = LoadU(di, &a.data_[i]); 56 entropy_lanes = 57 Add(entropy_lanes, Entropy(ConvertTo(df, counts), inv_tot, total)); 58 } 59 a.entropy_ += GetLane(SumOfLanes(df, entropy_lanes)); 60 } 61 62 float HistogramDistance(const Histogram& a, const Histogram& b) { 63 if (a.total_count_ == 0 || b.total_count_ == 0) return 0; 64 65 const HWY_CAPPED(float, Histogram::kRounding) df; 66 const HWY_CAPPED(int32_t, Histogram::kRounding) di; 67 68 const auto inv_tot = Set(df, 1.0f / (a.total_count_ + b.total_count_)); 69 auto distance_lanes = Zero(df); 70 auto total = Set(df, a.total_count_ + b.total_count_); 71 72 for (size_t i = 0; i < std::max(a.data_.size(), b.data_.size()); 73 i += Lanes(di)) { 74 const auto a_counts = 75 a.data_.size() > i ? LoadU(di, &a.data_[i]) : Zero(di); 76 const auto b_counts = 77 b.data_.size() > i ? LoadU(di, &b.data_[i]) : Zero(di); 78 const auto counts = ConvertTo(df, Add(a_counts, b_counts)); 79 distance_lanes = Add(distance_lanes, Entropy(counts, inv_tot, total)); 80 } 81 const float total_distance = GetLane(SumOfLanes(df, distance_lanes)); 82 return total_distance - a.entropy_ - b.entropy_; 83 } 84 85 constexpr const float kInfinity = std::numeric_limits<float>::infinity(); 86 87 float HistogramKLDivergence(const Histogram& actual, const Histogram& coding) { 88 if (actual.total_count_ == 0) return 0; 89 if (coding.total_count_ == 0) return kInfinity; 90 91 const HWY_CAPPED(float, Histogram::kRounding) df; 92 const HWY_CAPPED(int32_t, Histogram::kRounding) di; 93 94 const auto coding_inv = Set(df, 1.0f / coding.total_count_); 95 auto cost_lanes = Zero(df); 96 97 for (size_t i = 0; i < actual.data_.size(); i += Lanes(di)) { 98 const auto counts = LoadU(di, &actual.data_[i]); 99 const auto coding_counts = 100 coding.data_.size() > i ? LoadU(di, &coding.data_[i]) : Zero(di); 101 const auto coding_probs = Mul(ConvertTo(df, coding_counts), coding_inv); 102 const auto neg_coding_cost = BitCast( 103 df, 104 IfThenZeroElse(Eq(counts, Zero(di)), 105 IfThenElse(Eq(coding_counts, Zero(di)), 106 BitCast(di, Set(df, -kInfinity)), 107 BitCast(di, FastLog2f(df, coding_probs))))); 108 cost_lanes = NegMulAdd(ConvertTo(df, counts), neg_coding_cost, cost_lanes); 109 } 110 const float total_cost = GetLane(SumOfLanes(df, cost_lanes)); 111 return total_cost - actual.entropy_; 112 } 113 114 // First step of a k-means clustering with a fancy distance metric. 115 void FastClusterHistograms(const std::vector<Histogram>& in, 116 size_t max_histograms, std::vector<Histogram>* out, 117 std::vector<uint32_t>* histogram_symbols) { 118 const size_t prev_histograms = out->size(); 119 out->reserve(max_histograms); 120 histogram_symbols->clear(); 121 histogram_symbols->resize(in.size(), max_histograms); 122 123 std::vector<float> dists(in.size(), std::numeric_limits<float>::max()); 124 size_t largest_idx = 0; 125 for (size_t i = 0; i < in.size(); i++) { 126 if (in[i].total_count_ == 0) { 127 (*histogram_symbols)[i] = 0; 128 dists[i] = 0.0f; 129 continue; 130 } 131 HistogramEntropy(in[i]); 132 if (in[i].total_count_ > in[largest_idx].total_count_) { 133 largest_idx = i; 134 } 135 } 136 137 if (prev_histograms > 0) { 138 for (size_t j = 0; j < prev_histograms; ++j) { 139 HistogramEntropy((*out)[j]); 140 } 141 for (size_t i = 0; i < in.size(); i++) { 142 if (dists[i] == 0.0f) continue; 143 for (size_t j = 0; j < prev_histograms; ++j) { 144 dists[i] = std::min(HistogramKLDivergence(in[i], (*out)[j]), dists[i]); 145 } 146 } 147 auto max_dist = std::max_element(dists.begin(), dists.end()); 148 if (*max_dist > 0.0f) { 149 largest_idx = max_dist - dists.begin(); 150 } 151 } 152 153 constexpr float kMinDistanceForDistinct = 48.0f; 154 while (out->size() < max_histograms) { 155 (*histogram_symbols)[largest_idx] = out->size(); 156 out->push_back(in[largest_idx]); 157 dists[largest_idx] = 0.0f; 158 largest_idx = 0; 159 for (size_t i = 0; i < in.size(); i++) { 160 if (dists[i] == 0.0f) continue; 161 dists[i] = std::min(HistogramDistance(in[i], out->back()), dists[i]); 162 if (dists[i] > dists[largest_idx]) largest_idx = i; 163 } 164 if (dists[largest_idx] < kMinDistanceForDistinct) break; 165 } 166 167 for (size_t i = 0; i < in.size(); i++) { 168 if ((*histogram_symbols)[i] != max_histograms) continue; 169 size_t best = 0; 170 float best_dist = std::numeric_limits<float>::max(); 171 for (size_t j = 0; j < out->size(); j++) { 172 float dist = j < prev_histograms ? HistogramKLDivergence(in[i], (*out)[j]) 173 : HistogramDistance(in[i], (*out)[j]); 174 if (dist < best_dist) { 175 best = j; 176 best_dist = dist; 177 } 178 } 179 JXL_ASSERT(best_dist < std::numeric_limits<float>::max()); 180 if (best >= prev_histograms) { 181 (*out)[best].AddHistogram(in[i]); 182 HistogramEntropy((*out)[best]); 183 } 184 (*histogram_symbols)[i] = best; 185 } 186 } 187 188 // NOLINTNEXTLINE(google-readability-namespace-comments) 189 } // namespace HWY_NAMESPACE 190 } // namespace jxl 191 HWY_AFTER_NAMESPACE(); 192 193 #if HWY_ONCE 194 namespace jxl { 195 HWY_EXPORT(FastClusterHistograms); // Local function 196 HWY_EXPORT(HistogramEntropy); // Local function 197 198 float Histogram::PopulationCost() const { 199 return ANSPopulationCost(data_.data(), data_.size()); 200 } 201 202 float Histogram::ShannonEntropy() const { 203 HWY_DYNAMIC_DISPATCH(HistogramEntropy)(*this); 204 return entropy_; 205 } 206 207 namespace { 208 // ----------------------------------------------------------------------------- 209 // Histogram refinement 210 211 // Reorder histograms in *out so that the new symbols in *symbols come in 212 // increasing order. 213 void HistogramReindex(std::vector<Histogram>* out, size_t prev_histograms, 214 std::vector<uint32_t>* symbols) { 215 std::vector<Histogram> tmp(*out); 216 std::map<int, int> new_index; 217 for (size_t i = 0; i < prev_histograms; ++i) { 218 new_index[i] = i; 219 } 220 int next_index = prev_histograms; 221 for (uint32_t symbol : *symbols) { 222 if (new_index.find(symbol) == new_index.end()) { 223 new_index[symbol] = next_index; 224 (*out)[next_index] = tmp[symbol]; 225 ++next_index; 226 } 227 } 228 out->resize(next_index); 229 for (uint32_t& symbol : *symbols) { 230 symbol = new_index[symbol]; 231 } 232 } 233 234 } // namespace 235 236 // Clusters similar histograms in 'in' together, the selected histograms are 237 // placed in 'out', and for each index in 'in', *histogram_symbols will 238 // indicate which of the 'out' histograms is the best approximation. 239 void ClusterHistograms(const HistogramParams& params, 240 const std::vector<Histogram>& in, size_t max_histograms, 241 std::vector<Histogram>* out, 242 std::vector<uint32_t>* histogram_symbols) { 243 size_t prev_histograms = out->size(); 244 max_histograms = std::min(max_histograms, params.max_histograms); 245 max_histograms = std::min(max_histograms, in.size()); 246 if (params.clustering == HistogramParams::ClusteringType::kFastest) { 247 max_histograms = std::min(max_histograms, static_cast<size_t>(4)); 248 } 249 250 HWY_DYNAMIC_DISPATCH(FastClusterHistograms) 251 (in, prev_histograms + max_histograms, out, histogram_symbols); 252 253 if (prev_histograms == 0 && 254 params.clustering == HistogramParams::ClusteringType::kBest) { 255 for (auto& histo : *out) { 256 histo.entropy_ = 257 ANSPopulationCost(histo.data_.data(), histo.data_.size()); 258 } 259 uint32_t next_version = 2; 260 std::vector<uint32_t> version(out->size(), 1); 261 std::vector<uint32_t> renumbering(out->size()); 262 std::iota(renumbering.begin(), renumbering.end(), 0); 263 264 // Try to pair up clusters if doing so reduces the total cost. 265 266 struct HistogramPair { 267 // validity of a pair: p.version == max(version[i], version[j]) 268 float cost; 269 uint32_t first; 270 uint32_t second; 271 uint32_t version; 272 // We use > because priority queues sort in *decreasing* order, but we 273 // want lower cost elements to appear first. 274 bool operator<(const HistogramPair& other) const { 275 return std::make_tuple(cost, first, second, version) > 276 std::make_tuple(other.cost, other.first, other.second, 277 other.version); 278 } 279 }; 280 281 // Create list of all pairs by increasing merging cost. 282 std::priority_queue<HistogramPair> pairs_to_merge; 283 for (uint32_t i = 0; i < out->size(); i++) { 284 for (uint32_t j = i + 1; j < out->size(); j++) { 285 Histogram histo; 286 histo.AddHistogram((*out)[i]); 287 histo.AddHistogram((*out)[j]); 288 float cost = ANSPopulationCost(histo.data_.data(), histo.data_.size()) - 289 (*out)[i].entropy_ - (*out)[j].entropy_; 290 // Avoid enqueueing pairs that are not advantageous to merge. 291 if (cost >= 0) continue; 292 pairs_to_merge.push( 293 HistogramPair{cost, i, j, std::max(version[i], version[j])}); 294 } 295 } 296 297 // Merge the best pair to merge, add new pairs that get formed as a 298 // consequence. 299 while (!pairs_to_merge.empty()) { 300 uint32_t first = pairs_to_merge.top().first; 301 uint32_t second = pairs_to_merge.top().second; 302 uint32_t ver = pairs_to_merge.top().version; 303 pairs_to_merge.pop(); 304 if (ver != std::max(version[first], version[second]) || 305 version[first] == 0 || version[second] == 0) { 306 continue; 307 } 308 (*out)[first].AddHistogram((*out)[second]); 309 (*out)[first].entropy_ = ANSPopulationCost((*out)[first].data_.data(), 310 (*out)[first].data_.size()); 311 for (uint32_t& item : renumbering) { 312 if (item == second) { 313 item = first; 314 } 315 } 316 version[second] = 0; 317 version[first] = next_version++; 318 for (uint32_t j = 0; j < out->size(); j++) { 319 if (j == first) continue; 320 if (version[j] == 0) continue; 321 Histogram histo; 322 histo.AddHistogram((*out)[first]); 323 histo.AddHistogram((*out)[j]); 324 float cost = ANSPopulationCost(histo.data_.data(), histo.data_.size()) - 325 (*out)[first].entropy_ - (*out)[j].entropy_; 326 // Avoid enqueueing pairs that are not advantageous to merge. 327 if (cost >= 0) continue; 328 pairs_to_merge.push( 329 HistogramPair{cost, std::min(first, j), std::max(first, j), 330 std::max(version[first], version[j])}); 331 } 332 } 333 std::vector<uint32_t> reverse_renumbering(out->size(), -1); 334 size_t num_alive = 0; 335 for (size_t i = 0; i < out->size(); i++) { 336 if (version[i] == 0) continue; 337 (*out)[num_alive++] = (*out)[i]; 338 reverse_renumbering[i] = num_alive - 1; 339 } 340 out->resize(num_alive); 341 for (uint32_t& item : *histogram_symbols) { 342 item = reverse_renumbering[renumbering[item]]; 343 } 344 } 345 346 // Convert the context map to a canonical form. 347 HistogramReindex(out, prev_histograms, histogram_symbols); 348 } 349 350 } // namespace jxl 351 #endif // HWY_ONCE