libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

dec_frame.h (14042B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #ifndef LIB_JXL_DEC_FRAME_H_
      7 #define LIB_JXL_DEC_FRAME_H_
      8 
      9 #include <jxl/decode.h>
     10 #include <jxl/types.h>
     11 #include <stdint.h>
     12 
     13 #include <algorithm>
     14 #include <cstddef>
     15 #include <limits>
     16 #include <utility>
     17 #include <vector>
     18 
     19 #include "lib/jxl/base/common.h"
     20 #include "lib/jxl/base/compiler_specific.h"
     21 #include "lib/jxl/base/data_parallel.h"
     22 #include "lib/jxl/base/status.h"
     23 #include "lib/jxl/common.h"  // JXL_HIGH_PRECISION
     24 #include "lib/jxl/dec_bit_reader.h"
     25 #include "lib/jxl/dec_cache.h"
     26 #include "lib/jxl/dec_modular.h"
     27 #include "lib/jxl/frame_header.h"
     28 #include "lib/jxl/image_bundle.h"
     29 #include "lib/jxl/image_metadata.h"
     30 
     31 namespace jxl {
     32 
     33 // Decodes a frame. Groups may be processed in parallel by `pool`.
     34 // `metadata` is the metadata that applies to all frames of the codestream
     35 // `decoded->metadata` must already be set and must match metadata.m.
     36 // Used in the encoder to model decoder behaviour, and in tests.
     37 Status DecodeFrame(PassesDecoderState* dec_state, ThreadPool* JXL_RESTRICT pool,
     38                    const uint8_t* next_in, size_t avail_in,
     39                    FrameHeader* frame_header, ImageBundle* decoded,
     40                    const CodecMetadata& metadata,
     41                    bool use_slow_rendering_pipeline = false);
     42 
     43 // TODO(veluca): implement "forced drawing".
     44 class FrameDecoder {
     45  public:
     46   // All parameters must outlive the FrameDecoder.
     47   FrameDecoder(PassesDecoderState* dec_state, const CodecMetadata& metadata,
     48                ThreadPool* pool, bool use_slow_rendering_pipeline)
     49       : dec_state_(dec_state),
     50         pool_(pool),
     51         frame_header_(&metadata),
     52         use_slow_rendering_pipeline_(use_slow_rendering_pipeline) {}
     53 
     54   void SetRenderSpotcolors(bool rsc) { render_spotcolors_ = rsc; }
     55   void SetCoalescing(bool c) { coalescing_ = c; }
     56 
     57   // Read FrameHeader and table of contents from the given BitReader.
     58   Status InitFrame(BitReader* JXL_RESTRICT br, ImageBundle* decoded,
     59                    bool is_preview);
     60 
     61   // Checks frame dimensions for their limits, and sets the output
     62   // image buffer.
     63   Status InitFrameOutput();
     64 
     65   struct SectionInfo {
     66     BitReader* JXL_RESTRICT br;
     67     // Logical index of the section, regardless of any permutation that may be
     68     // applied in the table of contents or of the physical position in the file.
     69     size_t id;
     70     // Index of the section in the order of the bytes inside the frame.
     71     size_t index;
     72   };
     73 
     74   struct TocEntry {
     75     size_t size;
     76     size_t id;
     77   };
     78 
     79   enum SectionStatus {
     80     // Processed correctly.
     81     kDone = 0,
     82     // Skipped because other required sections were not yet processed.
     83     kSkipped = 1,
     84     // Skipped because the section was already processed.
     85     kDuplicate = 2,
     86     // Only partially decoded: the section will need to be processed again.
     87     kPartial = 3,
     88   };
     89 
     90   // Processes `num` sections; each SectionInfo contains the index
     91   // of the section and a BitReader that only contains the data of the section.
     92   // `section_status` should point to `num` elements, and will be filled with
     93   // information about whether each section was processed or not.
     94   // A section is a part of the encoded file that is indexed by the TOC.
     95   Status ProcessSections(const SectionInfo* sections, size_t num,
     96                          SectionStatus* section_status);
     97 
     98   // Flushes all the data decoded so far to pixels.
     99   Status Flush();
    100 
    101   // Runs final operations once a frame data is decoded.
    102   // Must be called exactly once per frame, after all calls to ProcessSections.
    103   Status FinalizeFrame();
    104 
    105   // Returns dependencies of this frame on reference ids as a bit mask: bits 0-3
    106   // indicate reference frame 0-3 for patches and blending, bits 4-7 indicate DC
    107   // frames this frame depends on. Only returns a valid result after all calls
    108   // to ProcessSections are finished and before FinalizeFrame.
    109   int References() const;
    110 
    111   // Returns reference id of storage location where this frame is stored as a
    112   // bit flag, or 0 if not stored.
    113   // Matches the bit mask used for GetReferences: bits 0-3 indicate it is stored
    114   // for patching or blending, bits 4-7 indicate DC frame.
    115   // Unlike References, can be ran at any time as
    116   // soon as the frame header is known.
    117   static int SavedAs(const FrameHeader& header);
    118 
    119   uint64_t SumSectionSizes() const { return section_sizes_sum_; }
    120   const std::vector<TocEntry>& Toc() const { return toc_; }
    121 
    122   const FrameHeader& GetFrameHeader() const { return frame_header_; }
    123 
    124   // Returns whether a DC image has been decoded, accessible at low resolution
    125   // at passes.shared_storage.dc_storage
    126   bool HasDecodedDC() const { return finalized_dc_; }
    127   bool HasDecodedAll() const { return toc_.size() == num_sections_done_; }
    128 
    129   size_t NumCompletePasses() const {
    130     return *std::min_element(decoded_passes_per_ac_group_.begin(),
    131                              decoded_passes_per_ac_group_.end());
    132   }
    133 
    134   // If enabled, ProcessSections will stop and return true when the DC
    135   // sections have been processed, instead of starting the AC sections. This
    136   // will only occur if supported (that is, flushing will produce a valid
    137   // 1/8th*1/8th resolution image). The return value of true then does not mean
    138   // all sections have been processed, use HasDecodedDC and HasDecodedAll
    139   // to check the true finished state.
    140   // Returns the progressive detail that will be effective for the frame.
    141   JxlProgressiveDetail SetPauseAtProgressive(JxlProgressiveDetail prog_detail) {
    142     bool single_section =
    143         frame_dim_.num_groups == 1 && frame_header_.passes.num_passes == 1;
    144     if (frame_header_.frame_type != kSkipProgressive &&
    145         // If there's only one group and one pass, there is no separate section
    146         // for DC and the entire full resolution image is available at once.
    147         !single_section &&
    148         // If extra channels are encoded with modular without squeeze, they
    149         // don't support DC. If the are encoded with squeeze, DC works in theory
    150         // but the implementation may not yet correctly support this for Flush.
    151         // Therefore, can't correctly pause for a progressive step if there is
    152         // an extra channel (including alpha channel)
    153         // TODO(firsching): Check if this is still the case.
    154         decoded_->metadata()->extra_channel_info.empty() &&
    155         // DC is not guaranteed to be available in modular mode and may be a
    156         // black image. If squeeze is used, it may be available depending on the
    157         // current implementation.
    158         // TODO(lode): do return DC if it's known that flushing at this point
    159         // will produce a valid 1/8th downscaled image with modular encoding.
    160         frame_header_.encoding == FrameEncoding::kVarDCT) {
    161       progressive_detail_ = prog_detail;
    162     } else {
    163       progressive_detail_ = JxlProgressiveDetail::kFrames;
    164     }
    165     if (progressive_detail_ >= JxlProgressiveDetail::kPasses) {
    166       for (size_t i = 1; i < frame_header_.passes.num_passes; ++i) {
    167         passes_to_pause_.push_back(i);
    168       }
    169     } else if (progressive_detail_ >= JxlProgressiveDetail::kLastPasses) {
    170       for (size_t i = 0; i < frame_header_.passes.num_downsample; ++i) {
    171         passes_to_pause_.push_back(frame_header_.passes.last_pass[i] + 1);
    172       }
    173       // The format does not guarantee that these values are sorted.
    174       std::sort(passes_to_pause_.begin(), passes_to_pause_.end());
    175     }
    176     return progressive_detail_;
    177   }
    178 
    179   size_t NextNumPassesToPause() const {
    180     auto it = std::upper_bound(passes_to_pause_.begin(), passes_to_pause_.end(),
    181                                NumCompletePasses());
    182     return (it != passes_to_pause_.end() ? *it
    183                                          : std::numeric_limits<size_t>::max());
    184   }
    185 
    186   // Sets the pixel callback or image buffer where the pixels will be decoded.
    187   //
    188   // @param undo_orientation: if true, indicates the frame decoder should apply
    189   // the exif orientation to bring the image to the intended display
    190   // orientation.
    191   void SetImageOutput(const PixelCallback& pixel_callback, void* image_buffer,
    192                       size_t image_buffer_size, size_t xsize, size_t ysize,
    193                       JxlPixelFormat format, size_t bits_per_sample,
    194                       bool unpremul_alpha, bool undo_orientation) const {
    195     dec_state_->width = xsize;
    196     dec_state_->height = ysize;
    197     dec_state_->main_output.format = format;
    198     dec_state_->main_output.bits_per_sample = bits_per_sample;
    199     dec_state_->main_output.callback = pixel_callback;
    200     dec_state_->main_output.buffer = image_buffer;
    201     dec_state_->main_output.buffer_size = image_buffer_size;
    202     dec_state_->main_output.stride = GetStride(xsize, format);
    203     const jxl::ExtraChannelInfo* alpha =
    204         decoded_->metadata()->Find(jxl::ExtraChannel::kAlpha);
    205     if (alpha && alpha->alpha_associated && unpremul_alpha) {
    206       dec_state_->unpremul_alpha = true;
    207     }
    208     if (undo_orientation) {
    209       dec_state_->undo_orientation = decoded_->metadata()->GetOrientation();
    210       if (static_cast<int>(dec_state_->undo_orientation) > 4) {
    211         std::swap(dec_state_->width, dec_state_->height);
    212       }
    213     }
    214     dec_state_->extra_output.clear();
    215 #if !JXL_HIGH_PRECISION
    216     if (dec_state_->main_output.buffer &&
    217         (format.data_type == JXL_TYPE_UINT8) && (format.num_channels >= 3) &&
    218         !dec_state_->unpremul_alpha &&
    219         (dec_state_->undo_orientation == Orientation::kIdentity) &&
    220         decoded_->metadata()->xyb_encoded &&
    221         dec_state_->output_encoding_info.color_encoding.IsSRGB() &&
    222         dec_state_->output_encoding_info.all_default_opsin &&
    223         (dec_state_->output_encoding_info.desired_intensity_target ==
    224          dec_state_->output_encoding_info.orig_intensity_target) &&
    225         HasFastXYBTosRGB8() && frame_header_.needs_color_transform()) {
    226       dec_state_->fast_xyb_srgb8_conversion = true;
    227     }
    228 #endif
    229   }
    230 
    231   void AddExtraChannelOutput(void* buffer, size_t buffer_size, size_t xsize,
    232                              JxlPixelFormat format, size_t bits_per_sample) {
    233     ImageOutput out;
    234     out.format = format;
    235     out.bits_per_sample = bits_per_sample;
    236     out.buffer = buffer;
    237     out.buffer_size = buffer_size;
    238     out.stride = GetStride(xsize, format);
    239     dec_state_->extra_output.push_back(out);
    240   }
    241 
    242  private:
    243   Status ProcessDCGlobal(BitReader* br);
    244   Status ProcessDCGroup(size_t dc_group_id, BitReader* br);
    245   Status FinalizeDC();
    246   Status AllocateOutput();
    247   Status ProcessACGlobal(BitReader* br);
    248   Status ProcessACGroup(size_t ac_group_id, BitReader* JXL_RESTRICT* br,
    249                         size_t num_passes, size_t thread, bool force_draw,
    250                         bool dc_only);
    251   void MarkSections(const SectionInfo* sections, size_t num,
    252                     const SectionStatus* section_status);
    253 
    254   // Allocates storage for parallel decoding using up to `num_threads` threads
    255   // of up to `num_tasks` tasks. The value of `thread` passed to
    256   // `GetStorageLocation` must be smaller than the `num_threads` value passed
    257   // here. The value of `task` passed to `GetStorageLocation` must be smaller
    258   // than the value of `num_tasks` passed here.
    259   Status PrepareStorage(size_t num_threads, size_t num_tasks) {
    260     size_t storage_size = std::min(num_threads, num_tasks);
    261     if (storage_size > group_dec_caches_.size()) {
    262       group_dec_caches_.resize(storage_size);
    263     }
    264     use_task_id_ = num_threads > num_tasks;
    265     bool use_noise = (frame_header_.flags & FrameHeader::kNoise) != 0;
    266     bool use_group_ids =
    267         (modular_frame_decoder_.UsesFullImage() &&
    268          (frame_header_.encoding == FrameEncoding::kVarDCT || use_noise));
    269     if (dec_state_->render_pipeline) {
    270       JXL_RETURN_IF_ERROR(dec_state_->render_pipeline->PrepareForThreads(
    271           storage_size, use_group_ids));
    272     }
    273     return true;
    274   }
    275 
    276   size_t GetStorageLocation(size_t thread, size_t task) const {
    277     if (use_task_id_) return task;
    278     return thread;
    279   }
    280 
    281   static size_t BytesPerChannel(JxlDataType data_type) {
    282     return (data_type == JXL_TYPE_UINT8   ? 1u
    283             : data_type == JXL_TYPE_FLOAT ? 4u
    284                                           : 2u);
    285   }
    286 
    287   static size_t GetStride(const size_t xsize, JxlPixelFormat format) {
    288     size_t stride =
    289         (xsize * BytesPerChannel(format.data_type) * format.num_channels);
    290     if (format.align > 1) {
    291       stride = (jxl::DivCeil(stride, format.align) * format.align);
    292     }
    293     return stride;
    294   }
    295 
    296   bool HasDcGroupToDecode() const {
    297     return std::any_of(decoded_dc_groups_.cbegin(), decoded_dc_groups_.cend(),
    298                        [](uint8_t ready) { return ready == 0; });
    299   }
    300 
    301   PassesDecoderState* dec_state_;
    302   ThreadPool* pool_;
    303   std::vector<TocEntry> toc_;
    304   uint64_t section_sizes_sum_;
    305   // TODO(veluca): figure out the duplication between these and dec_state_.
    306   FrameHeader frame_header_;
    307   FrameDimensions frame_dim_;
    308   ImageBundle* decoded_;
    309   ModularFrameDecoder modular_frame_decoder_;
    310   bool render_spotcolors_ = true;
    311   bool coalescing_ = true;
    312 
    313   std::vector<uint8_t> processed_section_;
    314   std::vector<uint8_t> decoded_passes_per_ac_group_;
    315   std::vector<uint8_t> decoded_dc_groups_;
    316   bool decoded_dc_global_;
    317   bool decoded_ac_global_;
    318   bool HasEverything() const;
    319   bool finalized_dc_ = true;
    320   size_t num_sections_done_ = 0;
    321   bool is_finalized_ = true;
    322   bool allocated_ = false;
    323 
    324   std::vector<GroupDecCache> group_dec_caches_;
    325 
    326   // Whether or not the task id should be used for storage indexing, instead of
    327   // the thread id.
    328   bool use_task_id_ = false;
    329 
    330   // Testing setting: whether or not to use the slow rendering pipeline.
    331   bool use_slow_rendering_pipeline_;
    332 
    333   JxlProgressiveDetail progressive_detail_ = kFrames;
    334   // Number of completed passes where section decoding should pause.
    335   // Used for progressive details at least kLastPasses.
    336   std::vector<int> passes_to_pause_;
    337 };
    338 
    339 }  // namespace jxl
    340 
    341 #endif  // LIB_JXL_DEC_FRAME_H_