libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

dec_patch_dictionary.cc (13036B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/dec_patch_dictionary.h"
      7 
      8 #include <stdint.h>
      9 #include <stdlib.h>
     10 #include <sys/types.h>
     11 
     12 #include <algorithm>
     13 #include <utility>
     14 #include <vector>
     15 
     16 #include "lib/jxl/base/printf_macros.h"
     17 #include "lib/jxl/base/status.h"
     18 #include "lib/jxl/blending.h"
     19 #include "lib/jxl/common.h"  // kMaxNumReferenceFrames
     20 #include "lib/jxl/dec_ans.h"
     21 #include "lib/jxl/image.h"
     22 #include "lib/jxl/image_bundle.h"
     23 #include "lib/jxl/pack_signed.h"
     24 #include "lib/jxl/patch_dictionary_internal.h"
     25 
     26 namespace jxl {
     27 
     28 Status PatchDictionary::Decode(BitReader* br, size_t xsize, size_t ysize,
     29                                bool* uses_extra_channels) {
     30   positions_.clear();
     31   std::vector<uint8_t> context_map;
     32   ANSCode code;
     33   JXL_RETURN_IF_ERROR(
     34       DecodeHistograms(br, kNumPatchDictionaryContexts, &code, &context_map));
     35   ANSSymbolReader decoder(&code, br);
     36 
     37   auto read_num = [&](size_t context) {
     38     size_t r = decoder.ReadHybridUint(context, br, context_map);
     39     return r;
     40   };
     41 
     42   size_t num_ref_patch = read_num(kNumRefPatchContext);
     43   // Limit max memory usage of patches to about 66 bytes per pixel (assuming 8
     44   // bytes per size_t)
     45   const size_t num_pixels = xsize * ysize;
     46   const size_t max_ref_patches = 1024 + num_pixels / 4;
     47   const size_t max_patches = max_ref_patches * 4;
     48   const size_t max_blending_infos = max_patches * 4;
     49   if (num_ref_patch > max_ref_patches) {
     50     return JXL_FAILURE("Too many patches in dictionary");
     51   }
     52   size_t num_ec = shared_->metadata->m.num_extra_channels;
     53 
     54   size_t total_patches = 0;
     55   size_t next_size = 1;
     56 
     57   for (size_t id = 0; id < num_ref_patch; id++) {
     58     PatchReferencePosition ref_pos;
     59     ref_pos.ref = read_num(kReferenceFrameContext);
     60     if (ref_pos.ref >= kMaxNumReferenceFrames ||
     61         shared_->reference_frames[ref_pos.ref].frame.xsize() == 0) {
     62       return JXL_FAILURE("Invalid reference frame ID");
     63     }
     64     if (!shared_->reference_frames[ref_pos.ref].ib_is_in_xyb) {
     65       return JXL_FAILURE(
     66           "Patches cannot use frames saved post color transforms");
     67     }
     68     const ImageBundle& ib = shared_->reference_frames[ref_pos.ref].frame;
     69     ref_pos.x0 = read_num(kPatchReferencePositionContext);
     70     ref_pos.y0 = read_num(kPatchReferencePositionContext);
     71     ref_pos.xsize = read_num(kPatchSizeContext) + 1;
     72     ref_pos.ysize = read_num(kPatchSizeContext) + 1;
     73     if (ref_pos.x0 + ref_pos.xsize > ib.xsize()) {
     74       return JXL_FAILURE("Invalid position specified in reference frame");
     75     }
     76     if (ref_pos.y0 + ref_pos.ysize > ib.ysize()) {
     77       return JXL_FAILURE("Invalid position specified in reference frame");
     78     }
     79     size_t id_count = read_num(kPatchCountContext);
     80     if (id_count > max_patches) {
     81       return JXL_FAILURE("Too many patches in dictionary");
     82     }
     83     id_count++;
     84     total_patches += id_count;
     85     if (total_patches > max_patches) {
     86       return JXL_FAILURE("Too many patches in dictionary");
     87     }
     88     if (next_size < total_patches) {
     89       next_size *= 2;
     90       next_size = std::min<size_t>(next_size, max_patches);
     91     }
     92     if (next_size * (num_ec + 1) > max_blending_infos) {
     93       return JXL_FAILURE("Too many patches in dictionary");
     94     }
     95     positions_.reserve(next_size);
     96     blendings_.reserve(next_size * (num_ec + 1));
     97     for (size_t i = 0; i < id_count; i++) {
     98       PatchPosition pos;
     99       pos.ref_pos_idx = ref_positions_.size();
    100       if (i == 0) {
    101         pos.x = read_num(kPatchPositionContext);
    102         pos.y = read_num(kPatchPositionContext);
    103       } else {
    104         ssize_t deltax = UnpackSigned(read_num(kPatchOffsetContext));
    105         if (deltax < 0 && static_cast<size_t>(-deltax) > positions_.back().x) {
    106           return JXL_FAILURE("Invalid patch: negative x coordinate (%" PRIuS
    107                              " base x %" PRIdS " delta x)",
    108                              positions_.back().x, deltax);
    109         }
    110         pos.x = positions_.back().x + deltax;
    111         ssize_t deltay = UnpackSigned(read_num(kPatchOffsetContext));
    112         if (deltay < 0 && static_cast<size_t>(-deltay) > positions_.back().y) {
    113           return JXL_FAILURE("Invalid patch: negative y coordinate (%" PRIuS
    114                              " base y %" PRIdS " delta y)",
    115                              positions_.back().y, deltay);
    116         }
    117         pos.y = positions_.back().y + deltay;
    118       }
    119       if (pos.x + ref_pos.xsize > xsize) {
    120         return JXL_FAILURE("Invalid patch x: at %" PRIuS " + %" PRIuS
    121                            " > %" PRIuS,
    122                            pos.x, ref_pos.xsize, xsize);
    123       }
    124       if (pos.y + ref_pos.ysize > ysize) {
    125         return JXL_FAILURE("Invalid patch y: at %" PRIuS " + %" PRIuS
    126                            " > %" PRIuS,
    127                            pos.y, ref_pos.ysize, ysize);
    128       }
    129       for (size_t j = 0; j < num_ec + 1; j++) {
    130         uint32_t blend_mode = read_num(kPatchBlendModeContext);
    131         if (blend_mode >=
    132             static_cast<uint32_t>(PatchBlendMode::kNumBlendModes)) {
    133           return JXL_FAILURE("Invalid patch blend mode: %u", blend_mode);
    134         }
    135         PatchBlending info;
    136         info.mode = static_cast<PatchBlendMode>(blend_mode);
    137         if (UsesAlpha(info.mode)) {
    138           *uses_extra_channels = true;
    139         }
    140         if (info.mode != PatchBlendMode::kNone && j > 0) {
    141           *uses_extra_channels = true;
    142         }
    143         if (UsesAlpha(info.mode) &&
    144             shared_->metadata->m.extra_channel_info.size() > 1) {
    145           info.alpha_channel = read_num(kPatchAlphaChannelContext);
    146           if (info.alpha_channel >=
    147               shared_->metadata->m.extra_channel_info.size()) {
    148             return JXL_FAILURE(
    149                 "Invalid alpha channel for blending: %u out of %u\n",
    150                 info.alpha_channel,
    151                 static_cast<uint32_t>(
    152                     shared_->metadata->m.extra_channel_info.size()));
    153           }
    154         } else {
    155           info.alpha_channel = 0;
    156         }
    157         if (UsesClamp(info.mode)) {
    158           info.clamp = static_cast<bool>(read_num(kPatchClampContext));
    159         } else {
    160           info.clamp = false;
    161         }
    162         blendings_.push_back(info);
    163       }
    164       positions_.emplace_back(pos);
    165     }
    166     ref_positions_.emplace_back(ref_pos);
    167   }
    168   positions_.shrink_to_fit();
    169 
    170   if (!decoder.CheckANSFinalState()) {
    171     return JXL_FAILURE("ANS checksum failure.");
    172   }
    173 
    174   ComputePatchTree();
    175   return true;
    176 }
    177 
    178 int PatchDictionary::GetReferences() const {
    179   int result = 0;
    180   for (const auto& ref_pos : ref_positions_) {
    181     result |= (1 << static_cast<int>(ref_pos.ref));
    182   }
    183   return result;
    184 }
    185 
    186 namespace {
    187 struct PatchInterval {
    188   size_t idx;
    189   size_t y0, y1;
    190 };
    191 }  // namespace
    192 
    193 void PatchDictionary::ComputePatchTree() {
    194   patch_tree_.clear();
    195   num_patches_.clear();
    196   sorted_patches_y0_.clear();
    197   sorted_patches_y1_.clear();
    198   if (positions_.empty()) {
    199     return;
    200   }
    201   // Create a y-interval for each patch.
    202   std::vector<PatchInterval> intervals(positions_.size());
    203   for (size_t i = 0; i < positions_.size(); ++i) {
    204     const auto& pos = positions_[i];
    205     intervals[i].idx = i;
    206     intervals[i].y0 = pos.y;
    207     intervals[i].y1 = pos.y + ref_positions_[pos.ref_pos_idx].ysize;
    208   }
    209   auto sort_by_y0 = [&intervals](size_t start, size_t end) {
    210     std::sort(intervals.data() + start, intervals.data() + end,
    211               [](const PatchInterval& i0, const PatchInterval& i1) {
    212                 return i0.y0 < i1.y0;
    213               });
    214   };
    215   auto sort_by_y1 = [&intervals](size_t start, size_t end) {
    216     std::sort(intervals.data() + start, intervals.data() + end,
    217               [](const PatchInterval& i0, const PatchInterval& i1) {
    218                 return i0.y1 < i1.y1;
    219               });
    220   };
    221   // Count the number of patches for each row.
    222   sort_by_y1(0, intervals.size());
    223   num_patches_.resize(intervals.back().y1);
    224   for (auto iv : intervals) {
    225     for (size_t y = iv.y0; y < iv.y1; ++y) num_patches_[y]++;
    226   }
    227   PatchTreeNode root;
    228   root.start = 0;
    229   root.num = intervals.size();
    230   patch_tree_.push_back(root);
    231   size_t next = 0;
    232   while (next < patch_tree_.size()) {
    233     auto& node = patch_tree_[next];
    234     size_t start = node.start;
    235     size_t end = node.start + node.num;
    236     // Choose the y_center for this node to be the median of interval starts.
    237     sort_by_y0(start, end);
    238     size_t middle_idx = start + node.num / 2;
    239     node.y_center = intervals[middle_idx].y0;
    240     // Divide the intervals in [start, end) into three groups:
    241     //   * those completely to the right of y_center: [right_start, end)
    242     //   * those overlapping y_center: [left_end, right_start)
    243     //   * those completely to the left of y_center: [start, left_end)
    244     size_t right_start = middle_idx;
    245     while (right_start < end && intervals[right_start].y0 == node.y_center) {
    246       ++right_start;
    247     }
    248     sort_by_y1(start, right_start);
    249     size_t left_end = right_start;
    250     while (left_end > start && intervals[left_end - 1].y1 > node.y_center) {
    251       --left_end;
    252     }
    253     // Fill in sorted_patches_y0_ and sorted_patches_y1_ for the current node.
    254     node.num = right_start - left_end;
    255     node.start = sorted_patches_y0_.size();
    256     for (ssize_t i = static_cast<ssize_t>(right_start) - 1;
    257          i >= static_cast<ssize_t>(left_end); --i) {
    258       sorted_patches_y1_.emplace_back(intervals[i].y1, intervals[i].idx);
    259     }
    260     sort_by_y0(left_end, right_start);
    261     for (size_t i = left_end; i < right_start; ++i) {
    262       sorted_patches_y0_.emplace_back(intervals[i].y0, intervals[i].idx);
    263     }
    264     // Create the left and right nodes (if not empty).
    265     node.left_child = node.right_child = -1;
    266     if (left_end > start) {
    267       PatchTreeNode left;
    268       left.start = start;
    269       left.num = left_end - left.start;
    270       patch_tree_[next].left_child = patch_tree_.size();
    271       patch_tree_.push_back(left);
    272     }
    273     if (right_start < end) {
    274       PatchTreeNode right;
    275       right.start = right_start;
    276       right.num = end - right.start;
    277       patch_tree_[next].right_child = patch_tree_.size();
    278       patch_tree_.push_back(right);
    279     }
    280     ++next;
    281   }
    282 }
    283 
    284 std::vector<size_t> PatchDictionary::GetPatchesForRow(size_t y) const {
    285   std::vector<size_t> result;
    286   if (y < num_patches_.size() && num_patches_[y] > 0) {
    287     result.reserve(num_patches_[y]);
    288     for (ssize_t tree_idx = 0; tree_idx != -1;) {
    289       JXL_DASSERT(tree_idx < static_cast<ssize_t>(patch_tree_.size()));
    290       const auto& node = patch_tree_[tree_idx];
    291       if (y <= node.y_center) {
    292         for (size_t i = 0; i < node.num; ++i) {
    293           const auto& p = sorted_patches_y0_[node.start + i];
    294           if (y < p.first) break;
    295           result.push_back(p.second);
    296         }
    297         tree_idx = y < node.y_center ? node.left_child : -1;
    298       } else {
    299         for (size_t i = 0; i < node.num; ++i) {
    300           const auto& p = sorted_patches_y1_[node.start + i];
    301           if (y >= p.first) break;
    302           result.push_back(p.second);
    303         }
    304         tree_idx = node.right_child;
    305       }
    306     }
    307     // Ensure that he relative order of patches that affect the same pixels is
    308     // preserved. This is important for patches that have a blend mode
    309     // different from kAdd.
    310     std::sort(result.begin(), result.end());
    311   }
    312   return result;
    313 }
    314 
    315 // Adds patches to a segment of `xsize` pixels, starting at `inout`, assumed
    316 // to be located at position (x0, y) in the frame.
    317 Status PatchDictionary::AddOneRow(float* const* inout, size_t y, size_t x0,
    318                                   size_t xsize) const {
    319   size_t num_ec = shared_->metadata->m.num_extra_channels;
    320   std::vector<const float*> fg_ptrs(3 + num_ec);
    321   for (size_t pos_idx : GetPatchesForRow(y)) {
    322     const size_t blending_idx = pos_idx * (num_ec + 1);
    323     const PatchPosition& pos = positions_[pos_idx];
    324     const PatchReferencePosition& ref_pos = ref_positions_[pos.ref_pos_idx];
    325     size_t by = pos.y;
    326     size_t bx = pos.x;
    327     size_t patch_xsize = ref_pos.xsize;
    328     JXL_DASSERT(y >= by);
    329     JXL_DASSERT(y < by + ref_pos.ysize);
    330     size_t iy = y - by;
    331     size_t ref = ref_pos.ref;
    332     if (bx >= x0 + xsize) continue;
    333     if (bx + patch_xsize < x0) continue;
    334     size_t patch_x0 = std::max(bx, x0);
    335     size_t patch_x1 = std::min(bx + patch_xsize, x0 + xsize);
    336     for (size_t c = 0; c < 3; c++) {
    337       fg_ptrs[c] = shared_->reference_frames[ref].frame.color().ConstPlaneRow(
    338                        c, ref_pos.y0 + iy) +
    339                    ref_pos.x0 + x0 - bx;
    340     }
    341     for (size_t i = 0; i < num_ec; i++) {
    342       fg_ptrs[3 + i] =
    343           shared_->reference_frames[ref].frame.extra_channels()[i].ConstRow(
    344               ref_pos.y0 + iy) +
    345           ref_pos.x0 + x0 - bx;
    346     }
    347     JXL_RETURN_IF_ERROR(PerformBlending(
    348         inout, fg_ptrs.data(), inout, patch_x0 - x0, patch_x1 - patch_x0,
    349         blendings_[blending_idx], blendings_.data() + blending_idx + 1,
    350         shared_->metadata->m.extra_channel_info));
    351   }
    352   return true;
    353 }
    354 }  // namespace jxl