libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

enc_patch_dictionary.cc (31706B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/enc_patch_dictionary.h"
      7 
      8 #include <jxl/types.h>
      9 #include <stdint.h>
     10 #include <stdlib.h>
     11 #include <sys/types.h>
     12 
     13 #include <algorithm>
     14 #include <atomic>
     15 #include <utility>
     16 #include <vector>
     17 
     18 #include "lib/jxl/base/common.h"
     19 #include "lib/jxl/base/compiler_specific.h"
     20 #include "lib/jxl/base/override.h"
     21 #include "lib/jxl/base/printf_macros.h"
     22 #include "lib/jxl/base/random.h"
     23 #include "lib/jxl/base/status.h"
     24 #include "lib/jxl/dec_cache.h"
     25 #include "lib/jxl/dec_frame.h"
     26 #include "lib/jxl/enc_ans.h"
     27 #include "lib/jxl/enc_aux_out.h"
     28 #include "lib/jxl/enc_cache.h"
     29 #include "lib/jxl/enc_debug_image.h"
     30 #include "lib/jxl/enc_dot_dictionary.h"
     31 #include "lib/jxl/enc_frame.h"
     32 #include "lib/jxl/frame_header.h"
     33 #include "lib/jxl/image.h"
     34 #include "lib/jxl/image_bundle.h"
     35 #include "lib/jxl/image_ops.h"
     36 #include "lib/jxl/pack_signed.h"
     37 #include "lib/jxl/patch_dictionary_internal.h"
     38 
     39 namespace jxl {
     40 
     41 static constexpr size_t kPatchFrameReferenceId = 3;
     42 
     43 // static
     44 void PatchDictionaryEncoder::Encode(const PatchDictionary& pdic,
     45                                     BitWriter* writer, size_t layer,
     46                                     AuxOut* aux_out) {
     47   JXL_ASSERT(pdic.HasAny());
     48   std::vector<std::vector<Token>> tokens(1);
     49   size_t num_ec = pdic.shared_->metadata->m.num_extra_channels;
     50 
     51   auto add_num = [&](int context, size_t num) {
     52     tokens[0].emplace_back(context, num);
     53   };
     54   size_t num_ref_patch = 0;
     55   for (size_t i = 0; i < pdic.positions_.size();) {
     56     size_t ref_pos_idx = pdic.positions_[i].ref_pos_idx;
     57     while (i < pdic.positions_.size() &&
     58            pdic.positions_[i].ref_pos_idx == ref_pos_idx) {
     59       i++;
     60     }
     61     num_ref_patch++;
     62   }
     63   add_num(kNumRefPatchContext, num_ref_patch);
     64   size_t blend_pos = 0;
     65   for (size_t i = 0; i < pdic.positions_.size();) {
     66     size_t i_start = i;
     67     size_t ref_pos_idx = pdic.positions_[i].ref_pos_idx;
     68     const auto& ref_pos = pdic.ref_positions_[ref_pos_idx];
     69     while (i < pdic.positions_.size() &&
     70            pdic.positions_[i].ref_pos_idx == ref_pos_idx) {
     71       i++;
     72     }
     73     size_t num = i - i_start;
     74     JXL_ASSERT(num > 0);
     75     add_num(kReferenceFrameContext, ref_pos.ref);
     76     add_num(kPatchReferencePositionContext, ref_pos.x0);
     77     add_num(kPatchReferencePositionContext, ref_pos.y0);
     78     add_num(kPatchSizeContext, ref_pos.xsize - 1);
     79     add_num(kPatchSizeContext, ref_pos.ysize - 1);
     80     add_num(kPatchCountContext, num - 1);
     81     for (size_t j = i_start; j < i; j++) {
     82       const PatchPosition& pos = pdic.positions_[j];
     83       if (j == i_start) {
     84         add_num(kPatchPositionContext, pos.x);
     85         add_num(kPatchPositionContext, pos.y);
     86       } else {
     87         add_num(kPatchOffsetContext,
     88                 PackSigned(pos.x - pdic.positions_[j - 1].x));
     89         add_num(kPatchOffsetContext,
     90                 PackSigned(pos.y - pdic.positions_[j - 1].y));
     91       }
     92       for (size_t j = 0; j < num_ec + 1; ++j, ++blend_pos) {
     93         const PatchBlending& info = pdic.blendings_[blend_pos];
     94         add_num(kPatchBlendModeContext, static_cast<uint32_t>(info.mode));
     95         if (UsesAlpha(info.mode) &&
     96             pdic.shared_->metadata->m.extra_channel_info.size() > 1) {
     97           add_num(kPatchAlphaChannelContext, info.alpha_channel);
     98         }
     99         if (UsesClamp(info.mode)) {
    100           add_num(kPatchClampContext, TO_JXL_BOOL(info.clamp));
    101         }
    102       }
    103     }
    104   }
    105 
    106   EntropyEncodingData codes;
    107   std::vector<uint8_t> context_map;
    108   BuildAndEncodeHistograms(HistogramParams(), kNumPatchDictionaryContexts,
    109                            tokens, &codes, &context_map, writer, layer,
    110                            aux_out);
    111   WriteTokens(tokens[0], codes, context_map, 0, writer, layer, aux_out);
    112 }
    113 
    114 // static
    115 void PatchDictionaryEncoder::SubtractFrom(const PatchDictionary& pdic,
    116                                           Image3F* opsin) {
    117   size_t num_ec = pdic.shared_->metadata->m.num_extra_channels;
    118   // TODO(veluca): this can likely be optimized knowing it runs on full images.
    119   for (size_t y = 0; y < opsin->ysize(); y++) {
    120     float* JXL_RESTRICT rows[3] = {
    121         opsin->PlaneRow(0, y),
    122         opsin->PlaneRow(1, y),
    123         opsin->PlaneRow(2, y),
    124     };
    125     for (size_t pos_idx : pdic.GetPatchesForRow(y)) {
    126       const size_t blending_idx = pos_idx * (num_ec + 1);
    127       const PatchPosition& pos = pdic.positions_[pos_idx];
    128       const PatchReferencePosition& ref_pos =
    129           pdic.ref_positions_[pos.ref_pos_idx];
    130       const PatchBlendMode mode = pdic.blendings_[blending_idx].mode;
    131       size_t by = pos.y;
    132       size_t bx = pos.x;
    133       size_t xsize = ref_pos.xsize;
    134       JXL_DASSERT(y >= by);
    135       JXL_DASSERT(y < by + ref_pos.ysize);
    136       size_t iy = y - by;
    137       size_t ref = ref_pos.ref;
    138       const float* JXL_RESTRICT ref_rows[3] = {
    139           pdic.shared_->reference_frames[ref].frame.color().ConstPlaneRow(
    140               0, ref_pos.y0 + iy) +
    141               ref_pos.x0,
    142           pdic.shared_->reference_frames[ref].frame.color().ConstPlaneRow(
    143               1, ref_pos.y0 + iy) +
    144               ref_pos.x0,
    145           pdic.shared_->reference_frames[ref].frame.color().ConstPlaneRow(
    146               2, ref_pos.y0 + iy) +
    147               ref_pos.x0,
    148       };
    149       for (size_t ix = 0; ix < xsize; ix++) {
    150         for (size_t c = 0; c < 3; c++) {
    151           if (mode == PatchBlendMode::kAdd) {
    152             rows[c][bx + ix] -= ref_rows[c][ix];
    153           } else if (mode == PatchBlendMode::kReplace) {
    154             rows[c][bx + ix] = 0;
    155           } else if (mode == PatchBlendMode::kNone) {
    156             // Nothing to do.
    157           } else {
    158             JXL_UNREACHABLE("Blending mode %u not yet implemented",
    159                             static_cast<uint32_t>(mode));
    160           }
    161         }
    162       }
    163     }
    164   }
    165 }
    166 
    167 namespace {
    168 
    169 struct PatchColorspaceInfo {
    170   float kChannelDequant[3];
    171   float kChannelWeights[3];
    172 
    173   explicit PatchColorspaceInfo(bool is_xyb) {
    174     if (is_xyb) {
    175       kChannelDequant[0] = 0.01615;
    176       kChannelDequant[1] = 0.08875;
    177       kChannelDequant[2] = 0.1922;
    178       kChannelWeights[0] = 30.0;
    179       kChannelWeights[1] = 3.0;
    180       kChannelWeights[2] = 1.0;
    181     } else {
    182       kChannelDequant[0] = 20.0f / 255;
    183       kChannelDequant[1] = 22.0f / 255;
    184       kChannelDequant[2] = 20.0f / 255;
    185       kChannelWeights[0] = 0.017 * 255;
    186       kChannelWeights[1] = 0.02 * 255;
    187       kChannelWeights[2] = 0.017 * 255;
    188     }
    189   }
    190 
    191   float ScaleForQuantization(float val, size_t c) {
    192     return val / kChannelDequant[c];
    193   }
    194 
    195   int Quantize(float val, size_t c) {
    196     return truncf(ScaleForQuantization(val, c));
    197   }
    198 
    199   bool is_similar_v(const float v1[3], const float v2[3], float threshold) {
    200     float distance = 0;
    201     for (size_t c = 0; c < 3; c++) {
    202       distance += std::fabs(v1[c] - v2[c]) * kChannelWeights[c];
    203     }
    204     return distance <= threshold;
    205   }
    206 };
    207 
    208 StatusOr<std::vector<PatchInfo>> FindTextLikePatches(
    209     const CompressParams& cparams, const Image3F& opsin,
    210     const PassesEncoderState* JXL_RESTRICT state, ThreadPool* pool,
    211     AuxOut* aux_out, bool is_xyb) {
    212   std::vector<PatchInfo> info;
    213   if (state->cparams.patches == Override::kOff) return info;
    214   const auto& frame_dim = state->shared.frame_dim;
    215 
    216   PatchColorspaceInfo pci(is_xyb);
    217   float kSimilarThreshold = 0.8f;
    218 
    219   auto is_similar_impl = [&pci](std::pair<uint32_t, uint32_t> p1,
    220                                 std::pair<uint32_t, uint32_t> p2,
    221                                 const float* JXL_RESTRICT rows[3],
    222                                 size_t stride, float threshold) {
    223     float v1[3];
    224     float v2[3];
    225     for (size_t c = 0; c < 3; c++) {
    226       v1[c] = rows[c][p1.second * stride + p1.first];
    227       v2[c] = rows[c][p2.second * stride + p2.first];
    228     }
    229     return pci.is_similar_v(v1, v2, threshold);
    230   };
    231 
    232   std::atomic<bool> has_screenshot_areas{false};
    233   const size_t opsin_stride = opsin.PixelsPerRow();
    234   const float* JXL_RESTRICT opsin_rows[3] = {opsin.ConstPlaneRow(0, 0),
    235                                              opsin.ConstPlaneRow(1, 0),
    236                                              opsin.ConstPlaneRow(2, 0)};
    237 
    238   auto is_same = [&opsin_rows, opsin_stride](std::pair<uint32_t, uint32_t> p1,
    239                                              std::pair<uint32_t, uint32_t> p2) {
    240     for (size_t c = 0; c < 3; c++) {
    241       float v1 = opsin_rows[c][p1.second * opsin_stride + p1.first];
    242       float v2 = opsin_rows[c][p2.second * opsin_stride + p2.first];
    243       if (std::fabs(v1 - v2) > 1e-4) {
    244         return false;
    245       }
    246     }
    247     return true;
    248   };
    249 
    250   auto is_similar = [&](std::pair<uint32_t, uint32_t> p1,
    251                         std::pair<uint32_t, uint32_t> p2) {
    252     return is_similar_impl(p1, p2, opsin_rows, opsin_stride, kSimilarThreshold);
    253   };
    254 
    255   constexpr int64_t kPatchSide = 4;
    256   constexpr int64_t kExtraSide = 4;
    257 
    258   // Look for kPatchSide size squares, naturally aligned, that all have the same
    259   // pixel values.
    260   JXL_ASSIGN_OR_RETURN(ImageB is_screenshot_like,
    261                        ImageB::Create(DivCeil(frame_dim.xsize, kPatchSide),
    262                                       DivCeil(frame_dim.ysize, kPatchSide)));
    263   ZeroFillImage(&is_screenshot_like);
    264   uint8_t* JXL_RESTRICT screenshot_row = is_screenshot_like.Row(0);
    265   const size_t screenshot_stride = is_screenshot_like.PixelsPerRow();
    266   const auto process_row = [&](const uint32_t y, size_t /* thread */) {
    267     for (uint64_t x = 0; x < frame_dim.xsize / kPatchSide; x++) {
    268       bool all_same = true;
    269       for (size_t iy = 0; iy < static_cast<size_t>(kPatchSide); iy++) {
    270         for (size_t ix = 0; ix < static_cast<size_t>(kPatchSide); ix++) {
    271           size_t cx = x * kPatchSide + ix;
    272           size_t cy = y * kPatchSide + iy;
    273           if (!is_same({cx, cy}, {x * kPatchSide, y * kPatchSide})) {
    274             all_same = false;
    275             break;
    276           }
    277         }
    278       }
    279       if (!all_same) continue;
    280       size_t num = 0;
    281       size_t num_same = 0;
    282       for (int64_t iy = -kExtraSide; iy < kExtraSide + kPatchSide; iy++) {
    283         for (int64_t ix = -kExtraSide; ix < kExtraSide + kPatchSide; ix++) {
    284           int64_t cx = x * kPatchSide + ix;
    285           int64_t cy = y * kPatchSide + iy;
    286           if (cx < 0 || static_cast<uint64_t>(cx) >= frame_dim.xsize ||  //
    287               cy < 0 || static_cast<uint64_t>(cy) >= frame_dim.ysize) {
    288             continue;
    289           }
    290           num++;
    291           if (is_same({cx, cy}, {x * kPatchSide, y * kPatchSide})) num_same++;
    292         }
    293       }
    294       // Too few equal pixels nearby.
    295       if (num_same * 8 < num * 7) continue;
    296       screenshot_row[y * screenshot_stride + x] = 1;
    297       has_screenshot_areas = true;
    298     }
    299   };
    300   JXL_CHECK(RunOnPool(pool, 0, frame_dim.ysize / kPatchSide, ThreadPool::NoInit,
    301                       process_row, "IsScreenshotLike"));
    302 
    303   // TODO(veluca): also parallelize the rest of this function.
    304   if (WantDebugOutput(cparams)) {
    305     JXL_RETURN_IF_ERROR(
    306         DumpPlaneNormalized(cparams, "screenshot_like", is_screenshot_like));
    307   }
    308 
    309   constexpr int kSearchRadius = 1;
    310 
    311   if (!ApplyOverride(state->cparams.patches, has_screenshot_areas)) {
    312     return info;
    313   }
    314 
    315   // Search for "similar enough" pixels near the screenshot-like areas.
    316   JXL_ASSIGN_OR_RETURN(ImageB is_background,
    317                        ImageB::Create(frame_dim.xsize, frame_dim.ysize));
    318   ZeroFillImage(&is_background);
    319   JXL_ASSIGN_OR_RETURN(Image3F background,
    320                        Image3F::Create(frame_dim.xsize, frame_dim.ysize));
    321   ZeroFillImage(&background);
    322   constexpr size_t kDistanceLimit = 50;
    323   float* JXL_RESTRICT background_rows[3] = {
    324       background.PlaneRow(0, 0),
    325       background.PlaneRow(1, 0),
    326       background.PlaneRow(2, 0),
    327   };
    328   const size_t background_stride = background.PixelsPerRow();
    329   uint8_t* JXL_RESTRICT is_background_row = is_background.Row(0);
    330   const size_t is_background_stride = is_background.PixelsPerRow();
    331   std::vector<
    332       std::pair<std::pair<uint32_t, uint32_t>, std::pair<uint32_t, uint32_t>>>
    333       queue;
    334   size_t queue_front = 0;
    335   for (size_t y = 0; y < frame_dim.ysize; y++) {
    336     for (size_t x = 0; x < frame_dim.xsize; x++) {
    337       if (!screenshot_row[screenshot_stride * (y / kPatchSide) +
    338                           (x / kPatchSide)])
    339         continue;
    340       queue.push_back({{x, y}, {x, y}});
    341     }
    342   }
    343   while (queue.size() != queue_front) {
    344     std::pair<uint32_t, uint32_t> cur = queue[queue_front].first;
    345     std::pair<uint32_t, uint32_t> src = queue[queue_front].second;
    346     queue_front++;
    347     if (is_background_row[cur.second * is_background_stride + cur.first])
    348       continue;
    349     is_background_row[cur.second * is_background_stride + cur.first] = 1;
    350     for (size_t c = 0; c < 3; c++) {
    351       background_rows[c][cur.second * background_stride + cur.first] =
    352           opsin_rows[c][src.second * opsin_stride + src.first];
    353     }
    354     for (int dx = -kSearchRadius; dx <= kSearchRadius; dx++) {
    355       for (int dy = -kSearchRadius; dy <= kSearchRadius; dy++) {
    356         if (dx == 0 && dy == 0) continue;
    357         int next_first = cur.first + dx;
    358         int next_second = cur.second + dy;
    359         if (next_first < 0 || next_second < 0 ||
    360             static_cast<uint32_t>(next_first) >= frame_dim.xsize ||
    361             static_cast<uint32_t>(next_second) >= frame_dim.ysize) {
    362           continue;
    363         }
    364         if (static_cast<uint32_t>(
    365                 std::abs(next_first - static_cast<int>(src.first)) +
    366                 std::abs(next_second - static_cast<int>(src.second))) >
    367             kDistanceLimit) {
    368           continue;
    369         }
    370         std::pair<uint32_t, uint32_t> next{next_first, next_second};
    371         if (is_similar(src, next)) {
    372           if (!screenshot_row[next.second / kPatchSide * screenshot_stride +
    373                               next.first / kPatchSide] ||
    374               is_same(src, next)) {
    375             if (!is_background_row[next.second * is_background_stride +
    376                                    next.first])
    377               queue.emplace_back(next, src);
    378           }
    379         }
    380       }
    381     }
    382   }
    383   queue.clear();
    384 
    385   ImageF ccs;
    386   Rng rng(0);
    387   bool paint_ccs = false;
    388   if (WantDebugOutput(cparams)) {
    389     JXL_RETURN_IF_ERROR(
    390         DumpPlaneNormalized(cparams, "is_background", is_background));
    391     if (is_xyb) {
    392       JXL_RETURN_IF_ERROR(DumpXybImage(cparams, "background", background));
    393     } else {
    394       JXL_RETURN_IF_ERROR(DumpImage(cparams, "background", background));
    395     }
    396     JXL_ASSIGN_OR_RETURN(ccs, ImageF::Create(frame_dim.xsize, frame_dim.ysize));
    397     ZeroFillImage(&ccs);
    398     paint_ccs = true;
    399   }
    400 
    401   constexpr float kVerySimilarThreshold = 0.03f;
    402   constexpr float kHasSimilarThreshold = 0.03f;
    403 
    404   const float* JXL_RESTRICT const_background_rows[3] = {
    405       background_rows[0], background_rows[1], background_rows[2]};
    406   auto is_similar_b = [&](std::pair<int, int> p1, std::pair<int, int> p2) {
    407     return is_similar_impl(p1, p2, const_background_rows, background_stride,
    408                            kVerySimilarThreshold);
    409   };
    410 
    411   constexpr int kMinPeak = 2;
    412   constexpr int kHasSimilarRadius = 2;
    413 
    414   // Find small CC outside the "similar enough" areas, compute bounding boxes,
    415   // and run heuristics to exclude some patches.
    416   JXL_ASSIGN_OR_RETURN(ImageB visited,
    417                        ImageB::Create(frame_dim.xsize, frame_dim.ysize));
    418   ZeroFillImage(&visited);
    419   uint8_t* JXL_RESTRICT visited_row = visited.Row(0);
    420   const size_t visited_stride = visited.PixelsPerRow();
    421   std::vector<std::pair<uint32_t, uint32_t>> cc;
    422   std::vector<std::pair<uint32_t, uint32_t>> stack;
    423   for (size_t y = 0; y < frame_dim.ysize; y++) {
    424     for (size_t x = 0; x < frame_dim.xsize; x++) {
    425       if (is_background_row[y * is_background_stride + x]) continue;
    426       cc.clear();
    427       stack.clear();
    428       stack.emplace_back(x, y);
    429       size_t min_x = x;
    430       size_t max_x = x;
    431       size_t min_y = y;
    432       size_t max_y = y;
    433       std::pair<uint32_t, uint32_t> reference;
    434       bool found_border = false;
    435       bool all_similar = true;
    436       while (!stack.empty()) {
    437         std::pair<uint32_t, uint32_t> cur = stack.back();
    438         stack.pop_back();
    439         if (visited_row[cur.second * visited_stride + cur.first]) continue;
    440         visited_row[cur.second * visited_stride + cur.first] = 1;
    441         if (cur.first < min_x) min_x = cur.first;
    442         if (cur.first > max_x) max_x = cur.first;
    443         if (cur.second < min_y) min_y = cur.second;
    444         if (cur.second > max_y) max_y = cur.second;
    445         if (paint_ccs) {
    446           cc.push_back(cur);
    447         }
    448         for (int dx = -kSearchRadius; dx <= kSearchRadius; dx++) {
    449           for (int dy = -kSearchRadius; dy <= kSearchRadius; dy++) {
    450             if (dx == 0 && dy == 0) continue;
    451             int next_first = static_cast<int32_t>(cur.first) + dx;
    452             int next_second = static_cast<int32_t>(cur.second) + dy;
    453             if (next_first < 0 || next_second < 0 ||
    454                 static_cast<uint32_t>(next_first) >= frame_dim.xsize ||
    455                 static_cast<uint32_t>(next_second) >= frame_dim.ysize) {
    456               continue;
    457             }
    458             std::pair<uint32_t, uint32_t> next{next_first, next_second};
    459             if (!is_background_row[next.second * is_background_stride +
    460                                    next.first]) {
    461               stack.push_back(next);
    462             } else {
    463               if (!found_border) {
    464                 reference = next;
    465                 found_border = true;
    466               } else {
    467                 if (!is_similar_b(next, reference)) all_similar = false;
    468               }
    469             }
    470           }
    471         }
    472       }
    473       if (!found_border || !all_similar || max_x - min_x >= kMaxPatchSize ||
    474           max_y - min_y >= kMaxPatchSize) {
    475         continue;
    476       }
    477       size_t bpos = background_stride * reference.second + reference.first;
    478       float ref[3] = {background_rows[0][bpos], background_rows[1][bpos],
    479                       background_rows[2][bpos]};
    480       bool has_similar = false;
    481       for (size_t iy = std::max<int>(
    482                static_cast<int32_t>(min_y) - kHasSimilarRadius, 0);
    483            iy < std::min(max_y + kHasSimilarRadius + 1, frame_dim.ysize);
    484            iy++) {
    485         for (size_t ix = std::max<int>(
    486                  static_cast<int32_t>(min_x) - kHasSimilarRadius, 0);
    487              ix < std::min(max_x + kHasSimilarRadius + 1, frame_dim.xsize);
    488              ix++) {
    489           size_t opos = opsin_stride * iy + ix;
    490           float px[3] = {opsin_rows[0][opos], opsin_rows[1][opos],
    491                          opsin_rows[2][opos]};
    492           if (pci.is_similar_v(ref, px, kHasSimilarThreshold)) {
    493             has_similar = true;
    494           }
    495         }
    496       }
    497       if (!has_similar) continue;
    498       info.emplace_back();
    499       info.back().second.emplace_back(min_x, min_y);
    500       QuantizedPatch& patch = info.back().first;
    501       patch.xsize = max_x - min_x + 1;
    502       patch.ysize = max_y - min_y + 1;
    503       int max_value = 0;
    504       for (size_t c : {1, 0, 2}) {
    505         for (size_t iy = min_y; iy <= max_y; iy++) {
    506           for (size_t ix = min_x; ix <= max_x; ix++) {
    507             size_t offset = (iy - min_y) * patch.xsize + ix - min_x;
    508             patch.fpixels[c][offset] =
    509                 opsin_rows[c][iy * opsin_stride + ix] - ref[c];
    510             int val = pci.Quantize(patch.fpixels[c][offset], c);
    511             patch.pixels[c][offset] = val;
    512             if (std::abs(val) > max_value) max_value = std::abs(val);
    513           }
    514         }
    515       }
    516       if (max_value < kMinPeak) {
    517         info.pop_back();
    518         continue;
    519       }
    520       if (paint_ccs) {
    521         float cc_color = rng.UniformF(0.5, 1.0);
    522         for (std::pair<uint32_t, uint32_t> p : cc) {
    523           ccs.Row(p.second)[p.first] = cc_color;
    524         }
    525       }
    526     }
    527   }
    528 
    529   if (paint_ccs) {
    530     JXL_ASSERT(WantDebugOutput(cparams));
    531     JXL_RETURN_IF_ERROR(DumpPlaneNormalized(cparams, "ccs", ccs));
    532   }
    533   if (info.empty()) {
    534     return info;
    535   }
    536 
    537   // Remove duplicates.
    538   constexpr size_t kMinPatchOccurrences = 2;
    539   std::sort(info.begin(), info.end());
    540   size_t unique = 0;
    541   for (size_t i = 1; i < info.size(); i++) {
    542     if (info[i].first == info[unique].first) {
    543       info[unique].second.insert(info[unique].second.end(),
    544                                  info[i].second.begin(), info[i].second.end());
    545     } else {
    546       if (info[unique].second.size() >= kMinPatchOccurrences) {
    547         unique++;
    548       }
    549       info[unique] = info[i];
    550     }
    551   }
    552   if (info[unique].second.size() >= kMinPatchOccurrences) {
    553     unique++;
    554   }
    555   info.resize(unique);
    556 
    557   size_t max_patch_size = 0;
    558 
    559   for (size_t i = 0; i < info.size(); i++) {
    560     size_t pixels = info[i].first.xsize * info[i].first.ysize;
    561     if (pixels > max_patch_size) max_patch_size = pixels;
    562   }
    563 
    564   // don't use patches if all patches are smaller than this
    565   constexpr size_t kMinMaxPatchSize = 20;
    566   if (max_patch_size < kMinMaxPatchSize) {
    567     info.clear();
    568   }
    569 
    570   return info;
    571 }
    572 
    573 }  // namespace
    574 
    575 Status FindBestPatchDictionary(const Image3F& opsin,
    576                                PassesEncoderState* JXL_RESTRICT state,
    577                                const JxlCmsInterface& cms, ThreadPool* pool,
    578                                AuxOut* aux_out, bool is_xyb) {
    579   JXL_ASSIGN_OR_RETURN(
    580       std::vector<PatchInfo> info,
    581       FindTextLikePatches(state->cparams, opsin, state, pool, aux_out, is_xyb));
    582 
    583   // TODO(veluca): this doesn't work if both dots and patches are enabled.
    584   // For now, since dots and patches are not likely to occur in the same kind of
    585   // images, disable dots if some patches were found.
    586   if (info.empty() &&
    587       ApplyOverride(
    588           state->cparams.dots,
    589           state->cparams.speed_tier <= SpeedTier::kSquirrel &&
    590               state->cparams.butteraugli_distance >= kMinButteraugliForDots)) {
    591     Rect rect(0, 0, state->shared.frame_dim.xsize,
    592               state->shared.frame_dim.ysize);
    593     JXL_ASSIGN_OR_RETURN(info, FindDotDictionary(state->cparams, opsin, rect,
    594                                                  state->shared.cmap, pool));
    595   }
    596 
    597   if (info.empty()) return true;
    598 
    599   std::sort(
    600       info.begin(), info.end(), [&](const PatchInfo& a, const PatchInfo& b) {
    601         return a.first.xsize * a.first.ysize > b.first.xsize * b.first.ysize;
    602       });
    603 
    604   size_t max_x_size = 0;
    605   size_t max_y_size = 0;
    606   size_t total_pixels = 0;
    607 
    608   for (size_t i = 0; i < info.size(); i++) {
    609     size_t pixels = info[i].first.xsize * info[i].first.ysize;
    610     if (max_x_size < info[i].first.xsize) max_x_size = info[i].first.xsize;
    611     if (max_y_size < info[i].first.ysize) max_y_size = info[i].first.ysize;
    612     total_pixels += pixels;
    613   }
    614 
    615   // Bin-packing & conversion of patches.
    616   constexpr float kBinPackingSlackness = 1.05f;
    617   size_t ref_xsize = std::max<float>(max_x_size, std::sqrt(total_pixels));
    618   size_t ref_ysize = std::max<float>(max_y_size, std::sqrt(total_pixels));
    619   std::vector<std::pair<size_t, size_t>> ref_positions(info.size());
    620   // TODO(veluca): allow partial overlaps of patches that have the same pixels.
    621   size_t max_y = 0;
    622   do {
    623     max_y = 0;
    624     // Increase packed image size.
    625     ref_xsize = ref_xsize * kBinPackingSlackness + 1;
    626     ref_ysize = ref_ysize * kBinPackingSlackness + 1;
    627 
    628     JXL_ASSIGN_OR_RETURN(ImageB occupied, ImageB::Create(ref_xsize, ref_ysize));
    629     ZeroFillImage(&occupied);
    630     uint8_t* JXL_RESTRICT occupied_rows = occupied.Row(0);
    631     size_t occupied_stride = occupied.PixelsPerRow();
    632 
    633     bool success = true;
    634     // For every patch...
    635     for (size_t patch = 0; patch < info.size(); patch++) {
    636       size_t x0 = 0;
    637       size_t y0 = 0;
    638       size_t xsize = info[patch].first.xsize;
    639       size_t ysize = info[patch].first.ysize;
    640       bool found = false;
    641       // For every possible start position ...
    642       for (; y0 + ysize <= ref_ysize; y0++) {
    643         x0 = 0;
    644         for (; x0 + xsize <= ref_xsize; x0++) {
    645           bool has_occupied_pixel = false;
    646           size_t x = x0;
    647           // Check if it is possible to place the patch in this position in the
    648           // reference frame.
    649           for (size_t y = y0; y < y0 + ysize; y++) {
    650             x = x0;
    651             for (; x < x0 + xsize; x++) {
    652               if (occupied_rows[y * occupied_stride + x]) {
    653                 has_occupied_pixel = true;
    654                 break;
    655               }
    656             }
    657           }  // end of positioning check
    658           if (!has_occupied_pixel) {
    659             found = true;
    660             break;
    661           }
    662           x0 = x;  // Jump to next pixel after the occupied one.
    663         }
    664         if (found) break;
    665       }  // end of start position checking
    666 
    667       // We didn't find a possible position: repeat from the beginning with a
    668       // larger reference frame size.
    669       if (!found) {
    670         success = false;
    671         break;
    672       }
    673 
    674       // We found a position: mark the corresponding positions in the reference
    675       // image as used.
    676       ref_positions[patch] = {x0, y0};
    677       for (size_t y = y0; y < y0 + ysize; y++) {
    678         for (size_t x = x0; x < x0 + xsize; x++) {
    679           occupied_rows[y * occupied_stride + x] = JXL_TRUE;
    680         }
    681       }
    682       max_y = std::max(max_y, y0 + ysize);
    683     }
    684 
    685     if (success) break;
    686   } while (true);
    687 
    688   JXL_ASSERT(ref_ysize >= max_y);
    689 
    690   ref_ysize = max_y;
    691 
    692   JXL_ASSIGN_OR_RETURN(Image3F reference_frame,
    693                        Image3F::Create(ref_xsize, ref_ysize));
    694   // TODO(veluca): figure out a better way to fill the image.
    695   ZeroFillImage(&reference_frame);
    696   std::vector<PatchPosition> positions;
    697   std::vector<PatchReferencePosition> pref_positions;
    698   std::vector<PatchBlending> blendings;
    699   float* JXL_RESTRICT ref_rows[3] = {
    700       reference_frame.PlaneRow(0, 0),
    701       reference_frame.PlaneRow(1, 0),
    702       reference_frame.PlaneRow(2, 0),
    703   };
    704   size_t ref_stride = reference_frame.PixelsPerRow();
    705   size_t num_ec = state->shared.metadata->m.num_extra_channels;
    706 
    707   for (size_t i = 0; i < info.size(); i++) {
    708     PatchReferencePosition ref_pos;
    709     ref_pos.xsize = info[i].first.xsize;
    710     ref_pos.ysize = info[i].first.ysize;
    711     ref_pos.x0 = ref_positions[i].first;
    712     ref_pos.y0 = ref_positions[i].second;
    713     ref_pos.ref = kPatchFrameReferenceId;
    714     for (size_t y = 0; y < ref_pos.ysize; y++) {
    715       for (size_t x = 0; x < ref_pos.xsize; x++) {
    716         for (size_t c = 0; c < 3; c++) {
    717           ref_rows[c][(y + ref_pos.y0) * ref_stride + x + ref_pos.x0] =
    718               info[i].first.fpixels[c][y * ref_pos.xsize + x];
    719         }
    720       }
    721     }
    722     for (const auto& pos : info[i].second) {
    723       JXL_DEBUG_V(4, "Patch %" PRIuS "x%" PRIuS " at position %u,%u",
    724                   ref_pos.xsize, ref_pos.ysize, pos.first, pos.second);
    725       positions.emplace_back(
    726           PatchPosition{pos.first, pos.second, pref_positions.size()});
    727       // Add blending for color channels, ignore other channels.
    728       blendings.push_back({PatchBlendMode::kAdd, 0, false});
    729       for (size_t j = 0; j < num_ec; ++j) {
    730         blendings.push_back({PatchBlendMode::kNone, 0, false});
    731       }
    732     }
    733     pref_positions.emplace_back(ref_pos);
    734   }
    735 
    736   CompressParams cparams = state->cparams;
    737   // Recursive application of patches could create very weird issues.
    738   cparams.patches = Override::kOff;
    739 
    740   JXL_RETURN_IF_ERROR(RoundtripPatchFrame(&reference_frame, state,
    741                                           kPatchFrameReferenceId, cparams, cms,
    742                                           pool, aux_out, /*subtract=*/true));
    743 
    744   // TODO(veluca): this assumes that applying patches is commutative, which is
    745   // not true for all blending modes. This code only produces kAdd patches, so
    746   // this works out.
    747   PatchDictionaryEncoder::SetPositions(
    748       &state->shared.image_features.patches, std::move(positions),
    749       std::move(pref_positions), std::move(blendings));
    750   return true;
    751 }
    752 
    753 Status RoundtripPatchFrame(Image3F* reference_frame,
    754                            PassesEncoderState* JXL_RESTRICT state, int idx,
    755                            CompressParams& cparams, const JxlCmsInterface& cms,
    756                            ThreadPool* pool, AuxOut* aux_out, bool subtract) {
    757   FrameInfo patch_frame_info;
    758   cparams.resampling = 1;
    759   cparams.ec_resampling = 1;
    760   cparams.dots = Override::kOff;
    761   cparams.noise = Override::kOff;
    762   cparams.modular_mode = true;
    763   cparams.responsive = 0;
    764   cparams.progressive_dc = 0;
    765   cparams.progressive_mode = Override::kOff;
    766   cparams.qprogressive_mode = Override::kOff;
    767   // Use gradient predictor and not Predictor::Best.
    768   cparams.options.predictor = Predictor::Gradient;
    769   patch_frame_info.save_as_reference = idx;  // always saved.
    770   patch_frame_info.frame_type = FrameType::kReferenceOnly;
    771   patch_frame_info.save_before_color_transform = true;
    772   ImageBundle ib(&state->shared.metadata->m);
    773   // TODO(veluca): metadata.color_encoding is a lie: ib is in XYB, but there is
    774   // no simple way to express that yet.
    775   patch_frame_info.ib_needs_color_transform = false;
    776   ib.SetFromImage(std::move(*reference_frame),
    777                   state->shared.metadata->m.color_encoding);
    778   if (!ib.metadata()->extra_channel_info.empty()) {
    779     // Add placeholder extra channels to the patch image: patch encoding does
    780     // not yet support extra channels, but the codec expects that the amount of
    781     // extra channels in frames matches that in the metadata of the codestream.
    782     std::vector<ImageF> extra_channels;
    783     extra_channels.reserve(ib.metadata()->extra_channel_info.size());
    784     for (size_t i = 0; i < ib.metadata()->extra_channel_info.size(); i++) {
    785       JXL_ASSIGN_OR_RETURN(ImageF ch, ImageF::Create(ib.xsize(), ib.ysize()));
    786       extra_channels.emplace_back(std::move(ch));
    787       // Must initialize the image with data to not affect blending with
    788       // uninitialized memory.
    789       // TODO(lode): patches must copy and use the real extra channels instead.
    790       ZeroFillImage(&extra_channels.back());
    791     }
    792     ib.SetExtraChannels(std::move(extra_channels));
    793   }
    794   auto special_frame = std::unique_ptr<BitWriter>(new BitWriter());
    795   AuxOut patch_aux_out;
    796   JXL_CHECK(EncodeFrame(cparams, patch_frame_info, state->shared.metadata, ib,
    797                         cms, pool, special_frame.get(),
    798                         aux_out ? &patch_aux_out : nullptr));
    799   if (aux_out) {
    800     for (const auto& l : patch_aux_out.layers) {
    801       aux_out->layers[kLayerDictionary].Assimilate(l);
    802     }
    803   }
    804   const Span<const uint8_t> encoded = special_frame->GetSpan();
    805   state->special_frames.emplace_back(std::move(special_frame));
    806   if (subtract) {
    807     ImageBundle decoded(&state->shared.metadata->m);
    808     PassesDecoderState dec_state;
    809     JXL_CHECK(dec_state.output_encoding_info.SetFromMetadata(
    810         *state->shared.metadata));
    811     const uint8_t* frame_start = encoded.data();
    812     size_t encoded_size = encoded.size();
    813     JXL_CHECK(DecodeFrame(&dec_state, pool, frame_start, encoded_size,
    814                           /*frame_header=*/nullptr, &decoded,
    815                           *state->shared.metadata));
    816     frame_start += decoded.decoded_bytes();
    817     encoded_size -= decoded.decoded_bytes();
    818     size_t ref_xsize =
    819         dec_state.shared_storage.reference_frames[idx].frame.color()->xsize();
    820     // if the frame itself uses patches, we need to decode another frame
    821     if (!ref_xsize) {
    822       JXL_CHECK(DecodeFrame(&dec_state, pool, frame_start, encoded_size,
    823                             /*frame_header=*/nullptr, &decoded,
    824                             *state->shared.metadata));
    825     }
    826     JXL_CHECK(encoded_size == 0);
    827     state->shared.reference_frames[idx] =
    828         std::move(dec_state.shared_storage.reference_frames[idx]);
    829   } else {
    830     state->shared.reference_frames[idx].frame = std::move(ib);
    831   }
    832   return true;
    833 }
    834 
    835 }  // namespace jxl