libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

enc_frame.cc (95997B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/enc_frame.h"
      7 
      8 #include <stddef.h>
      9 #include <stdint.h>
     10 
     11 #include <algorithm>
     12 #include <array>
     13 #include <atomic>
     14 #include <cmath>
     15 #include <memory>
     16 #include <mutex>
     17 #include <numeric>
     18 #include <utility>
     19 #include <vector>
     20 
     21 #include "lib/jxl/ac_context.h"
     22 #include "lib/jxl/ac_strategy.h"
     23 #include "lib/jxl/base/bits.h"
     24 #include "lib/jxl/base/common.h"
     25 #include "lib/jxl/base/compiler_specific.h"
     26 #include "lib/jxl/base/data_parallel.h"
     27 #include "lib/jxl/base/override.h"
     28 #include "lib/jxl/base/printf_macros.h"
     29 #include "lib/jxl/base/status.h"
     30 #include "lib/jxl/chroma_from_luma.h"
     31 #include "lib/jxl/coeff_order.h"
     32 #include "lib/jxl/coeff_order_fwd.h"
     33 #include "lib/jxl/color_encoding_internal.h"
     34 #include "lib/jxl/common.h"  // kMaxNumPasses
     35 #include "lib/jxl/dct_util.h"
     36 #include "lib/jxl/dec_external_image.h"
     37 #include "lib/jxl/enc_ac_strategy.h"
     38 #include "lib/jxl/enc_adaptive_quantization.h"
     39 #include "lib/jxl/enc_ans.h"
     40 #include "lib/jxl/enc_ar_control_field.h"
     41 #include "lib/jxl/enc_aux_out.h"
     42 #include "lib/jxl/enc_bit_writer.h"
     43 #include "lib/jxl/enc_cache.h"
     44 #include "lib/jxl/enc_chroma_from_luma.h"
     45 #include "lib/jxl/enc_coeff_order.h"
     46 #include "lib/jxl/enc_context_map.h"
     47 #include "lib/jxl/enc_entropy_coder.h"
     48 #include "lib/jxl/enc_external_image.h"
     49 #include "lib/jxl/enc_fields.h"
     50 #include "lib/jxl/enc_group.h"
     51 #include "lib/jxl/enc_heuristics.h"
     52 #include "lib/jxl/enc_modular.h"
     53 #include "lib/jxl/enc_noise.h"
     54 #include "lib/jxl/enc_params.h"
     55 #include "lib/jxl/enc_patch_dictionary.h"
     56 #include "lib/jxl/enc_photon_noise.h"
     57 #include "lib/jxl/enc_quant_weights.h"
     58 #include "lib/jxl/enc_splines.h"
     59 #include "lib/jxl/enc_toc.h"
     60 #include "lib/jxl/enc_xyb.h"
     61 #include "lib/jxl/fields.h"
     62 #include "lib/jxl/frame_dimensions.h"
     63 #include "lib/jxl/frame_header.h"
     64 #include "lib/jxl/image.h"
     65 #include "lib/jxl/image_bundle.h"
     66 #include "lib/jxl/image_ops.h"
     67 #include "lib/jxl/jpeg/enc_jpeg_data.h"
     68 #include "lib/jxl/loop_filter.h"
     69 #include "lib/jxl/modular/options.h"
     70 #include "lib/jxl/quant_weights.h"
     71 #include "lib/jxl/quantizer.h"
     72 #include "lib/jxl/splines.h"
     73 #include "lib/jxl/toc.h"
     74 
     75 namespace jxl {
     76 
     77 Status ParamsPostInit(CompressParams* p) {
     78   if (!p->manual_noise.empty() &&
     79       p->manual_noise.size() != NoiseParams::kNumNoisePoints) {
     80     return JXL_FAILURE("Invalid number of noise lut entries");
     81   }
     82   if (!p->manual_xyb_factors.empty() && p->manual_xyb_factors.size() != 3) {
     83     return JXL_FAILURE("Invalid number of XYB quantization factors");
     84   }
     85   if (!p->modular_mode && p->butteraugli_distance == 0.0) {
     86     p->butteraugli_distance = kMinButteraugliDistance;
     87   }
     88   if (p->original_butteraugli_distance == -1.0) {
     89     p->original_butteraugli_distance = p->butteraugli_distance;
     90   }
     91   if (p->resampling <= 0) {
     92     p->resampling = 1;
     93     // For very low bit rates, using 2x2 resampling gives better results on
     94     // most photographic images, with an adjusted butteraugli score chosen to
     95     // give roughly the same amount of bits per pixel.
     96     if (!p->already_downsampled && p->butteraugli_distance >= 20) {
     97       p->resampling = 2;
     98       p->butteraugli_distance = 6 + ((p->butteraugli_distance - 20) * 0.25);
     99     }
    100   }
    101   if (p->ec_resampling <= 0) {
    102     p->ec_resampling = p->resampling;
    103   }
    104   return true;
    105 }
    106 
    107 namespace {
    108 
    109 template <typename T>
    110 uint32_t GetBitDepth(JxlBitDepth bit_depth, const T& metadata,
    111                      JxlPixelFormat format) {
    112   if (bit_depth.type == JXL_BIT_DEPTH_FROM_PIXEL_FORMAT) {
    113     return BitsPerChannel(format.data_type);
    114   } else if (bit_depth.type == JXL_BIT_DEPTH_FROM_CODESTREAM) {
    115     return metadata.bit_depth.bits_per_sample;
    116   } else if (bit_depth.type == JXL_BIT_DEPTH_CUSTOM) {
    117     return bit_depth.bits_per_sample;
    118   } else {
    119     return 0;
    120   }
    121 }
    122 
    123 Status CopyColorChannels(JxlChunkedFrameInputSource input, Rect rect,
    124                          const FrameInfo& frame_info,
    125                          const ImageMetadata& metadata, ThreadPool* pool,
    126                          Image3F* color, ImageF* alpha,
    127                          bool* has_interleaved_alpha) {
    128   JxlPixelFormat format = {4, JXL_TYPE_UINT8, JXL_NATIVE_ENDIAN, 0};
    129   input.get_color_channels_pixel_format(input.opaque, &format);
    130   *has_interleaved_alpha = format.num_channels == 2 || format.num_channels == 4;
    131   size_t bits_per_sample =
    132       GetBitDepth(frame_info.image_bit_depth, metadata, format);
    133   size_t row_offset;
    134   auto buffer = GetColorBuffer(input, rect.x0(), rect.y0(), rect.xsize(),
    135                                rect.ysize(), &row_offset);
    136   if (!buffer) {
    137     return JXL_FAILURE("no buffer for color channels given");
    138   }
    139   size_t color_channels = frame_info.ib_needs_color_transform
    140                               ? metadata.color_encoding.Channels()
    141                               : 3;
    142   if (format.num_channels < color_channels) {
    143     return JXL_FAILURE("Expected %" PRIuS
    144                        " color channels, received only %u channels",
    145                        color_channels, format.num_channels);
    146   }
    147   const uint8_t* data = reinterpret_cast<const uint8_t*>(buffer.get());
    148   for (size_t c = 0; c < color_channels; ++c) {
    149     JXL_RETURN_IF_ERROR(ConvertFromExternalNoSizeCheck(
    150         data, rect.xsize(), rect.ysize(), row_offset, bits_per_sample, format,
    151         c, pool, &color->Plane(c)));
    152   }
    153   if (color_channels == 1) {
    154     CopyImageTo(color->Plane(0), &color->Plane(1));
    155     CopyImageTo(color->Plane(0), &color->Plane(2));
    156   }
    157   if (alpha) {
    158     if (*has_interleaved_alpha) {
    159       JXL_RETURN_IF_ERROR(ConvertFromExternalNoSizeCheck(
    160           data, rect.xsize(), rect.ysize(), row_offset, bits_per_sample, format,
    161           format.num_channels - 1, pool, alpha));
    162     } else {
    163       // if alpha is not passed, but it is expected, then assume
    164       // it is all-opaque
    165       FillImage(1.0f, alpha);
    166     }
    167   }
    168   return true;
    169 }
    170 
    171 Status CopyExtraChannels(JxlChunkedFrameInputSource input, Rect rect,
    172                          const FrameInfo& frame_info,
    173                          const ImageMetadata& metadata,
    174                          bool has_interleaved_alpha, ThreadPool* pool,
    175                          std::vector<ImageF>* extra_channels) {
    176   for (size_t ec = 0; ec < metadata.num_extra_channels; ec++) {
    177     if (has_interleaved_alpha &&
    178         metadata.extra_channel_info[ec].type == ExtraChannel::kAlpha) {
    179       // Skip this alpha channel, but still request additional alpha channels
    180       // if they exist.
    181       has_interleaved_alpha = false;
    182       continue;
    183     }
    184     JxlPixelFormat ec_format = {1, JXL_TYPE_UINT8, JXL_NATIVE_ENDIAN, 0};
    185     input.get_extra_channel_pixel_format(input.opaque, ec, &ec_format);
    186     ec_format.num_channels = 1;
    187     size_t row_offset;
    188     auto buffer =
    189         GetExtraChannelBuffer(input, ec, rect.x0(), rect.y0(), rect.xsize(),
    190                               rect.ysize(), &row_offset);
    191     if (!buffer) {
    192       return JXL_FAILURE("no buffer for extra channel given");
    193     }
    194     size_t bits_per_sample = GetBitDepth(
    195         frame_info.image_bit_depth, metadata.extra_channel_info[ec], ec_format);
    196     if (!ConvertFromExternalNoSizeCheck(
    197             reinterpret_cast<const uint8_t*>(buffer.get()), rect.xsize(),
    198             rect.ysize(), row_offset, bits_per_sample, ec_format, 0, pool,
    199             &(*extra_channels)[ec])) {
    200       return JXL_FAILURE("Failed to set buffer for extra channel");
    201     }
    202   }
    203   return true;
    204 }
    205 
    206 void SetProgressiveMode(const CompressParams& cparams,
    207                         ProgressiveSplitter* progressive_splitter) {
    208   constexpr PassDefinition progressive_passes_dc_vlf_lf_full_ac[] = {
    209       {/*num_coefficients=*/2, /*shift=*/0,
    210        /*suitable_for_downsampling_of_at_least=*/4},
    211       {/*num_coefficients=*/3, /*shift=*/0,
    212        /*suitable_for_downsampling_of_at_least=*/2},
    213       {/*num_coefficients=*/8, /*shift=*/0,
    214        /*suitable_for_downsampling_of_at_least=*/0},
    215   };
    216   constexpr PassDefinition progressive_passes_dc_quant_ac_full_ac[] = {
    217       {/*num_coefficients=*/8, /*shift=*/1,
    218        /*suitable_for_downsampling_of_at_least=*/2},
    219       {/*num_coefficients=*/8, /*shift=*/0,
    220        /*suitable_for_downsampling_of_at_least=*/0},
    221   };
    222   bool progressive_mode = ApplyOverride(cparams.progressive_mode, false);
    223   bool qprogressive_mode = ApplyOverride(cparams.qprogressive_mode, false);
    224   if (cparams.custom_progressive_mode) {
    225     progressive_splitter->SetProgressiveMode(*cparams.custom_progressive_mode);
    226   } else if (qprogressive_mode) {
    227     progressive_splitter->SetProgressiveMode(
    228         ProgressiveMode{progressive_passes_dc_quant_ac_full_ac});
    229   } else if (progressive_mode) {
    230     progressive_splitter->SetProgressiveMode(
    231         ProgressiveMode{progressive_passes_dc_vlf_lf_full_ac});
    232   }
    233 }
    234 
    235 uint64_t FrameFlagsFromParams(const CompressParams& cparams) {
    236   uint64_t flags = 0;
    237 
    238   const float dist = cparams.butteraugli_distance;
    239 
    240   // We don't add noise at low butteraugli distances because the original
    241   // noise is stored within the compressed image and adding noise makes things
    242   // worse.
    243   if (ApplyOverride(cparams.noise, dist >= kMinButteraugliForNoise) ||
    244       cparams.photon_noise_iso > 0 ||
    245       cparams.manual_noise.size() == NoiseParams::kNumNoisePoints) {
    246     flags |= FrameHeader::kNoise;
    247   }
    248 
    249   if (cparams.progressive_dc > 0 && cparams.modular_mode == false) {
    250     flags |= FrameHeader::kUseDcFrame;
    251   }
    252 
    253   return flags;
    254 }
    255 
    256 Status LoopFilterFromParams(const CompressParams& cparams, bool streaming_mode,
    257                             FrameHeader* JXL_RESTRICT frame_header) {
    258   LoopFilter* loop_filter = &frame_header->loop_filter;
    259 
    260   // Gaborish defaults to enabled in Hare or slower.
    261   loop_filter->gab = ApplyOverride(
    262       cparams.gaborish, cparams.speed_tier <= SpeedTier::kHare &&
    263                             frame_header->encoding == FrameEncoding::kVarDCT &&
    264                             cparams.decoding_speed_tier < 4);
    265 
    266   if (cparams.epf != -1) {
    267     loop_filter->epf_iters = cparams.epf;
    268   } else {
    269     if (frame_header->encoding == FrameEncoding::kModular) {
    270       loop_filter->epf_iters = 0;
    271     } else {
    272       constexpr float kThresholds[3] = {0.7, 1.5, 4.0};
    273       loop_filter->epf_iters = 0;
    274       if (cparams.decoding_speed_tier < 3) {
    275         for (size_t i = cparams.decoding_speed_tier == 2 ? 1 : 0; i < 3; i++) {
    276           if (cparams.butteraugli_distance >= kThresholds[i]) {
    277             loop_filter->epf_iters++;
    278           }
    279         }
    280       }
    281     }
    282   }
    283   // Strength of EPF in modular mode.
    284   if (frame_header->encoding == FrameEncoding::kModular &&
    285       !cparams.IsLossless()) {
    286     // TODO(veluca): this formula is nonsense.
    287     loop_filter->epf_sigma_for_modular =
    288         std::max(cparams.butteraugli_distance, 1.0f);
    289   }
    290   if (frame_header->encoding == FrameEncoding::kModular &&
    291       cparams.lossy_palette) {
    292     loop_filter->epf_sigma_for_modular = 1.0f;
    293   }
    294 
    295   return true;
    296 }
    297 
    298 Status MakeFrameHeader(size_t xsize, size_t ysize,
    299                        const CompressParams& cparams,
    300                        const ProgressiveSplitter& progressive_splitter,
    301                        const FrameInfo& frame_info,
    302                        const jpeg::JPEGData* jpeg_data, bool streaming_mode,
    303                        FrameHeader* JXL_RESTRICT frame_header) {
    304   frame_header->nonserialized_is_preview = frame_info.is_preview;
    305   frame_header->is_last = frame_info.is_last;
    306   frame_header->save_before_color_transform =
    307       frame_info.save_before_color_transform;
    308   frame_header->frame_type = frame_info.frame_type;
    309   frame_header->name = frame_info.name;
    310 
    311   progressive_splitter.InitPasses(&frame_header->passes);
    312 
    313   if (cparams.modular_mode) {
    314     frame_header->encoding = FrameEncoding::kModular;
    315     if (cparams.modular_group_size_shift == -1) {
    316       frame_header->group_size_shift = 1;
    317       // no point using groups when only one group is full and the others are
    318       // less than half full: multithreading will not really help much, while
    319       // compression does suffer
    320       if (xsize <= 400 && ysize <= 400) {
    321         frame_header->group_size_shift = 2;
    322       }
    323     } else {
    324       frame_header->group_size_shift = cparams.modular_group_size_shift;
    325     }
    326   }
    327 
    328   if (jpeg_data) {
    329     // we are transcoding a JPEG, so we don't get to choose
    330     frame_header->encoding = FrameEncoding::kVarDCT;
    331     frame_header->x_qm_scale = 2;
    332     frame_header->b_qm_scale = 2;
    333     JXL_RETURN_IF_ERROR(SetChromaSubsamplingFromJpegData(
    334         *jpeg_data, &frame_header->chroma_subsampling));
    335     JXL_RETURN_IF_ERROR(SetColorTransformFromJpegData(
    336         *jpeg_data, &frame_header->color_transform));
    337   } else {
    338     frame_header->color_transform = cparams.color_transform;
    339     if (!cparams.modular_mode &&
    340         (frame_header->chroma_subsampling.MaxHShift() != 0 ||
    341          frame_header->chroma_subsampling.MaxVShift() != 0)) {
    342       return JXL_FAILURE(
    343           "Chroma subsampling is not supported in VarDCT mode when not "
    344           "recompressing JPEGs");
    345     }
    346   }
    347   if (frame_header->color_transform != ColorTransform::kYCbCr &&
    348       (frame_header->chroma_subsampling.MaxHShift() != 0 ||
    349        frame_header->chroma_subsampling.MaxVShift() != 0)) {
    350     return JXL_FAILURE(
    351         "Chroma subsampling is not supported when color transform is not "
    352         "YCbCr");
    353   }
    354 
    355   frame_header->flags = FrameFlagsFromParams(cparams);
    356   // Non-photon noise is not supported in the Modular encoder for now.
    357   if (frame_header->encoding != FrameEncoding::kVarDCT &&
    358       cparams.photon_noise_iso == 0 && cparams.manual_noise.empty()) {
    359     frame_header->UpdateFlag(false, FrameHeader::Flags::kNoise);
    360   }
    361 
    362   JXL_RETURN_IF_ERROR(
    363       LoopFilterFromParams(cparams, streaming_mode, frame_header));
    364 
    365   frame_header->dc_level = frame_info.dc_level;
    366   if (frame_header->dc_level > 2) {
    367     // With 3 or more progressive_dc frames, the implementation does not yet
    368     // work, see enc_cache.cc.
    369     return JXL_FAILURE("progressive_dc > 2 is not yet supported");
    370   }
    371   if (cparams.progressive_dc > 0 &&
    372       (cparams.ec_resampling != 1 || cparams.resampling != 1)) {
    373     return JXL_FAILURE("Resampling not supported with DC frames");
    374   }
    375   if (cparams.resampling != 1 && cparams.resampling != 2 &&
    376       cparams.resampling != 4 && cparams.resampling != 8) {
    377     return JXL_FAILURE("Invalid resampling factor");
    378   }
    379   if (cparams.ec_resampling != 1 && cparams.ec_resampling != 2 &&
    380       cparams.ec_resampling != 4 && cparams.ec_resampling != 8) {
    381     return JXL_FAILURE("Invalid ec_resampling factor");
    382   }
    383   // Resized frames.
    384   if (frame_info.frame_type != FrameType::kDCFrame) {
    385     frame_header->frame_origin = frame_info.origin;
    386     size_t ups = 1;
    387     if (cparams.already_downsampled) ups = cparams.resampling;
    388 
    389     // TODO(lode): this is not correct in case of odd original image sizes in
    390     // combination with cparams.already_downsampled. Likely these values should
    391     // be set to respectively frame_header->default_xsize() and
    392     // frame_header->default_ysize() instead, the original (non downsampled)
    393     // intended decoded image dimensions. But it may be more subtle than that
    394     // if combined with crop. This issue causes custom_size_or_origin to be
    395     // incorrectly set to true in case of already_downsampled with odd output
    396     // image size when no cropping is used.
    397     frame_header->frame_size.xsize = xsize * ups;
    398     frame_header->frame_size.ysize = ysize * ups;
    399     if (frame_info.origin.x0 != 0 || frame_info.origin.y0 != 0 ||
    400         frame_header->frame_size.xsize != frame_header->default_xsize() ||
    401         frame_header->frame_size.ysize != frame_header->default_ysize()) {
    402       frame_header->custom_size_or_origin = true;
    403     }
    404   }
    405   // Upsampling.
    406   frame_header->upsampling = cparams.resampling;
    407   const std::vector<ExtraChannelInfo>& extra_channels =
    408       frame_header->nonserialized_metadata->m.extra_channel_info;
    409   frame_header->extra_channel_upsampling.clear();
    410   frame_header->extra_channel_upsampling.resize(extra_channels.size(),
    411                                                 cparams.ec_resampling);
    412   frame_header->save_as_reference = frame_info.save_as_reference;
    413 
    414   // Set blending-related information.
    415   if (frame_info.blend || frame_header->custom_size_or_origin) {
    416     // Set blend_channel to the first alpha channel. These values are only
    417     // encoded in case a blend mode involving alpha is used and there are more
    418     // than one extra channels.
    419     size_t index = 0;
    420     if (frame_info.alpha_channel == -1) {
    421       if (extra_channels.size() > 1) {
    422         for (size_t i = 0; i < extra_channels.size(); i++) {
    423           if (extra_channels[i].type == ExtraChannel::kAlpha) {
    424             index = i;
    425             break;
    426           }
    427         }
    428       }
    429     } else {
    430       index = static_cast<size_t>(frame_info.alpha_channel);
    431       JXL_ASSERT(index == 0 || index < extra_channels.size());
    432     }
    433     frame_header->blending_info.alpha_channel = index;
    434     frame_header->blending_info.mode =
    435         frame_info.blend ? frame_info.blendmode : BlendMode::kReplace;
    436     frame_header->blending_info.source = frame_info.source;
    437     frame_header->blending_info.clamp = frame_info.clamp;
    438     const auto& extra_channel_info = frame_info.extra_channel_blending_info;
    439     for (size_t i = 0; i < extra_channels.size(); i++) {
    440       if (i < extra_channel_info.size()) {
    441         frame_header->extra_channel_blending_info[i] = extra_channel_info[i];
    442       } else {
    443         frame_header->extra_channel_blending_info[i].alpha_channel = index;
    444         BlendMode default_blend = frame_info.blendmode;
    445         if (extra_channels[i].type != ExtraChannel::kBlack && i != index) {
    446           // K needs to be blended, spot colors and other stuff gets added
    447           default_blend = BlendMode::kAdd;
    448         }
    449         frame_header->extra_channel_blending_info[i].mode =
    450             frame_info.blend ? default_blend : BlendMode::kReplace;
    451         frame_header->extra_channel_blending_info[i].source = 1;
    452       }
    453     }
    454   }
    455 
    456   frame_header->animation_frame.duration = frame_info.duration;
    457   frame_header->animation_frame.timecode = frame_info.timecode;
    458 
    459   if (jpeg_data) {
    460     frame_header->UpdateFlag(false, FrameHeader::kUseDcFrame);
    461     frame_header->UpdateFlag(true, FrameHeader::kSkipAdaptiveDCSmoothing);
    462   }
    463 
    464   return true;
    465 }
    466 
    467 // Invisible (alpha = 0) pixels tend to be a mess in optimized PNGs.
    468 // Since they have no visual impact whatsoever, we can replace them with
    469 // something that compresses better and reduces artifacts near the edges. This
    470 // does some kind of smooth stuff that seems to work.
    471 // Replace invisible pixels with a weighted average of the pixel to the left,
    472 // the pixel to the topright, and non-invisible neighbours.
    473 // Produces downward-blurry smears, with in the upwards direction only a 1px
    474 // edge duplication but not more. It would probably be better to smear in all
    475 // directions. That requires an alpha-weighed convolution with a large enough
    476 // kernel though, which might be overkill...
    477 void SimplifyInvisible(Image3F* image, const ImageF& alpha, bool lossless) {
    478   for (size_t c = 0; c < 3; ++c) {
    479     for (size_t y = 0; y < image->ysize(); ++y) {
    480       float* JXL_RESTRICT row = image->PlaneRow(c, y);
    481       const float* JXL_RESTRICT prow =
    482           (y > 0 ? image->PlaneRow(c, y - 1) : nullptr);
    483       const float* JXL_RESTRICT nrow =
    484           (y + 1 < image->ysize() ? image->PlaneRow(c, y + 1) : nullptr);
    485       const float* JXL_RESTRICT a = alpha.Row(y);
    486       const float* JXL_RESTRICT pa = (y > 0 ? alpha.Row(y - 1) : nullptr);
    487       const float* JXL_RESTRICT na =
    488           (y + 1 < image->ysize() ? alpha.Row(y + 1) : nullptr);
    489       for (size_t x = 0; x < image->xsize(); ++x) {
    490         if (a[x] == 0) {
    491           if (lossless) {
    492             row[x] = 0;
    493             continue;
    494           }
    495           float d = 0.f;
    496           row[x] = 0;
    497           if (x > 0) {
    498             row[x] += row[x - 1];
    499             d++;
    500             if (a[x - 1] > 0.f) {
    501               row[x] += row[x - 1];
    502               d++;
    503             }
    504           }
    505           if (x + 1 < image->xsize()) {
    506             if (y > 0) {
    507               row[x] += prow[x + 1];
    508               d++;
    509             }
    510             if (a[x + 1] > 0.f) {
    511               row[x] += 2.f * row[x + 1];
    512               d += 2.f;
    513             }
    514             if (y > 0 && pa[x + 1] > 0.f) {
    515               row[x] += 2.f * prow[x + 1];
    516               d += 2.f;
    517             }
    518             if (y + 1 < image->ysize() && na[x + 1] > 0.f) {
    519               row[x] += 2.f * nrow[x + 1];
    520               d += 2.f;
    521             }
    522           }
    523           if (y > 0 && pa[x] > 0.f) {
    524             row[x] += 2.f * prow[x];
    525             d += 2.f;
    526           }
    527           if (y + 1 < image->ysize() && na[x] > 0.f) {
    528             row[x] += 2.f * nrow[x];
    529             d += 2.f;
    530           }
    531           if (d > 1.f) row[x] /= d;
    532         }
    533       }
    534     }
    535   }
    536 }
    537 
    538 struct PixelStatsForChromacityAdjustment {
    539   float dx = 0;
    540   float db = 0;
    541   float exposed_blue = 0;
    542   static float CalcPlane(const ImageF* JXL_RESTRICT plane, const Rect& rect) {
    543     float xmax = 0;
    544     float ymax = 0;
    545     for (size_t ty = 1; ty < rect.ysize(); ++ty) {
    546       for (size_t tx = 1; tx < rect.xsize(); ++tx) {
    547         float cur = rect.Row(plane, ty)[tx];
    548         float prev_row = rect.Row(plane, ty - 1)[tx];
    549         float prev = rect.Row(plane, ty)[tx - 1];
    550         xmax = std::max(xmax, std::abs(cur - prev));
    551         ymax = std::max(ymax, std::abs(cur - prev_row));
    552       }
    553     }
    554     return std::max(xmax, ymax);
    555   }
    556   void CalcExposedBlue(const ImageF* JXL_RESTRICT plane_y,
    557                        const ImageF* JXL_RESTRICT plane_b, const Rect& rect) {
    558     float eb = 0;
    559     float xmax = 0;
    560     float ymax = 0;
    561     for (size_t ty = 1; ty < rect.ysize(); ++ty) {
    562       for (size_t tx = 1; tx < rect.xsize(); ++tx) {
    563         float cur_y = rect.Row(plane_y, ty)[tx];
    564         float cur_b = rect.Row(plane_b, ty)[tx];
    565         float exposed_b = cur_b - cur_y * 1.2;
    566         float diff_b = cur_b - cur_y;
    567         float prev_row = rect.Row(plane_b, ty - 1)[tx];
    568         float prev = rect.Row(plane_b, ty)[tx - 1];
    569         float diff_prev_row = prev_row - rect.Row(plane_y, ty - 1)[tx];
    570         float diff_prev = prev - rect.Row(plane_y, ty)[tx - 1];
    571         xmax = std::max(xmax, std::abs(diff_b - diff_prev));
    572         ymax = std::max(ymax, std::abs(diff_b - diff_prev_row));
    573         if (exposed_b >= 0) {
    574           exposed_b *= fabs(cur_b - prev) + fabs(cur_b - prev_row);
    575           eb = std::max(eb, exposed_b);
    576         }
    577       }
    578     }
    579     exposed_blue = eb;
    580     db = std::max(xmax, ymax);
    581   }
    582   void Calc(const Image3F* JXL_RESTRICT opsin, const Rect& rect) {
    583     dx = CalcPlane(&opsin->Plane(0), rect);
    584     CalcExposedBlue(&opsin->Plane(1), &opsin->Plane(2), rect);
    585   }
    586   int HowMuchIsXChannelPixelized() const {
    587     if (dx >= 0.03) {
    588       return 2;
    589     }
    590     if (dx >= 0.017) {
    591       return 1;
    592     }
    593     return 0;
    594   }
    595   int HowMuchIsBChannelPixelized() const {
    596     int add = exposed_blue >= 0.13 ? 1 : 0;
    597     if (db > 0.38) {
    598       return 2 + add;
    599     }
    600     if (db > 0.33) {
    601       return 1 + add;
    602     }
    603     if (db > 0.28) {
    604       return add;
    605     }
    606     return 0;
    607   }
    608 };
    609 
    610 void ComputeChromacityAdjustments(const CompressParams& cparams,
    611                                   const Image3F& opsin, const Rect& rect,
    612                                   FrameHeader* frame_header) {
    613   if (frame_header->encoding != FrameEncoding::kVarDCT ||
    614       cparams.max_error_mode) {
    615     return;
    616   }
    617   // 1) Distance based approach for chromacity adjustment:
    618   float x_qm_scale_steps[4] = {1.25f, 7.0f, 15.0f, 24.0f};
    619   frame_header->x_qm_scale = 2;
    620   for (float x_qm_scale_step : x_qm_scale_steps) {
    621     if (cparams.original_butteraugli_distance > x_qm_scale_step) {
    622       frame_header->x_qm_scale++;
    623     }
    624   }
    625   if (cparams.butteraugli_distance < 0.299f) {
    626     // Favor chromacity preservation for making images appear more
    627     // faithful to original even with extreme (5-10x) zooming.
    628     frame_header->x_qm_scale++;
    629   }
    630   // 2) Pixel-based approach for chromacity adjustment:
    631   // look at the individual pixels and make a guess how difficult
    632   // the image would be based on the worst case pixel.
    633   PixelStatsForChromacityAdjustment pixel_stats;
    634   if (cparams.speed_tier <= SpeedTier::kSquirrel) {
    635     pixel_stats.Calc(&opsin, rect);
    636   }
    637   // For X take the most severe adjustment.
    638   frame_header->x_qm_scale = std::max<int>(
    639       frame_header->x_qm_scale, 2 + pixel_stats.HowMuchIsXChannelPixelized());
    640   // B only adjusted by pixel-based approach.
    641   frame_header->b_qm_scale = 2 + pixel_stats.HowMuchIsBChannelPixelized();
    642 }
    643 
    644 void ComputeNoiseParams(const CompressParams& cparams, bool streaming_mode,
    645                         bool color_is_jpeg, const Image3F& opsin,
    646                         const FrameDimensions& frame_dim,
    647                         FrameHeader* frame_header, NoiseParams* noise_params) {
    648   if (cparams.photon_noise_iso > 0) {
    649     *noise_params = SimulatePhotonNoise(frame_dim.xsize, frame_dim.ysize,
    650                                         cparams.photon_noise_iso);
    651   } else if (cparams.manual_noise.size() == NoiseParams::kNumNoisePoints) {
    652     for (size_t i = 0; i < NoiseParams::kNumNoisePoints; i++) {
    653       noise_params->lut[i] = cparams.manual_noise[i];
    654     }
    655   } else if (frame_header->encoding == FrameEncoding::kVarDCT &&
    656              frame_header->flags & FrameHeader::kNoise && !color_is_jpeg &&
    657              !streaming_mode) {
    658     // Don't start at zero amplitude since adding noise is expensive -- it
    659     // significantly slows down decoding, and this is unlikely to
    660     // completely go away even with advanced optimizations. After the
    661     // kNoiseModelingRampUpDistanceRange we have reached the full level,
    662     // i.e. noise is no longer represented by the compressed image, so we
    663     // can add full noise by the noise modeling itself.
    664     static const float kNoiseModelingRampUpDistanceRange = 0.6;
    665     static const float kNoiseLevelAtStartOfRampUp = 0.25;
    666     static const float kNoiseRampupStart = 1.0;
    667     // TODO(user) test and properly select quality_coef with smooth
    668     // filter
    669     float quality_coef = 1.0f;
    670     const float rampup = (cparams.butteraugli_distance - kNoiseRampupStart) /
    671                          kNoiseModelingRampUpDistanceRange;
    672     if (rampup < 1.0f) {
    673       quality_coef = kNoiseLevelAtStartOfRampUp +
    674                      (1.0f - kNoiseLevelAtStartOfRampUp) * rampup;
    675     }
    676     if (rampup < 0.0f) {
    677       quality_coef = kNoiseRampupStart;
    678     }
    679     if (!GetNoiseParameter(opsin, noise_params, quality_coef)) {
    680       frame_header->flags &= ~FrameHeader::kNoise;
    681     }
    682   }
    683 }
    684 
    685 Status DownsampleColorChannels(const CompressParams& cparams,
    686                                const FrameHeader& frame_header,
    687                                bool color_is_jpeg, Image3F* opsin) {
    688   if (color_is_jpeg || frame_header.upsampling == 1 ||
    689       cparams.already_downsampled) {
    690     return true;
    691   }
    692   if (frame_header.encoding == FrameEncoding::kVarDCT &&
    693       frame_header.upsampling == 2) {
    694     // TODO(lode): use the regular DownsampleImage, or adapt to the custom
    695     // coefficients, if there is are custom upscaling coefficients in
    696     // CustomTransformData
    697     if (cparams.speed_tier <= SpeedTier::kSquirrel) {
    698       // TODO(lode): DownsampleImage2_Iterative is currently too slow to
    699       // be used for squirrel, make it faster, and / or enable it only for
    700       // kitten.
    701       JXL_RETURN_IF_ERROR(DownsampleImage2_Iterative(opsin));
    702     } else {
    703       JXL_RETURN_IF_ERROR(DownsampleImage2_Sharper(opsin));
    704     }
    705   } else {
    706     JXL_ASSIGN_OR_RETURN(*opsin,
    707                          DownsampleImage(*opsin, frame_header.upsampling));
    708   }
    709   if (frame_header.encoding == FrameEncoding::kVarDCT) {
    710     PadImageToBlockMultipleInPlace(opsin);
    711   }
    712   return true;
    713 }
    714 
    715 template <typename V, typename R>
    716 void FindIndexOfSumMaximum(const V* array, const size_t len, R* idx, V* sum) {
    717   JXL_ASSERT(len > 0);
    718   V maxval = 0;
    719   V val = 0;
    720   R maxidx = 0;
    721   for (size_t i = 0; i < len; ++i) {
    722     val += array[i];
    723     if (val > maxval) {
    724       maxval = val;
    725       maxidx = i;
    726     }
    727   }
    728   *idx = maxidx;
    729   *sum = maxval;
    730 }
    731 
    732 Status ComputeJPEGTranscodingData(const jpeg::JPEGData& jpeg_data,
    733                                   const FrameHeader& frame_header,
    734                                   ThreadPool* pool,
    735                                   ModularFrameEncoder* enc_modular,
    736                                   PassesEncoderState* enc_state) {
    737   PassesSharedState& shared = enc_state->shared;
    738   const FrameDimensions& frame_dim = shared.frame_dim;
    739 
    740   const size_t xsize = frame_dim.xsize_padded;
    741   const size_t ysize = frame_dim.ysize_padded;
    742   const size_t xsize_blocks = frame_dim.xsize_blocks;
    743   const size_t ysize_blocks = frame_dim.ysize_blocks;
    744 
    745   // no-op chroma from luma
    746   JXL_ASSIGN_OR_RETURN(shared.cmap,
    747                        ColorCorrelationMap::Create(xsize, ysize, false));
    748   shared.ac_strategy.FillDCT8();
    749   FillImage(static_cast<uint8_t>(0), &shared.epf_sharpness);
    750 
    751   enc_state->coeffs.clear();
    752   while (enc_state->coeffs.size() < enc_state->passes.size()) {
    753     JXL_ASSIGN_OR_RETURN(
    754         std::unique_ptr<ACImageT<int32_t>> coeffs,
    755         ACImageT<int32_t>::Make(kGroupDim * kGroupDim, frame_dim.num_groups));
    756     enc_state->coeffs.emplace_back(std::move(coeffs));
    757   }
    758 
    759   // convert JPEG quantization table to a Quantizer object
    760   float dcquantization[3];
    761   std::vector<QuantEncoding> qe(DequantMatrices::kNum,
    762                                 QuantEncoding::Library(0));
    763 
    764   auto jpeg_c_map =
    765       JpegOrder(frame_header.color_transform, jpeg_data.components.size() == 1);
    766 
    767   std::vector<int> qt(192);
    768   for (size_t c = 0; c < 3; c++) {
    769     size_t jpeg_c = jpeg_c_map[c];
    770     const int32_t* quant =
    771         jpeg_data.quant[jpeg_data.components[jpeg_c].quant_idx].values.data();
    772 
    773     dcquantization[c] = 255 * 8.0f / quant[0];
    774     for (size_t y = 0; y < 8; y++) {
    775       for (size_t x = 0; x < 8; x++) {
    776         // JPEG XL transposes the DCT, JPEG doesn't.
    777         qt[c * 64 + 8 * x + y] = quant[8 * y + x];
    778       }
    779     }
    780   }
    781   DequantMatricesSetCustomDC(&shared.matrices, dcquantization);
    782   float dcquantization_r[3] = {1.0f / dcquantization[0],
    783                                1.0f / dcquantization[1],
    784                                1.0f / dcquantization[2]};
    785 
    786   qe[AcStrategy::Type::DCT] = QuantEncoding::RAW(qt);
    787   JXL_RETURN_IF_ERROR(
    788       DequantMatricesSetCustom(&shared.matrices, qe, enc_modular));
    789 
    790   // Ensure that InvGlobalScale() is 1.
    791   shared.quantizer = Quantizer(&shared.matrices, 1, kGlobalScaleDenom);
    792   // Recompute MulDC() and InvMulDC().
    793   shared.quantizer.RecomputeFromGlobalScale();
    794 
    795   // Per-block dequant scaling should be 1.
    796   FillImage(static_cast<int32_t>(shared.quantizer.InvGlobalScale()),
    797             &shared.raw_quant_field);
    798 
    799   std::vector<int32_t> scaled_qtable(192);
    800   for (size_t c = 0; c < 3; c++) {
    801     for (size_t i = 0; i < 64; i++) {
    802       scaled_qtable[64 * c + i] =
    803           (1 << kCFLFixedPointPrecision) * qt[64 + i] / qt[64 * c + i];
    804     }
    805   }
    806 
    807   auto jpeg_row = [&](size_t c, size_t y) {
    808     return jpeg_data.components[jpeg_c_map[c]].coeffs.data() +
    809            jpeg_data.components[jpeg_c_map[c]].width_in_blocks * kDCTBlockSize *
    810                y;
    811   };
    812 
    813   bool DCzero = (frame_header.color_transform == ColorTransform::kYCbCr);
    814   // Compute chroma-from-luma for AC (doesn't seem to be useful for DC)
    815   if (frame_header.chroma_subsampling.Is444() &&
    816       enc_state->cparams.force_cfl_jpeg_recompression &&
    817       jpeg_data.components.size() == 3) {
    818     for (size_t c : {0, 2}) {
    819       ImageSB* map = (c == 0 ? &shared.cmap.ytox_map : &shared.cmap.ytob_map);
    820       const float kScale = kDefaultColorFactor;
    821       const int kOffset = 127;
    822       const float kBase =
    823           c == 0 ? shared.cmap.YtoXRatio(0) : shared.cmap.YtoBRatio(0);
    824       const float kZeroThresh =
    825           kScale * kZeroBiasDefault[c] *
    826           0.9999f;  // just epsilon less for better rounding
    827 
    828       auto process_row = [&](const uint32_t task, const size_t thread) {
    829         size_t ty = task;
    830         int8_t* JXL_RESTRICT row_out = map->Row(ty);
    831         for (size_t tx = 0; tx < map->xsize(); ++tx) {
    832           const size_t y0 = ty * kColorTileDimInBlocks;
    833           const size_t x0 = tx * kColorTileDimInBlocks;
    834           const size_t y1 = std::min(frame_dim.ysize_blocks,
    835                                      (ty + 1) * kColorTileDimInBlocks);
    836           const size_t x1 = std::min(frame_dim.xsize_blocks,
    837                                      (tx + 1) * kColorTileDimInBlocks);
    838           int32_t d_num_zeros[257] = {0};
    839           // TODO(veluca): this needs SIMD + fixed point adaptation, and/or
    840           // conversion to the new CfL algorithm.
    841           for (size_t y = y0; y < y1; ++y) {
    842             const int16_t* JXL_RESTRICT row_m = jpeg_row(1, y);
    843             const int16_t* JXL_RESTRICT row_s = jpeg_row(c, y);
    844             for (size_t x = x0; x < x1; ++x) {
    845               for (size_t coeffpos = 1; coeffpos < kDCTBlockSize; coeffpos++) {
    846                 const float scaled_m = row_m[x * kDCTBlockSize + coeffpos] *
    847                                        scaled_qtable[64 * c + coeffpos] *
    848                                        (1.0f / (1 << kCFLFixedPointPrecision));
    849                 const float scaled_s =
    850                     kScale * row_s[x * kDCTBlockSize + coeffpos] +
    851                     (kOffset - kBase * kScale) * scaled_m;
    852                 if (std::abs(scaled_m) > 1e-8f) {
    853                   float from;
    854                   float to;
    855                   if (scaled_m > 0) {
    856                     from = (scaled_s - kZeroThresh) / scaled_m;
    857                     to = (scaled_s + kZeroThresh) / scaled_m;
    858                   } else {
    859                     from = (scaled_s + kZeroThresh) / scaled_m;
    860                     to = (scaled_s - kZeroThresh) / scaled_m;
    861                   }
    862                   if (from < 0.0f) {
    863                     from = 0.0f;
    864                   }
    865                   if (to > 255.0f) {
    866                     to = 255.0f;
    867                   }
    868                   // Instead of clamping the both values
    869                   // we just check that range is sane.
    870                   if (from <= to) {
    871                     d_num_zeros[static_cast<int>(std::ceil(from))]++;
    872                     d_num_zeros[static_cast<int>(std::floor(to + 1))]--;
    873                   }
    874                 }
    875               }
    876             }
    877           }
    878           int best = 0;
    879           int32_t best_sum = 0;
    880           FindIndexOfSumMaximum(d_num_zeros, 256, &best, &best_sum);
    881           int32_t offset_sum = 0;
    882           for (int i = 0; i < 256; ++i) {
    883             if (i <= kOffset) {
    884               offset_sum += d_num_zeros[i];
    885             }
    886           }
    887           row_out[tx] = 0;
    888           if (best_sum > offset_sum + 1) {
    889             row_out[tx] = best - kOffset;
    890           }
    891         }
    892       };
    893 
    894       JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, map->ysize(), ThreadPool::NoInit,
    895                                     process_row, "FindCorrelation"));
    896     }
    897   }
    898 
    899   JXL_ASSIGN_OR_RETURN(Image3F dc, Image3F::Create(xsize_blocks, ysize_blocks));
    900   if (!frame_header.chroma_subsampling.Is444()) {
    901     ZeroFillImage(&dc);
    902     for (auto& coeff : enc_state->coeffs) {
    903       coeff->ZeroFill();
    904     }
    905   }
    906   // JPEG DC is from -1024 to 1023.
    907   std::vector<size_t> dc_counts[3] = {};
    908   dc_counts[0].resize(2048);
    909   dc_counts[1].resize(2048);
    910   dc_counts[2].resize(2048);
    911   size_t total_dc[3] = {};
    912   for (size_t c : {1, 0, 2}) {
    913     if (jpeg_data.components.size() == 1 && c != 1) {
    914       for (auto& coeff : enc_state->coeffs) {
    915         coeff->ZeroFillPlane(c);
    916       }
    917       ZeroFillImage(&dc.Plane(c));
    918       // Ensure no division by 0.
    919       dc_counts[c][1024] = 1;
    920       total_dc[c] = 1;
    921       continue;
    922     }
    923     size_t hshift = frame_header.chroma_subsampling.HShift(c);
    924     size_t vshift = frame_header.chroma_subsampling.VShift(c);
    925     ImageSB& map = (c == 0 ? shared.cmap.ytox_map : shared.cmap.ytob_map);
    926     for (size_t group_index = 0; group_index < frame_dim.num_groups;
    927          group_index++) {
    928       const size_t gx = group_index % frame_dim.xsize_groups;
    929       const size_t gy = group_index / frame_dim.xsize_groups;
    930       int32_t* coeffs[kMaxNumPasses];
    931       for (size_t i = 0; i < enc_state->coeffs.size(); i++) {
    932         coeffs[i] = enc_state->coeffs[i]->PlaneRow(c, group_index, 0).ptr32;
    933       }
    934       int32_t block[64];
    935       for (size_t by = gy * kGroupDimInBlocks;
    936            by < ysize_blocks && by < (gy + 1) * kGroupDimInBlocks; ++by) {
    937         if ((by >> vshift) << vshift != by) continue;
    938         const int16_t* JXL_RESTRICT inputjpeg = jpeg_row(c, by >> vshift);
    939         const int16_t* JXL_RESTRICT inputjpegY = jpeg_row(1, by);
    940         float* JXL_RESTRICT fdc = dc.PlaneRow(c, by >> vshift);
    941         const int8_t* JXL_RESTRICT cm =
    942             map.ConstRow(by / kColorTileDimInBlocks);
    943         for (size_t bx = gx * kGroupDimInBlocks;
    944              bx < xsize_blocks && bx < (gx + 1) * kGroupDimInBlocks; ++bx) {
    945           if ((bx >> hshift) << hshift != bx) continue;
    946           size_t base = (bx >> hshift) * kDCTBlockSize;
    947           int idc;
    948           if (DCzero) {
    949             idc = inputjpeg[base];
    950           } else {
    951             idc = inputjpeg[base] + 1024 / qt[c * 64];
    952           }
    953           dc_counts[c][std::min(static_cast<uint32_t>(idc + 1024),
    954                                 static_cast<uint32_t>(2047))]++;
    955           total_dc[c]++;
    956           fdc[bx >> hshift] = idc * dcquantization_r[c];
    957           if (c == 1 || !enc_state->cparams.force_cfl_jpeg_recompression ||
    958               !frame_header.chroma_subsampling.Is444()) {
    959             for (size_t y = 0; y < 8; y++) {
    960               for (size_t x = 0; x < 8; x++) {
    961                 block[y * 8 + x] = inputjpeg[base + x * 8 + y];
    962               }
    963             }
    964           } else {
    965             const int32_t scale =
    966                 shared.cmap.RatioJPEG(cm[bx / kColorTileDimInBlocks]);
    967 
    968             for (size_t y = 0; y < 8; y++) {
    969               for (size_t x = 0; x < 8; x++) {
    970                 int Y = inputjpegY[kDCTBlockSize * bx + x * 8 + y];
    971                 int QChroma = inputjpeg[kDCTBlockSize * bx + x * 8 + y];
    972                 // Fixed-point multiply of CfL scale with quant table ratio
    973                 // first, and Y value second.
    974                 int coeff_scale = (scale * scaled_qtable[64 * c + y * 8 + x] +
    975                                    (1 << (kCFLFixedPointPrecision - 1))) >>
    976                                   kCFLFixedPointPrecision;
    977                 int cfl_factor =
    978                     (Y * coeff_scale + (1 << (kCFLFixedPointPrecision - 1))) >>
    979                     kCFLFixedPointPrecision;
    980                 int QCR = QChroma - cfl_factor;
    981                 block[y * 8 + x] = QCR;
    982               }
    983             }
    984           }
    985           enc_state->progressive_splitter.SplitACCoefficients(
    986               block, AcStrategy::FromRawStrategy(AcStrategy::Type::DCT), bx, by,
    987               coeffs);
    988           for (size_t i = 0; i < enc_state->coeffs.size(); i++) {
    989             coeffs[i] += kDCTBlockSize;
    990           }
    991         }
    992       }
    993     }
    994   }
    995 
    996   auto& dct = enc_state->shared.block_ctx_map.dc_thresholds;
    997   auto& num_dc_ctxs = enc_state->shared.block_ctx_map.num_dc_ctxs;
    998   num_dc_ctxs = 1;
    999   for (size_t i = 0; i < 3; i++) {
   1000     dct[i].clear();
   1001     int num_thresholds = (CeilLog2Nonzero(total_dc[i]) - 12) / 2;
   1002     // up to 3 buckets per channel:
   1003     // dark/medium/bright, yellow/unsat/blue, green/unsat/red
   1004     num_thresholds = std::min(std::max(num_thresholds, 0), 2);
   1005     size_t cumsum = 0;
   1006     size_t cut = total_dc[i] / (num_thresholds + 1);
   1007     for (int j = 0; j < 2048; j++) {
   1008       cumsum += dc_counts[i][j];
   1009       if (cumsum > cut) {
   1010         dct[i].push_back(j - 1025);
   1011         cut = total_dc[i] * (dct[i].size() + 1) / (num_thresholds + 1);
   1012       }
   1013     }
   1014     num_dc_ctxs *= dct[i].size() + 1;
   1015   }
   1016 
   1017   auto& ctx_map = enc_state->shared.block_ctx_map.ctx_map;
   1018   ctx_map.clear();
   1019   ctx_map.resize(3 * kNumOrders * num_dc_ctxs, 0);
   1020 
   1021   int lbuckets = (dct[1].size() + 1);
   1022   for (size_t i = 0; i < num_dc_ctxs; i++) {
   1023     // up to 9 contexts for luma
   1024     ctx_map[i] = i / lbuckets;
   1025     // up to 3 contexts for chroma
   1026     ctx_map[kNumOrders * num_dc_ctxs + i] =
   1027         ctx_map[2 * kNumOrders * num_dc_ctxs + i] =
   1028             num_dc_ctxs / lbuckets + (i % lbuckets);
   1029   }
   1030   enc_state->shared.block_ctx_map.num_ctxs =
   1031       *std::max_element(ctx_map.begin(), ctx_map.end()) + 1;
   1032 
   1033   // disable DC frame for now
   1034   std::atomic<bool> has_error{false};
   1035   auto compute_dc_coeffs = [&](const uint32_t group_index,
   1036                                size_t /* thread */) {
   1037     if (has_error) return;
   1038     const Rect r = enc_state->shared.frame_dim.DCGroupRect(group_index);
   1039     if (!enc_modular->AddVarDCTDC(frame_header, dc, r, group_index,
   1040                                   /*nl_dc=*/false, enc_state,
   1041                                   /*jpeg_transcode=*/true)) {
   1042       has_error = true;
   1043       return;
   1044     }
   1045     if (!enc_modular->AddACMetadata(r, group_index, /*jpeg_transcode=*/true,
   1046                                     enc_state)) {
   1047       has_error = true;
   1048       return;
   1049     }
   1050   };
   1051   JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, shared.frame_dim.num_dc_groups,
   1052                                 ThreadPool::NoInit, compute_dc_coeffs,
   1053                                 "Compute DC coeffs"));
   1054   if (has_error) return JXL_FAILURE("Compute DC coeffs failed");
   1055 
   1056   return true;
   1057 }
   1058 
   1059 Status ComputeVarDCTEncodingData(const FrameHeader& frame_header,
   1060                                  const Image3F* linear,
   1061                                  Image3F* JXL_RESTRICT opsin, const Rect& rect,
   1062                                  const JxlCmsInterface& cms, ThreadPool* pool,
   1063                                  ModularFrameEncoder* enc_modular,
   1064                                  PassesEncoderState* enc_state,
   1065                                  AuxOut* aux_out) {
   1066   JXL_ASSERT((rect.xsize() % kBlockDim) == 0 &&
   1067              (rect.ysize() % kBlockDim) == 0);
   1068   JXL_RETURN_IF_ERROR(LossyFrameHeuristics(frame_header, enc_state, enc_modular,
   1069                                            linear, opsin, rect, cms, pool,
   1070                                            aux_out));
   1071 
   1072   JXL_RETURN_IF_ERROR(InitializePassesEncoder(
   1073       frame_header, *opsin, rect, cms, pool, enc_state, enc_modular, aux_out));
   1074   return true;
   1075 }
   1076 
   1077 void ComputeAllCoeffOrders(PassesEncoderState& enc_state,
   1078                            const FrameDimensions& frame_dim) {
   1079   auto used_orders_info = ComputeUsedOrders(
   1080       enc_state.cparams.speed_tier, enc_state.shared.ac_strategy,
   1081       Rect(enc_state.shared.raw_quant_field));
   1082   enc_state.used_orders.resize(enc_state.progressive_splitter.GetNumPasses());
   1083   for (size_t i = 0; i < enc_state.progressive_splitter.GetNumPasses(); i++) {
   1084     ComputeCoeffOrder(
   1085         enc_state.cparams.speed_tier, *enc_state.coeffs[i],
   1086         enc_state.shared.ac_strategy, frame_dim, enc_state.used_orders[i],
   1087         enc_state.used_acs, used_orders_info.first, used_orders_info.second,
   1088         &enc_state.shared.coeff_orders[i * enc_state.shared.coeff_order_size]);
   1089   }
   1090   enc_state.used_acs |= used_orders_info.first;
   1091 }
   1092 
   1093 // Working area for TokenizeCoefficients (per-group!)
   1094 struct EncCache {
   1095   // Allocates memory when first called.
   1096   Status InitOnce() {
   1097     if (num_nzeroes.xsize() == 0) {
   1098       JXL_ASSIGN_OR_RETURN(
   1099           num_nzeroes, Image3I::Create(kGroupDimInBlocks, kGroupDimInBlocks));
   1100     }
   1101     return true;
   1102   }
   1103   // TokenizeCoefficients
   1104   Image3I num_nzeroes;
   1105 };
   1106 
   1107 Status TokenizeAllCoefficients(const FrameHeader& frame_header,
   1108                                ThreadPool* pool,
   1109                                PassesEncoderState* enc_state) {
   1110   PassesSharedState& shared = enc_state->shared;
   1111   std::vector<EncCache> group_caches;
   1112   const auto tokenize_group_init = [&](const size_t num_threads) {
   1113     group_caches.resize(num_threads);
   1114     return true;
   1115   };
   1116   std::atomic<bool> has_error{false};
   1117   const auto tokenize_group = [&](const uint32_t group_index,
   1118                                   const size_t thread) {
   1119     if (has_error) return;
   1120     // Tokenize coefficients.
   1121     const Rect rect = shared.frame_dim.BlockGroupRect(group_index);
   1122     for (size_t idx_pass = 0; idx_pass < enc_state->passes.size(); idx_pass++) {
   1123       JXL_ASSERT(enc_state->coeffs[idx_pass]->Type() == ACType::k32);
   1124       const int32_t* JXL_RESTRICT ac_rows[3] = {
   1125           enc_state->coeffs[idx_pass]->PlaneRow(0, group_index, 0).ptr32,
   1126           enc_state->coeffs[idx_pass]->PlaneRow(1, group_index, 0).ptr32,
   1127           enc_state->coeffs[idx_pass]->PlaneRow(2, group_index, 0).ptr32,
   1128       };
   1129       // Ensure group cache is initialized.
   1130       if (!group_caches[thread].InitOnce()) {
   1131         has_error = true;
   1132         return;
   1133       }
   1134       TokenizeCoefficients(
   1135           &shared.coeff_orders[idx_pass * shared.coeff_order_size], rect,
   1136           ac_rows, shared.ac_strategy, frame_header.chroma_subsampling,
   1137           &group_caches[thread].num_nzeroes,
   1138           &enc_state->passes[idx_pass].ac_tokens[group_index], shared.quant_dc,
   1139           shared.raw_quant_field, shared.block_ctx_map);
   1140     }
   1141   };
   1142   JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, shared.frame_dim.num_groups,
   1143                                 tokenize_group_init, tokenize_group,
   1144                                 "TokenizeGroup"));
   1145   if (has_error) return JXL_FAILURE("TokenizeGroup failed");
   1146   return true;
   1147 }
   1148 
   1149 Status EncodeGlobalDCInfo(const PassesSharedState& shared, BitWriter* writer,
   1150                           AuxOut* aux_out) {
   1151   // Encode quantizer DC and global scale.
   1152   QuantizerParams params = shared.quantizer.GetParams();
   1153   JXL_RETURN_IF_ERROR(
   1154       WriteQuantizerParams(params, writer, kLayerQuant, aux_out));
   1155   EncodeBlockCtxMap(shared.block_ctx_map, writer, aux_out);
   1156   ColorCorrelationMapEncodeDC(shared.cmap, writer, kLayerDC, aux_out);
   1157   return true;
   1158 }
   1159 
   1160 // In streaming mode, this function only performs the histogram clustering and
   1161 // saves the histogram bitstreams in enc_state, the actual AC global bitstream
   1162 // is written in OutputAcGlobal() function after all the groups are processed.
   1163 Status EncodeGlobalACInfo(PassesEncoderState* enc_state, BitWriter* writer,
   1164                           ModularFrameEncoder* enc_modular, AuxOut* aux_out) {
   1165   PassesSharedState& shared = enc_state->shared;
   1166   JXL_RETURN_IF_ERROR(DequantMatricesEncode(shared.matrices, writer,
   1167                                             kLayerQuant, aux_out, enc_modular));
   1168   size_t num_histo_bits = CeilLog2Nonzero(shared.frame_dim.num_groups);
   1169   if (!enc_state->streaming_mode && num_histo_bits != 0) {
   1170     BitWriter::Allotment allotment(writer, num_histo_bits);
   1171     writer->Write(num_histo_bits, shared.num_histograms - 1);
   1172     allotment.ReclaimAndCharge(writer, kLayerAC, aux_out);
   1173   }
   1174 
   1175   for (size_t i = 0; i < enc_state->progressive_splitter.GetNumPasses(); i++) {
   1176     // Encode coefficient orders.
   1177     if (!enc_state->streaming_mode) {
   1178       size_t order_bits = 0;
   1179       JXL_RETURN_IF_ERROR(U32Coder::CanEncode(
   1180           kOrderEnc, enc_state->used_orders[i], &order_bits));
   1181       BitWriter::Allotment allotment(writer, order_bits);
   1182       JXL_CHECK(U32Coder::Write(kOrderEnc, enc_state->used_orders[i], writer));
   1183       allotment.ReclaimAndCharge(writer, kLayerOrder, aux_out);
   1184       EncodeCoeffOrders(enc_state->used_orders[i],
   1185                         &shared.coeff_orders[i * shared.coeff_order_size],
   1186                         writer, kLayerOrder, aux_out);
   1187     }
   1188 
   1189     // Encode histograms.
   1190     HistogramParams hist_params(enc_state->cparams.speed_tier,
   1191                                 shared.block_ctx_map.NumACContexts());
   1192     if (enc_state->cparams.speed_tier > SpeedTier::kTortoise) {
   1193       hist_params.lz77_method = HistogramParams::LZ77Method::kNone;
   1194     }
   1195     if (enc_state->cparams.decoding_speed_tier >= 1) {
   1196       hist_params.max_histograms = 6;
   1197     }
   1198     size_t num_histogram_groups = shared.num_histograms;
   1199     if (enc_state->streaming_mode) {
   1200       size_t prev_num_histograms =
   1201           enc_state->passes[i].codes.encoding_info.size();
   1202       if (enc_state->initialize_global_state) {
   1203         prev_num_histograms += kNumFixedHistograms;
   1204         hist_params.add_fixed_histograms = true;
   1205       }
   1206       size_t remaining_histograms = kClustersLimit - prev_num_histograms;
   1207       // Heuristic to assign budget of new histograms to DC groups.
   1208       // TODO(szabadka) Tune this together with the DC group ordering.
   1209       size_t max_histograms = remaining_histograms < 20
   1210                                   ? std::min<size_t>(remaining_histograms, 4)
   1211                                   : remaining_histograms / 4;
   1212       hist_params.max_histograms =
   1213           std::min(max_histograms, hist_params.max_histograms);
   1214       num_histogram_groups = 1;
   1215     }
   1216     hist_params.streaming_mode = enc_state->streaming_mode;
   1217     hist_params.initialize_global_state = enc_state->initialize_global_state;
   1218     BuildAndEncodeHistograms(
   1219         hist_params,
   1220         num_histogram_groups * shared.block_ctx_map.NumACContexts(),
   1221         enc_state->passes[i].ac_tokens, &enc_state->passes[i].codes,
   1222         &enc_state->passes[i].context_map, writer, kLayerAC, aux_out);
   1223   }
   1224 
   1225   return true;
   1226 }
   1227 
   1228 Status EncodeGroups(const FrameHeader& frame_header,
   1229                     PassesEncoderState* enc_state,
   1230                     ModularFrameEncoder* enc_modular, ThreadPool* pool,
   1231                     std::vector<BitWriter>* group_codes, AuxOut* aux_out) {
   1232   const PassesSharedState& shared = enc_state->shared;
   1233   const FrameDimensions& frame_dim = shared.frame_dim;
   1234   const size_t num_groups = frame_dim.num_groups;
   1235   const size_t num_passes = enc_state->progressive_splitter.GetNumPasses();
   1236   const size_t global_ac_index = frame_dim.num_dc_groups + 1;
   1237   const bool is_small_image =
   1238       !enc_state->streaming_mode && num_groups == 1 && num_passes == 1;
   1239   const size_t num_toc_entries =
   1240       is_small_image ? 1
   1241                      : AcGroupIndex(0, 0, num_groups, frame_dim.num_dc_groups) +
   1242                            num_groups * num_passes;
   1243   group_codes->resize(num_toc_entries);
   1244 
   1245   const auto get_output = [&](const size_t index) {
   1246     return &(*group_codes)[is_small_image ? 0 : index];
   1247   };
   1248   auto ac_group_code = [&](size_t pass, size_t group) {
   1249     return get_output(AcGroupIndex(pass, group, frame_dim.num_groups,
   1250                                    frame_dim.num_dc_groups));
   1251   };
   1252 
   1253   if (enc_state->initialize_global_state) {
   1254     if (frame_header.flags & FrameHeader::kPatches) {
   1255       PatchDictionaryEncoder::Encode(shared.image_features.patches,
   1256                                      get_output(0), kLayerDictionary, aux_out);
   1257     }
   1258     if (frame_header.flags & FrameHeader::kSplines) {
   1259       EncodeSplines(shared.image_features.splines, get_output(0), kLayerSplines,
   1260                     HistogramParams(), aux_out);
   1261     }
   1262     if (frame_header.flags & FrameHeader::kNoise) {
   1263       EncodeNoise(shared.image_features.noise_params, get_output(0),
   1264                   kLayerNoise, aux_out);
   1265     }
   1266 
   1267     JXL_RETURN_IF_ERROR(DequantMatricesEncodeDC(shared.matrices, get_output(0),
   1268                                                 kLayerQuant, aux_out));
   1269     if (frame_header.encoding == FrameEncoding::kVarDCT) {
   1270       JXL_RETURN_IF_ERROR(EncodeGlobalDCInfo(shared, get_output(0), aux_out));
   1271     }
   1272     JXL_RETURN_IF_ERROR(enc_modular->EncodeGlobalInfo(enc_state->streaming_mode,
   1273                                                       get_output(0), aux_out));
   1274     JXL_RETURN_IF_ERROR(enc_modular->EncodeStream(get_output(0), aux_out,
   1275                                                   kLayerModularGlobal,
   1276                                                   ModularStreamId::Global()));
   1277   }
   1278 
   1279   std::vector<std::unique_ptr<AuxOut>> aux_outs;
   1280   auto resize_aux_outs = [&aux_outs,
   1281                           aux_out](const size_t num_threads) -> Status {
   1282     if (aux_out == nullptr) {
   1283       aux_outs.resize(num_threads);
   1284     } else {
   1285       while (aux_outs.size() > num_threads) {
   1286         aux_out->Assimilate(*aux_outs.back());
   1287         aux_outs.pop_back();
   1288       }
   1289       while (num_threads > aux_outs.size()) {
   1290         aux_outs.emplace_back(jxl::make_unique<AuxOut>());
   1291       }
   1292     }
   1293     return true;
   1294   };
   1295 
   1296   const auto process_dc_group = [&](const uint32_t group_index,
   1297                                     const size_t thread) {
   1298     AuxOut* my_aux_out = aux_outs[thread].get();
   1299     BitWriter* output = get_output(group_index + 1);
   1300     int modular_group_index = group_index;
   1301     if (enc_state->streaming_mode) {
   1302       JXL_ASSERT(group_index == 0);
   1303       modular_group_index = enc_state->dc_group_index;
   1304     }
   1305     if (frame_header.encoding == FrameEncoding::kVarDCT &&
   1306         !(frame_header.flags & FrameHeader::kUseDcFrame)) {
   1307       BitWriter::Allotment allotment(output, 2);
   1308       output->Write(2, enc_modular->extra_dc_precision[modular_group_index]);
   1309       allotment.ReclaimAndCharge(output, kLayerDC, my_aux_out);
   1310       JXL_CHECK(enc_modular->EncodeStream(
   1311           output, my_aux_out, kLayerDC,
   1312           ModularStreamId::VarDCTDC(modular_group_index)));
   1313     }
   1314     JXL_CHECK(enc_modular->EncodeStream(
   1315         output, my_aux_out, kLayerModularDcGroup,
   1316         ModularStreamId::ModularDC(modular_group_index)));
   1317     if (frame_header.encoding == FrameEncoding::kVarDCT) {
   1318       const Rect& rect = enc_state->shared.frame_dim.DCGroupRect(group_index);
   1319       size_t nb_bits = CeilLog2Nonzero(rect.xsize() * rect.ysize());
   1320       if (nb_bits != 0) {
   1321         BitWriter::Allotment allotment(output, nb_bits);
   1322         output->Write(nb_bits,
   1323                       enc_modular->ac_metadata_size[modular_group_index] - 1);
   1324         allotment.ReclaimAndCharge(output, kLayerControlFields, my_aux_out);
   1325       }
   1326       JXL_CHECK(enc_modular->EncodeStream(
   1327           output, my_aux_out, kLayerControlFields,
   1328           ModularStreamId::ACMetadata(modular_group_index)));
   1329     }
   1330   };
   1331   JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, frame_dim.num_dc_groups,
   1332                                 resize_aux_outs, process_dc_group,
   1333                                 "EncodeDCGroup"));
   1334 
   1335   if (frame_header.encoding == FrameEncoding::kVarDCT) {
   1336     JXL_RETURN_IF_ERROR(EncodeGlobalACInfo(
   1337         enc_state, get_output(global_ac_index), enc_modular, aux_out));
   1338   }
   1339 
   1340   std::atomic<bool> has_error{false};
   1341   const auto process_group = [&](const uint32_t group_index,
   1342                                  const size_t thread) {
   1343     if (has_error) return;
   1344     AuxOut* my_aux_out = aux_outs[thread].get();
   1345 
   1346     size_t ac_group_id =
   1347         enc_state->streaming_mode
   1348             ? enc_modular->ComputeStreamingAbsoluteAcGroupId(
   1349                   enc_state->dc_group_index, group_index, shared.frame_dim)
   1350             : group_index;
   1351 
   1352     for (size_t i = 0; i < num_passes; i++) {
   1353       JXL_DEBUG_V(2, "Encoding AC group %u [abs %" PRIuS "] pass %" PRIuS,
   1354                   group_index, ac_group_id, i);
   1355       if (frame_header.encoding == FrameEncoding::kVarDCT) {
   1356         if (!EncodeGroupTokenizedCoefficients(
   1357                 group_index, i, enc_state->histogram_idx[group_index],
   1358                 *enc_state, ac_group_code(i, group_index), my_aux_out)) {
   1359           has_error = true;
   1360           return;
   1361         }
   1362       }
   1363       // Write all modular encoded data (color?, alpha, depth, extra channels)
   1364       if (!enc_modular->EncodeStream(
   1365               ac_group_code(i, group_index), my_aux_out, kLayerModularAcGroup,
   1366               ModularStreamId::ModularAC(ac_group_id, i))) {
   1367         has_error = true;
   1368         return;
   1369       }
   1370       JXL_DEBUG_V(2,
   1371                   "AC group %u [abs %" PRIuS "] pass %" PRIuS
   1372                   " encoded size is %" PRIuS " bits",
   1373                   group_index, ac_group_id, i,
   1374                   ac_group_code(i, group_index)->BitsWritten());
   1375     }
   1376   };
   1377   JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, num_groups, resize_aux_outs,
   1378                                 process_group, "EncodeGroupCoefficients"));
   1379   if (has_error) return JXL_FAILURE("EncodeGroupCoefficients failed");
   1380   // Resizing aux_outs to 0 also Assimilates the array.
   1381   static_cast<void>(resize_aux_outs(0));
   1382 
   1383   for (BitWriter& bw : *group_codes) {
   1384     BitWriter::Allotment allotment(&bw, 8);
   1385     bw.ZeroPadToByte();  // end of group.
   1386     allotment.ReclaimAndCharge(&bw, kLayerAC, aux_out);
   1387   }
   1388   return true;
   1389 }
   1390 
   1391 Status ComputeEncodingData(
   1392     const CompressParams& cparams, const FrameInfo& frame_info,
   1393     const CodecMetadata* metadata, JxlEncoderChunkedFrameAdapter& frame_data,
   1394     const jpeg::JPEGData* jpeg_data, size_t x0, size_t y0, size_t xsize,
   1395     size_t ysize, const JxlCmsInterface& cms, ThreadPool* pool,
   1396     FrameHeader& mutable_frame_header, ModularFrameEncoder& enc_modular,
   1397     PassesEncoderState& enc_state, std::vector<BitWriter>* group_codes,
   1398     AuxOut* aux_out) {
   1399   JXL_ASSERT(x0 + xsize <= frame_data.xsize);
   1400   JXL_ASSERT(y0 + ysize <= frame_data.ysize);
   1401   const FrameHeader& frame_header = mutable_frame_header;
   1402   PassesSharedState& shared = enc_state.shared;
   1403   shared.metadata = metadata;
   1404   if (enc_state.streaming_mode) {
   1405     shared.frame_dim.Set(
   1406         xsize, ysize, frame_header.group_size_shift,
   1407         /*max_hshift=*/0, /*max_vshift=*/0,
   1408         mutable_frame_header.encoding == FrameEncoding::kModular,
   1409         /*upsampling=*/1);
   1410   } else {
   1411     shared.frame_dim = frame_header.ToFrameDimensions();
   1412   }
   1413 
   1414   shared.image_features.patches.SetPassesSharedState(&shared);
   1415   const FrameDimensions& frame_dim = shared.frame_dim;
   1416   JXL_ASSIGN_OR_RETURN(
   1417       shared.ac_strategy,
   1418       AcStrategyImage::Create(frame_dim.xsize_blocks, frame_dim.ysize_blocks));
   1419   JXL_ASSIGN_OR_RETURN(
   1420       shared.raw_quant_field,
   1421       ImageI::Create(frame_dim.xsize_blocks, frame_dim.ysize_blocks));
   1422   JXL_ASSIGN_OR_RETURN(
   1423       shared.epf_sharpness,
   1424       ImageB::Create(frame_dim.xsize_blocks, frame_dim.ysize_blocks));
   1425   JXL_ASSIGN_OR_RETURN(shared.cmap, ColorCorrelationMap::Create(
   1426                                         frame_dim.xsize, frame_dim.ysize));
   1427   shared.coeff_order_size = kCoeffOrderMaxSize;
   1428   if (frame_header.encoding == FrameEncoding::kVarDCT) {
   1429     shared.coeff_orders.resize(frame_header.passes.num_passes *
   1430                                kCoeffOrderMaxSize);
   1431   }
   1432 
   1433   JXL_ASSIGN_OR_RETURN(shared.quant_dc, ImageB::Create(frame_dim.xsize_blocks,
   1434                                                        frame_dim.ysize_blocks));
   1435   JXL_ASSIGN_OR_RETURN(
   1436       shared.dc_storage,
   1437       Image3F::Create(frame_dim.xsize_blocks, frame_dim.ysize_blocks));
   1438   shared.dc = &shared.dc_storage;
   1439 
   1440   const size_t num_extra_channels = metadata->m.num_extra_channels;
   1441   const ExtraChannelInfo* alpha_eci = metadata->m.Find(ExtraChannel::kAlpha);
   1442   const ExtraChannelInfo* black_eci = metadata->m.Find(ExtraChannel::kBlack);
   1443   const size_t alpha_idx = alpha_eci - metadata->m.extra_channel_info.data();
   1444   const size_t black_idx = black_eci - metadata->m.extra_channel_info.data();
   1445   const ColorEncoding c_enc = metadata->m.color_encoding;
   1446 
   1447   // Make the image patch bigger than the currently processed group in streaming
   1448   // mode so that we can take into account border pixels around the group when
   1449   // computing inverse Gaborish and adaptive quantization map.
   1450   int max_border = enc_state.streaming_mode ? kBlockDim : 0;
   1451   Rect frame_rect(0, 0, frame_data.xsize, frame_data.ysize);
   1452   Rect frame_area_rect = Rect(x0, y0, xsize, ysize);
   1453   Rect patch_rect = frame_area_rect.Extend(max_border, frame_rect);
   1454   JXL_ASSERT(patch_rect.IsInside(frame_rect));
   1455 
   1456   // Allocating a large enough image avoids a copy when padding.
   1457   JXL_ASSIGN_OR_RETURN(Image3F color,
   1458                        Image3F::Create(RoundUpToBlockDim(patch_rect.xsize()),
   1459                                        RoundUpToBlockDim(patch_rect.ysize())));
   1460   color.ShrinkTo(patch_rect.xsize(), patch_rect.ysize());
   1461   std::vector<ImageF> extra_channels(num_extra_channels);
   1462   for (auto& extra_channel : extra_channels) {
   1463     JXL_ASSIGN_OR_RETURN(
   1464         extra_channel, ImageF::Create(patch_rect.xsize(), patch_rect.ysize()));
   1465   }
   1466   ImageF* alpha = alpha_eci ? &extra_channels[alpha_idx] : nullptr;
   1467   ImageF* black = black_eci ? &extra_channels[black_idx] : nullptr;
   1468   bool has_interleaved_alpha = false;
   1469   JxlChunkedFrameInputSource input = frame_data.GetInputSource();
   1470   if (!frame_data.IsJPEG()) {
   1471     JXL_RETURN_IF_ERROR(CopyColorChannels(input, patch_rect, frame_info,
   1472                                           metadata->m, pool, &color, alpha,
   1473                                           &has_interleaved_alpha));
   1474   }
   1475   JXL_RETURN_IF_ERROR(CopyExtraChannels(input, patch_rect, frame_info,
   1476                                         metadata->m, has_interleaved_alpha,
   1477                                         pool, &extra_channels));
   1478 
   1479   shared.image_features.patches.SetPassesSharedState(&shared);
   1480   enc_state.cparams = cparams;
   1481 
   1482   Image3F linear_storage;
   1483   Image3F* linear = nullptr;
   1484 
   1485   if (!jpeg_data) {
   1486     if (frame_header.color_transform == ColorTransform::kXYB &&
   1487         frame_info.ib_needs_color_transform) {
   1488       if (frame_header.encoding == FrameEncoding::kVarDCT &&
   1489           cparams.speed_tier <= SpeedTier::kKitten) {
   1490         JXL_ASSIGN_OR_RETURN(
   1491             linear_storage,
   1492             Image3F::Create(patch_rect.xsize(), patch_rect.ysize()));
   1493         linear = &linear_storage;
   1494       }
   1495       ToXYB(c_enc, metadata->m.IntensityTarget(), black, pool, &color, cms,
   1496             linear);
   1497     } else {
   1498       // Nothing to do.
   1499       // RGB or YCbCr: forward YCbCr is not implemented, this is only used when
   1500       // the input is already in YCbCr
   1501       // If encoding a special DC or reference frame: input is already in XYB.
   1502     }
   1503     bool lossless = cparams.IsLossless();
   1504     if (alpha && !alpha_eci->alpha_associated &&
   1505         frame_header.frame_type == FrameType::kRegularFrame &&
   1506         !ApplyOverride(cparams.keep_invisible, true) &&
   1507         cparams.ec_resampling == cparams.resampling) {
   1508       // simplify invisible pixels
   1509       SimplifyInvisible(&color, *alpha, lossless);
   1510       if (linear) {
   1511         SimplifyInvisible(linear, *alpha, lossless);
   1512       }
   1513     }
   1514     PadImageToBlockMultipleInPlace(&color);
   1515   }
   1516 
   1517   // Rectangle within color that corresponds to the currently processed group in
   1518   // streaming mode.
   1519   Rect group_rect(x0 - patch_rect.x0(), y0 - patch_rect.y0(),
   1520                   RoundUpToBlockDim(xsize), RoundUpToBlockDim(ysize));
   1521 
   1522   if (enc_state.initialize_global_state && !jpeg_data) {
   1523     ComputeChromacityAdjustments(cparams, color, group_rect,
   1524                                  &mutable_frame_header);
   1525   }
   1526 
   1527   bool has_jpeg_data = (jpeg_data != nullptr);
   1528   ComputeNoiseParams(cparams, enc_state.streaming_mode, has_jpeg_data, color,
   1529                      frame_dim, &mutable_frame_header,
   1530                      &shared.image_features.noise_params);
   1531 
   1532   JXL_RETURN_IF_ERROR(
   1533       DownsampleColorChannels(cparams, frame_header, has_jpeg_data, &color));
   1534 
   1535   if (cparams.ec_resampling != 1 && !cparams.already_downsampled) {
   1536     for (ImageF& ec : extra_channels) {
   1537       JXL_ASSIGN_OR_RETURN(ec, DownsampleImage(ec, cparams.ec_resampling));
   1538     }
   1539   }
   1540 
   1541   if (!enc_state.streaming_mode) {
   1542     group_rect = Rect(color);
   1543   }
   1544 
   1545   if (frame_header.encoding == FrameEncoding::kVarDCT) {
   1546     enc_state.passes.resize(enc_state.progressive_splitter.GetNumPasses());
   1547     for (PassesEncoderState::PassData& pass : enc_state.passes) {
   1548       pass.ac_tokens.resize(shared.frame_dim.num_groups);
   1549     }
   1550     if (jpeg_data) {
   1551       JXL_RETURN_IF_ERROR(ComputeJPEGTranscodingData(
   1552           *jpeg_data, frame_header, pool, &enc_modular, &enc_state));
   1553     } else {
   1554       JXL_RETURN_IF_ERROR(ComputeVarDCTEncodingData(
   1555           frame_header, linear, &color, group_rect, cms, pool, &enc_modular,
   1556           &enc_state, aux_out));
   1557     }
   1558     ComputeAllCoeffOrders(enc_state, frame_dim);
   1559     if (!enc_state.streaming_mode) {
   1560       shared.num_histograms = 1;
   1561       enc_state.histogram_idx.resize(frame_dim.num_groups);
   1562     }
   1563     JXL_RETURN_IF_ERROR(
   1564         TokenizeAllCoefficients(frame_header, pool, &enc_state));
   1565   }
   1566 
   1567   if (cparams.modular_mode || !extra_channels.empty()) {
   1568     JXL_RETURN_IF_ERROR(enc_modular.ComputeEncodingData(
   1569         frame_header, metadata->m, &color, extra_channels, group_rect,
   1570         frame_dim, frame_area_rect, &enc_state, cms, pool, aux_out,
   1571         /*do_color=*/cparams.modular_mode));
   1572   }
   1573 
   1574   if (!enc_state.streaming_mode) {
   1575     if (cparams.speed_tier < SpeedTier::kTortoise ||
   1576         !cparams.ModularPartIsLossless() || cparams.responsive ||
   1577         !cparams.custom_fixed_tree.empty()) {
   1578       // Use local trees if doing lossless modular, unless at very slow speeds.
   1579       JXL_RETURN_IF_ERROR(enc_modular.ComputeTree(pool));
   1580       JXL_RETURN_IF_ERROR(enc_modular.ComputeTokens(pool));
   1581     }
   1582     mutable_frame_header.UpdateFlag(shared.image_features.patches.HasAny(),
   1583                                     FrameHeader::kPatches);
   1584     mutable_frame_header.UpdateFlag(shared.image_features.splines.HasAny(),
   1585                                     FrameHeader::kSplines);
   1586   }
   1587 
   1588   JXL_RETURN_IF_ERROR(EncodeGroups(frame_header, &enc_state, &enc_modular, pool,
   1589                                    group_codes, aux_out));
   1590   if (enc_state.streaming_mode) {
   1591     const size_t group_index = enc_state.dc_group_index;
   1592     enc_modular.ClearStreamData(ModularStreamId::VarDCTDC(group_index));
   1593     enc_modular.ClearStreamData(ModularStreamId::ACMetadata(group_index));
   1594     enc_modular.ClearModularStreamData();
   1595   }
   1596   return true;
   1597 }
   1598 
   1599 Status PermuteGroups(const CompressParams& cparams,
   1600                      const FrameDimensions& frame_dim, size_t num_passes,
   1601                      std::vector<coeff_order_t>* permutation,
   1602                      std::vector<BitWriter>* group_codes) {
   1603   const size_t num_groups = frame_dim.num_groups;
   1604   if (!cparams.centerfirst || (num_passes == 1 && num_groups == 1)) {
   1605     return true;
   1606   }
   1607   // Don't permute global DC/AC or DC.
   1608   permutation->resize(frame_dim.num_dc_groups + 2);
   1609   std::iota(permutation->begin(), permutation->end(), 0);
   1610   std::vector<coeff_order_t> ac_group_order(num_groups);
   1611   std::iota(ac_group_order.begin(), ac_group_order.end(), 0);
   1612   size_t group_dim = frame_dim.group_dim;
   1613 
   1614   // The center of the image is either given by parameters or chosen
   1615   // to be the middle of the image by default if center_x, center_y resp.
   1616   // are not provided.
   1617 
   1618   int64_t imag_cx;
   1619   if (cparams.center_x != static_cast<size_t>(-1)) {
   1620     JXL_RETURN_IF_ERROR(cparams.center_x < frame_dim.xsize);
   1621     imag_cx = cparams.center_x;
   1622   } else {
   1623     imag_cx = frame_dim.xsize / 2;
   1624   }
   1625 
   1626   int64_t imag_cy;
   1627   if (cparams.center_y != static_cast<size_t>(-1)) {
   1628     JXL_RETURN_IF_ERROR(cparams.center_y < frame_dim.ysize);
   1629     imag_cy = cparams.center_y;
   1630   } else {
   1631     imag_cy = frame_dim.ysize / 2;
   1632   }
   1633 
   1634   // The center of the group containing the center of the image.
   1635   int64_t cx = (imag_cx / group_dim) * group_dim + group_dim / 2;
   1636   int64_t cy = (imag_cy / group_dim) * group_dim + group_dim / 2;
   1637   // This identifies in what area of the central group the center of the image
   1638   // lies in.
   1639   double direction = -std::atan2(imag_cy - cy, imag_cx - cx);
   1640   // This identifies the side of the central group the center of the image
   1641   // lies closest to. This can take values 0, 1, 2, 3 corresponding to left,
   1642   // bottom, right, top.
   1643   int64_t side = std::fmod((direction + 5 * kPi / 4), 2 * kPi) * 2 / kPi;
   1644   auto get_distance_from_center = [&](size_t gid) {
   1645     Rect r = frame_dim.GroupRect(gid);
   1646     int64_t gcx = r.x0() + group_dim / 2;
   1647     int64_t gcy = r.y0() + group_dim / 2;
   1648     int64_t dx = gcx - cx;
   1649     int64_t dy = gcy - cy;
   1650     // The angle is determined by taking atan2 and adding an appropriate
   1651     // starting point depending on the side we want to start on.
   1652     double angle = std::remainder(
   1653         std::atan2(dy, dx) + kPi / 4 + side * (kPi / 2), 2 * kPi);
   1654     // Concentric squares in clockwise order.
   1655     return std::make_pair(std::max(std::abs(dx), std::abs(dy)), angle);
   1656   };
   1657   std::sort(ac_group_order.begin(), ac_group_order.end(),
   1658             [&](coeff_order_t a, coeff_order_t b) {
   1659               return get_distance_from_center(a) < get_distance_from_center(b);
   1660             });
   1661   std::vector<coeff_order_t> inv_ac_group_order(ac_group_order.size(), 0);
   1662   for (size_t i = 0; i < ac_group_order.size(); i++) {
   1663     inv_ac_group_order[ac_group_order[i]] = i;
   1664   }
   1665   for (size_t i = 0; i < num_passes; i++) {
   1666     size_t pass_start = permutation->size();
   1667     for (coeff_order_t v : inv_ac_group_order) {
   1668       permutation->push_back(pass_start + v);
   1669     }
   1670   }
   1671   std::vector<BitWriter> new_group_codes(group_codes->size());
   1672   for (size_t i = 0; i < permutation->size(); i++) {
   1673     new_group_codes[(*permutation)[i]] = std::move((*group_codes)[i]);
   1674   }
   1675   *group_codes = std::move(new_group_codes);
   1676   return true;
   1677 }
   1678 
   1679 bool CanDoStreamingEncoding(const CompressParams& cparams,
   1680                             const FrameInfo& frame_info,
   1681                             const CodecMetadata& metadata,
   1682                             const JxlEncoderChunkedFrameAdapter& frame_data) {
   1683   if (cparams.buffering == 0) {
   1684     return false;
   1685   }
   1686   if (cparams.buffering == -1) {
   1687     if (cparams.speed_tier < SpeedTier::kTortoise) return false;
   1688     if (cparams.speed_tier < SpeedTier::kSquirrel &&
   1689         cparams.butteraugli_distance > 0.5f) {
   1690       return false;
   1691     }
   1692     if (cparams.speed_tier == SpeedTier::kSquirrel &&
   1693         cparams.butteraugli_distance >= 3.f) {
   1694       return false;
   1695     }
   1696   }
   1697 
   1698   // TODO(veluca): handle different values of `buffering`.
   1699   if (frame_data.xsize <= 2048 && frame_data.ysize <= 2048) {
   1700     return false;
   1701   }
   1702   if (frame_data.IsJPEG()) {
   1703     return false;
   1704   }
   1705   if (cparams.noise == Override::kOn || cparams.patches == Override::kOn) {
   1706     return false;
   1707   }
   1708   if (cparams.progressive_dc != 0 || frame_info.dc_level != 0) {
   1709     return false;
   1710   }
   1711   if (cparams.resampling != 1 || cparams.ec_resampling != 1) {
   1712     return false;
   1713   }
   1714   if (cparams.max_error_mode) {
   1715     return false;
   1716   }
   1717   if (!cparams.ModularPartIsLossless() || cparams.responsive > 0) {
   1718     if (metadata.m.num_extra_channels > 0 || cparams.modular_mode) {
   1719       return false;
   1720     }
   1721   }
   1722   ColorTransform ok_color_transform =
   1723       cparams.modular_mode ? ColorTransform::kNone : ColorTransform::kXYB;
   1724   if (cparams.color_transform != ok_color_transform) {
   1725     return false;
   1726   }
   1727   return true;
   1728 }
   1729 
   1730 void ComputePermutationForStreaming(size_t xsize, size_t ysize,
   1731                                     size_t group_size, size_t num_passes,
   1732                                     std::vector<coeff_order_t>& permutation,
   1733                                     std::vector<size_t>& dc_group_order) {
   1734   // This is only valid in VarDCT mode, otherwise there can be group shift.
   1735   const size_t dc_group_size = group_size * kBlockDim;
   1736   const size_t group_xsize = DivCeil(xsize, group_size);
   1737   const size_t group_ysize = DivCeil(ysize, group_size);
   1738   const size_t dc_group_xsize = DivCeil(xsize, dc_group_size);
   1739   const size_t dc_group_ysize = DivCeil(ysize, dc_group_size);
   1740   const size_t num_groups = group_xsize * group_ysize;
   1741   const size_t num_dc_groups = dc_group_xsize * dc_group_ysize;
   1742   const size_t num_sections = 2 + num_dc_groups + num_passes * num_groups;
   1743   permutation.resize(num_sections);
   1744   size_t new_ix = 0;
   1745   // DC Global is first
   1746   permutation[0] = new_ix++;
   1747   // TODO(szabadka) Change the dc group order to center-first.
   1748   for (size_t dc_y = 0; dc_y < dc_group_ysize; ++dc_y) {
   1749     for (size_t dc_x = 0; dc_x < dc_group_xsize; ++dc_x) {
   1750       size_t dc_ix = dc_y * dc_group_xsize + dc_x;
   1751       dc_group_order.push_back(dc_ix);
   1752       permutation[1 + dc_ix] = new_ix++;
   1753       size_t ac_y0 = dc_y * kBlockDim;
   1754       size_t ac_x0 = dc_x * kBlockDim;
   1755       size_t ac_y1 = std::min<size_t>(group_ysize, ac_y0 + kBlockDim);
   1756       size_t ac_x1 = std::min<size_t>(group_xsize, ac_x0 + kBlockDim);
   1757       for (size_t pass = 0; pass < num_passes; ++pass) {
   1758         for (size_t ac_y = ac_y0; ac_y < ac_y1; ++ac_y) {
   1759           for (size_t ac_x = ac_x0; ac_x < ac_x1; ++ac_x) {
   1760             size_t group_ix = ac_y * group_xsize + ac_x;
   1761             size_t old_ix =
   1762                 AcGroupIndex(pass, group_ix, num_groups, num_dc_groups);
   1763             permutation[old_ix] = new_ix++;
   1764           }
   1765         }
   1766       }
   1767     }
   1768   }
   1769   // AC Global is last
   1770   permutation[1 + num_dc_groups] = new_ix++;
   1771   JXL_ASSERT(new_ix == num_sections);
   1772 }
   1773 
   1774 constexpr size_t kGroupSizeOffset[4] = {
   1775     static_cast<size_t>(0),
   1776     static_cast<size_t>(1024),
   1777     static_cast<size_t>(17408),
   1778     static_cast<size_t>(4211712),
   1779 };
   1780 constexpr size_t kTOCBits[4] = {12, 16, 24, 32};
   1781 
   1782 size_t TOCBucket(size_t group_size) {
   1783   size_t bucket = 0;
   1784   while (bucket < 3 && group_size >= kGroupSizeOffset[bucket + 1]) ++bucket;
   1785   return bucket;
   1786 }
   1787 
   1788 size_t TOCSize(const std::vector<size_t>& group_sizes) {
   1789   size_t toc_bits = 0;
   1790   for (size_t i = 0; i < group_sizes.size(); i++) {
   1791     toc_bits += kTOCBits[TOCBucket(group_sizes[i])];
   1792   }
   1793   return (toc_bits + 7) / 8;
   1794 }
   1795 
   1796 PaddedBytes EncodeTOC(const std::vector<size_t>& group_sizes, AuxOut* aux_out) {
   1797   BitWriter writer;
   1798   BitWriter::Allotment allotment(&writer, 32 * group_sizes.size());
   1799   for (size_t i = 0; i < group_sizes.size(); i++) {
   1800     JXL_CHECK(U32Coder::Write(kTocDist, group_sizes[i], &writer));
   1801   }
   1802   writer.ZeroPadToByte();  // before first group
   1803   allotment.ReclaimAndCharge(&writer, kLayerTOC, aux_out);
   1804   return std::move(writer).TakeBytes();
   1805 }
   1806 
   1807 void ComputeGroupDataOffset(size_t frame_header_size, size_t dc_global_size,
   1808                             size_t num_sections, size_t& min_dc_global_size,
   1809                             size_t& group_offset) {
   1810   size_t max_toc_bits = (num_sections - 1) * 32;
   1811   size_t min_toc_bits = (num_sections - 1) * 12;
   1812   size_t max_padding = (max_toc_bits - min_toc_bits + 7) / 8;
   1813   min_dc_global_size = dc_global_size;
   1814   size_t dc_global_bucket = TOCBucket(min_dc_global_size);
   1815   while (TOCBucket(min_dc_global_size + max_padding) > dc_global_bucket) {
   1816     dc_global_bucket = TOCBucket(min_dc_global_size + max_padding);
   1817     min_dc_global_size = kGroupSizeOffset[dc_global_bucket];
   1818   }
   1819   JXL_ASSERT(TOCBucket(min_dc_global_size) == dc_global_bucket);
   1820   JXL_ASSERT(TOCBucket(min_dc_global_size + max_padding) == dc_global_bucket);
   1821   max_toc_bits += kTOCBits[dc_global_bucket];
   1822   size_t max_toc_size = (max_toc_bits + 7) / 8;
   1823   group_offset = frame_header_size + max_toc_size + min_dc_global_size;
   1824 }
   1825 
   1826 size_t ComputeDcGlobalPadding(const std::vector<size_t>& group_sizes,
   1827                               size_t frame_header_size,
   1828                               size_t group_data_offset,
   1829                               size_t min_dc_global_size) {
   1830   std::vector<size_t> new_group_sizes = group_sizes;
   1831   new_group_sizes[0] = min_dc_global_size;
   1832   size_t toc_size = TOCSize(new_group_sizes);
   1833   size_t actual_offset = frame_header_size + toc_size + group_sizes[0];
   1834   return group_data_offset - actual_offset;
   1835 }
   1836 
   1837 Status OutputGroups(std::vector<BitWriter>&& group_codes,
   1838                     std::vector<size_t>* group_sizes,
   1839                     JxlEncoderOutputProcessorWrapper* output_processor) {
   1840   JXL_ASSERT(group_codes.size() >= 4);
   1841   {
   1842     PaddedBytes dc_group = std::move(group_codes[1]).TakeBytes();
   1843     group_sizes->push_back(dc_group.size());
   1844     JXL_RETURN_IF_ERROR(AppendData(*output_processor, dc_group));
   1845   }
   1846   for (size_t i = 3; i < group_codes.size(); ++i) {
   1847     PaddedBytes ac_group = std::move(group_codes[i]).TakeBytes();
   1848     group_sizes->push_back(ac_group.size());
   1849     JXL_RETURN_IF_ERROR(AppendData(*output_processor, ac_group));
   1850   }
   1851   return true;
   1852 }
   1853 
   1854 void RemoveUnusedHistograms(std::vector<uint8_t>& context_map,
   1855                             EntropyEncodingData& codes) {
   1856   std::vector<int> remap(256, -1);
   1857   std::vector<uint8_t> inv_remap;
   1858   for (size_t i = 0; i < context_map.size(); ++i) {
   1859     const uint8_t histo_ix = context_map[i];
   1860     if (remap[histo_ix] == -1) {
   1861       remap[histo_ix] = inv_remap.size();
   1862       inv_remap.push_back(histo_ix);
   1863     }
   1864     context_map[i] = remap[histo_ix];
   1865   }
   1866   EntropyEncodingData new_codes;
   1867   new_codes.use_prefix_code = codes.use_prefix_code;
   1868   new_codes.lz77 = codes.lz77;
   1869   for (uint8_t histo_idx : inv_remap) {
   1870     new_codes.encoding_info.emplace_back(
   1871         std::move(codes.encoding_info[histo_idx]));
   1872     new_codes.uint_config.emplace_back(codes.uint_config[histo_idx]);
   1873     new_codes.encoded_histograms.emplace_back(
   1874         std::move(codes.encoded_histograms[histo_idx]));
   1875   }
   1876   codes = std::move(new_codes);
   1877 }
   1878 
   1879 Status OutputAcGlobal(PassesEncoderState& enc_state,
   1880                       const FrameDimensions& frame_dim,
   1881                       std::vector<size_t>* group_sizes,
   1882                       JxlEncoderOutputProcessorWrapper* output_processor,
   1883                       AuxOut* aux_out) {
   1884   JXL_ASSERT(frame_dim.num_groups > 1);
   1885   BitWriter writer;
   1886   {
   1887     size_t num_histo_bits = CeilLog2Nonzero(frame_dim.num_groups);
   1888     BitWriter::Allotment allotment(&writer, num_histo_bits + 1);
   1889     writer.Write(1, 1);  // default dequant matrices
   1890     writer.Write(num_histo_bits, frame_dim.num_dc_groups - 1);
   1891     allotment.ReclaimAndCharge(&writer, kLayerAC, aux_out);
   1892   }
   1893   const PassesSharedState& shared = enc_state.shared;
   1894   for (size_t i = 0; i < enc_state.progressive_splitter.GetNumPasses(); i++) {
   1895     // Encode coefficient orders.
   1896     size_t order_bits = 0;
   1897     JXL_RETURN_IF_ERROR(
   1898         U32Coder::CanEncode(kOrderEnc, enc_state.used_orders[i], &order_bits));
   1899     BitWriter::Allotment allotment(&writer, order_bits);
   1900     JXL_CHECK(U32Coder::Write(kOrderEnc, enc_state.used_orders[i], &writer));
   1901     allotment.ReclaimAndCharge(&writer, kLayerOrder, aux_out);
   1902     EncodeCoeffOrders(enc_state.used_orders[i],
   1903                       &shared.coeff_orders[i * shared.coeff_order_size],
   1904                       &writer, kLayerOrder, aux_out);
   1905     // Fix up context map and entropy codes to remove any fix histograms that
   1906     // were not selected by clustering.
   1907     RemoveUnusedHistograms(enc_state.passes[i].context_map,
   1908                            enc_state.passes[i].codes);
   1909     EncodeHistograms(enc_state.passes[i].context_map, enc_state.passes[i].codes,
   1910                      &writer, kLayerAC, aux_out);
   1911   }
   1912   {
   1913     BitWriter::Allotment allotment(&writer, 8);
   1914     writer.ZeroPadToByte();  // end of group.
   1915     allotment.ReclaimAndCharge(&writer, kLayerAC, aux_out);
   1916   }
   1917   PaddedBytes ac_global = std::move(writer).TakeBytes();
   1918   group_sizes->push_back(ac_global.size());
   1919   JXL_RETURN_IF_ERROR(AppendData(*output_processor, ac_global));
   1920   return true;
   1921 }
   1922 
   1923 Status EncodeFrameStreaming(const CompressParams& cparams,
   1924                             const FrameInfo& frame_info,
   1925                             const CodecMetadata* metadata,
   1926                             JxlEncoderChunkedFrameAdapter& frame_data,
   1927                             const JxlCmsInterface& cms, ThreadPool* pool,
   1928                             JxlEncoderOutputProcessorWrapper* output_processor,
   1929                             AuxOut* aux_out) {
   1930   PassesEncoderState enc_state;
   1931   SetProgressiveMode(cparams, &enc_state.progressive_splitter);
   1932   FrameHeader frame_header(metadata);
   1933   std::unique_ptr<jpeg::JPEGData> jpeg_data;
   1934   if (frame_data.IsJPEG()) {
   1935     jpeg_data = make_unique<jpeg::JPEGData>(frame_data.TakeJPEGData());
   1936   }
   1937   JXL_RETURN_IF_ERROR(MakeFrameHeader(frame_data.xsize, frame_data.ysize,
   1938                                       cparams, enc_state.progressive_splitter,
   1939                                       frame_info, jpeg_data.get(), true,
   1940                                       &frame_header));
   1941   const size_t num_passes = enc_state.progressive_splitter.GetNumPasses();
   1942   ModularFrameEncoder enc_modular(frame_header, cparams, true);
   1943   std::vector<coeff_order_t> permutation;
   1944   std::vector<size_t> dc_group_order;
   1945   size_t group_size = frame_header.ToFrameDimensions().group_dim;
   1946   ComputePermutationForStreaming(frame_data.xsize, frame_data.ysize, group_size,
   1947                                  num_passes, permutation, dc_group_order);
   1948   enc_state.shared.num_histograms = dc_group_order.size();
   1949   size_t dc_group_size = group_size * kBlockDim;
   1950   size_t dc_group_xsize = DivCeil(frame_data.xsize, dc_group_size);
   1951   size_t min_dc_global_size = 0;
   1952   size_t group_data_offset = 0;
   1953   PaddedBytes frame_header_bytes;
   1954   PaddedBytes dc_global_bytes;
   1955   std::vector<size_t> group_sizes;
   1956   size_t start_pos = output_processor->CurrentPosition();
   1957   for (size_t i = 0; i < dc_group_order.size(); ++i) {
   1958     size_t dc_ix = dc_group_order[i];
   1959     size_t dc_y = dc_ix / dc_group_xsize;
   1960     size_t dc_x = dc_ix % dc_group_xsize;
   1961     size_t y0 = dc_y * dc_group_size;
   1962     size_t x0 = dc_x * dc_group_size;
   1963     size_t ysize = std::min<size_t>(dc_group_size, frame_data.ysize - y0);
   1964     size_t xsize = std::min<size_t>(dc_group_size, frame_data.xsize - x0);
   1965     size_t group_xsize = DivCeil(xsize, group_size);
   1966     size_t group_ysize = DivCeil(ysize, group_size);
   1967     JXL_DEBUG_V(2,
   1968                 "Encoding DC group #%" PRIuS " dc_y = %" PRIuS " dc_x = %" PRIuS
   1969                 " (x0, y0) = (%" PRIuS ", %" PRIuS ") (xsize, ysize) = (%" PRIuS
   1970                 ", %" PRIuS ")",
   1971                 dc_ix, dc_y, dc_x, x0, y0, xsize, ysize);
   1972     enc_state.streaming_mode = true;
   1973     enc_state.initialize_global_state = (i == 0);
   1974     enc_state.dc_group_index = dc_ix;
   1975     enc_state.histogram_idx = std::vector<size_t>(group_xsize * group_ysize, i);
   1976     std::vector<BitWriter> group_codes;
   1977     JXL_RETURN_IF_ERROR(ComputeEncodingData(
   1978         cparams, frame_info, metadata, frame_data, jpeg_data.get(), x0, y0,
   1979         xsize, ysize, cms, pool, frame_header, enc_modular, enc_state,
   1980         &group_codes, aux_out));
   1981     JXL_ASSERT(enc_state.special_frames.empty());
   1982     if (i == 0) {
   1983       BitWriter writer;
   1984       JXL_RETURN_IF_ERROR(WriteFrameHeader(frame_header, &writer, aux_out));
   1985       BitWriter::Allotment allotment(&writer, 8);
   1986       writer.Write(1, 1);  // write permutation
   1987       EncodePermutation(permutation.data(), /*skip=*/0, permutation.size(),
   1988                         &writer, kLayerHeader, aux_out);
   1989       writer.ZeroPadToByte();
   1990       allotment.ReclaimAndCharge(&writer, kLayerHeader, aux_out);
   1991       frame_header_bytes = std::move(writer).TakeBytes();
   1992       dc_global_bytes = std::move(group_codes[0]).TakeBytes();
   1993       ComputeGroupDataOffset(frame_header_bytes.size(), dc_global_bytes.size(),
   1994                              permutation.size(), min_dc_global_size,
   1995                              group_data_offset);
   1996       JXL_DEBUG_V(2, "Frame header size: %" PRIuS, frame_header_bytes.size());
   1997       JXL_DEBUG_V(2, "DC global size: %" PRIuS ", min size for TOC: %" PRIuS,
   1998                   dc_global_bytes.size(), min_dc_global_size);
   1999       JXL_DEBUG_V(2, "Num groups: %" PRIuS " group data offset: %" PRIuS,
   2000                   permutation.size(), group_data_offset);
   2001       group_sizes.push_back(dc_global_bytes.size());
   2002       output_processor->Seek(start_pos + group_data_offset);
   2003     }
   2004     JXL_RETURN_IF_ERROR(
   2005         OutputGroups(std::move(group_codes), &group_sizes, output_processor));
   2006   }
   2007   if (frame_header.encoding == FrameEncoding::kVarDCT) {
   2008     JXL_RETURN_IF_ERROR(
   2009         OutputAcGlobal(enc_state, frame_header.ToFrameDimensions(),
   2010                        &group_sizes, output_processor, aux_out));
   2011   } else {
   2012     group_sizes.push_back(0);
   2013   }
   2014   JXL_ASSERT(group_sizes.size() == permutation.size());
   2015   size_t end_pos = output_processor->CurrentPosition();
   2016   output_processor->Seek(start_pos);
   2017   size_t padding_size =
   2018       ComputeDcGlobalPadding(group_sizes, frame_header_bytes.size(),
   2019                              group_data_offset, min_dc_global_size);
   2020   group_sizes[0] += padding_size;
   2021   PaddedBytes toc_bytes = EncodeTOC(group_sizes, aux_out);
   2022   std::vector<uint8_t> padding_bytes(padding_size);
   2023   JXL_RETURN_IF_ERROR(AppendData(*output_processor, frame_header_bytes));
   2024   JXL_RETURN_IF_ERROR(AppendData(*output_processor, toc_bytes));
   2025   JXL_RETURN_IF_ERROR(AppendData(*output_processor, dc_global_bytes));
   2026   JXL_RETURN_IF_ERROR(AppendData(*output_processor, padding_bytes));
   2027   JXL_DEBUG_V(2, "TOC size: %" PRIuS " padding bytes after DC global: %" PRIuS,
   2028               toc_bytes.size(), padding_size);
   2029   JXL_ASSERT(output_processor->CurrentPosition() ==
   2030              start_pos + group_data_offset);
   2031   output_processor->Seek(end_pos);
   2032   return true;
   2033 }
   2034 
   2035 Status EncodeFrameOneShot(const CompressParams& cparams,
   2036                           const FrameInfo& frame_info,
   2037                           const CodecMetadata* metadata,
   2038                           JxlEncoderChunkedFrameAdapter& frame_data,
   2039                           const JxlCmsInterface& cms, ThreadPool* pool,
   2040                           JxlEncoderOutputProcessorWrapper* output_processor,
   2041                           AuxOut* aux_out) {
   2042   PassesEncoderState enc_state;
   2043   SetProgressiveMode(cparams, &enc_state.progressive_splitter);
   2044   std::vector<BitWriter> group_codes;
   2045   FrameHeader frame_header(metadata);
   2046   std::unique_ptr<jpeg::JPEGData> jpeg_data;
   2047   if (frame_data.IsJPEG()) {
   2048     jpeg_data = make_unique<jpeg::JPEGData>(frame_data.TakeJPEGData());
   2049   }
   2050   JXL_RETURN_IF_ERROR(MakeFrameHeader(frame_data.xsize, frame_data.ysize,
   2051                                       cparams, enc_state.progressive_splitter,
   2052                                       frame_info, jpeg_data.get(), false,
   2053                                       &frame_header));
   2054   const size_t num_passes = enc_state.progressive_splitter.GetNumPasses();
   2055   ModularFrameEncoder enc_modular(frame_header, cparams, false);
   2056   JXL_RETURN_IF_ERROR(ComputeEncodingData(
   2057       cparams, frame_info, metadata, frame_data, jpeg_data.get(), 0, 0,
   2058       frame_data.xsize, frame_data.ysize, cms, pool, frame_header, enc_modular,
   2059       enc_state, &group_codes, aux_out));
   2060 
   2061   BitWriter writer;
   2062   writer.AppendByteAligned(enc_state.special_frames);
   2063   JXL_RETURN_IF_ERROR(WriteFrameHeader(frame_header, &writer, aux_out));
   2064 
   2065   std::vector<coeff_order_t> permutation;
   2066   JXL_RETURN_IF_ERROR(PermuteGroups(cparams, enc_state.shared.frame_dim,
   2067                                     num_passes, &permutation, &group_codes));
   2068 
   2069   JXL_RETURN_IF_ERROR(
   2070       WriteGroupOffsets(group_codes, permutation, &writer, aux_out));
   2071 
   2072   writer.AppendByteAligned(group_codes);
   2073   PaddedBytes frame_bytes = std::move(writer).TakeBytes();
   2074   JXL_RETURN_IF_ERROR(AppendData(*output_processor, frame_bytes));
   2075 
   2076   return true;
   2077 }
   2078 
   2079 }  // namespace
   2080 
   2081 Status EncodeFrame(const CompressParams& cparams_orig,
   2082                    const FrameInfo& frame_info, const CodecMetadata* metadata,
   2083                    JxlEncoderChunkedFrameAdapter& frame_data,
   2084                    const JxlCmsInterface& cms, ThreadPool* pool,
   2085                    JxlEncoderOutputProcessorWrapper* output_processor,
   2086                    AuxOut* aux_out) {
   2087   CompressParams cparams = cparams_orig;
   2088   if (cparams.speed_tier == SpeedTier::kTectonicPlate &&
   2089       !cparams.IsLossless()) {
   2090     cparams.speed_tier = SpeedTier::kGlacier;
   2091   }
   2092   // Lightning mode is handled externally, so switch to Thunder mode to handle
   2093   // potentially weird cases.
   2094   if (cparams.speed_tier == SpeedTier::kLightning) {
   2095     cparams.speed_tier = SpeedTier::kThunder;
   2096   }
   2097   if (cparams.speed_tier == SpeedTier::kTectonicPlate) {
   2098     std::vector<CompressParams> all_params;
   2099     std::vector<int> shifts{0};
   2100     if (cparams_orig.responsive != 1) {
   2101       shifts.push_back(-1);
   2102       shifts.push_back(3);
   2103     }
   2104 
   2105     CompressParams cparams_attempt = cparams_orig;
   2106     cparams_attempt.speed_tier = SpeedTier::kGlacier;
   2107     cparams_attempt.options.max_properties = 4;
   2108 
   2109     for (float x : {0.0f, 80.f}) {
   2110       cparams_attempt.channel_colors_percent = x;
   2111       for (float y : {0.0f, 95.0f}) {
   2112         cparams_attempt.channel_colors_pre_transform_percent = y;
   2113         // 70000 ensures that the number of palette colors is representable in
   2114         // modular headers.
   2115         for (int K : {0, 1 << 10, 70000}) {
   2116           cparams_attempt.palette_colors = K;
   2117           for (int tree_mode :
   2118                {-1, static_cast<int>(ModularOptions::TreeMode::kNoWP),
   2119                 static_cast<int>(ModularOptions::TreeMode::kDefault)}) {
   2120             if (tree_mode == -1) {
   2121               // LZ77 only
   2122               cparams_attempt.options.nb_repeats = 0;
   2123             } else {
   2124               cparams_attempt.options.nb_repeats = 1;
   2125               cparams_attempt.options.wp_tree_mode =
   2126                   static_cast<ModularOptions::TreeMode>(tree_mode);
   2127             }
   2128             for (Predictor pred : {Predictor::Zero, Predictor::Variable}) {
   2129               cparams_attempt.options.predictor = pred;
   2130               for (int g : shifts) {
   2131                 cparams_attempt.modular_group_size_shift = g;
   2132                 for (Override patches : {Override::kDefault, Override::kOff}) {
   2133                   cparams_attempt.patches = patches;
   2134                   all_params.push_back(cparams_attempt);
   2135                 }
   2136               }
   2137             }
   2138           }
   2139         }
   2140       }
   2141     }
   2142 
   2143     fprintf(stderr, "Trying %zu variations\n", all_params.size());
   2144 
   2145     std::atomic<bool> has_error{false};
   2146     std::mutex mut;
   2147     CompressParams best_params;
   2148     std::vector<uint8_t> best_out;
   2149 
   2150     JXL_RETURN_IF_ERROR(RunOnPool(
   2151         pool, 0, all_params.size(), ThreadPool::NoInit,
   2152         [&](size_t task, size_t) {
   2153           if (has_error) return;
   2154           std::vector<uint8_t> output(64);
   2155           uint8_t* next_out = output.data();
   2156           size_t avail_out = output.size();
   2157           JxlEncoderOutputProcessorWrapper local_output;
   2158           local_output.SetAvailOut(&next_out, &avail_out);
   2159           if (!EncodeFrame(all_params[task], frame_info, metadata, frame_data,
   2160                            cms, nullptr, &local_output, aux_out)) {
   2161             has_error = true;
   2162             return;
   2163           }
   2164           local_output.SetFinalizedPosition();
   2165           local_output.CopyOutput(output, next_out, avail_out);
   2166 
   2167           fprintf(stderr, "%3zu: channel_colors_percent=%f, "
   2168                   "channel_colors_pre_transform_percent=%f, palette_colors=%d, "
   2169                   "nb_repeats=%f, wp_tree_mode=%d, predictor=%d, "
   2170                   "modular_group_size_shift=%d, patches=%d -> %zu\n", task,
   2171                   all_params[task].channel_colors_percent,
   2172                   all_params[task].channel_colors_pre_transform_percent,
   2173                   all_params[task].palette_colors, all_params[task].options.nb_repeats,
   2174                   static_cast<int>(all_params[task].options.wp_tree_mode),
   2175                   static_cast<int>(all_params[task].options.predictor),
   2176                   all_params[task].modular_group_size_shift,
   2177                   static_cast<int>(all_params[task].patches),
   2178                   output.size());
   2179           std::unique_lock<std::mutex> lock(mut);
   2180           if (best_out.empty() || best_out.size() > output.size()) {
   2181             best_out = std::move(output);
   2182             best_params = all_params[task];
   2183           }
   2184         },
   2185         "Compress kTectonicPlate"));
   2186     if (has_error) return JXL_FAILURE("Compress kTectonicPlate failed");
   2187 
   2188     cparams = best_params;
   2189     fprintf(stderr, "Selected: channel_colors_percent=%f, "
   2190             "channel_colors_pre_transform_percent=%f, palette_colors=%d, "
   2191             "nb_repeats=%f, wp_tree_mode=%d, predictor=%d, "
   2192             "modular_group_size_shift=%d, patches=%d -> %zu\n",
   2193             cparams.channel_colors_percent,
   2194             cparams.channel_colors_pre_transform_percent,
   2195             cparams.palette_colors, cparams.options.nb_repeats,
   2196             static_cast<int>(cparams.options.wp_tree_mode),
   2197             static_cast<int>(cparams.options.predictor),
   2198             cparams.modular_group_size_shift,
   2199             static_cast<int>(cparams.patches),
   2200             best_out.size());
   2201     JXL_RETURN_IF_ERROR(AppendData(*output_processor, best_out));
   2202     return true;
   2203   }
   2204 
   2205   JXL_RETURN_IF_ERROR(ParamsPostInit(&cparams));
   2206 
   2207   if (cparams.butteraugli_distance < 0) {
   2208     return JXL_FAILURE("Expected non-negative distance");
   2209   }
   2210 
   2211   if (cparams.progressive_dc < 0) {
   2212     if (cparams.progressive_dc != -1) {
   2213       return JXL_FAILURE("Invalid progressive DC setting value (%d)",
   2214                          cparams.progressive_dc);
   2215     }
   2216     cparams.progressive_dc = 0;
   2217   }
   2218   if (cparams.ec_resampling < cparams.resampling) {
   2219     cparams.ec_resampling = cparams.resampling;
   2220   }
   2221   if (cparams.resampling > 1 || frame_info.is_preview) {
   2222     cparams.progressive_dc = 0;
   2223   }
   2224 
   2225   if (frame_info.dc_level + cparams.progressive_dc > 4) {
   2226     return JXL_FAILURE("Too many levels of progressive DC");
   2227   }
   2228 
   2229   if (cparams.butteraugli_distance != 0 &&
   2230       cparams.butteraugli_distance < kMinButteraugliDistance) {
   2231     return JXL_FAILURE("Butteraugli distance is too low (%f)",
   2232                        cparams.butteraugli_distance);
   2233   }
   2234 
   2235   if (frame_data.IsJPEG()) {
   2236     cparams.gaborish = Override::kOff;
   2237     cparams.epf = 0;
   2238     cparams.modular_mode = false;
   2239   }
   2240 
   2241   if (frame_data.xsize == 0 || frame_data.ysize == 0) {
   2242     return JXL_FAILURE("Empty image");
   2243   }
   2244 
   2245   // Assert that this metadata is correctly set up for the compression params,
   2246   // this should have been done by enc_file.cc
   2247   JXL_ASSERT(metadata->m.xyb_encoded ==
   2248              (cparams.color_transform == ColorTransform::kXYB));
   2249 
   2250   if (frame_data.IsJPEG() && cparams.color_transform == ColorTransform::kXYB) {
   2251     return JXL_FAILURE("Can't add JPEG frame to XYB codestream");
   2252   }
   2253 
   2254   if (CanDoStreamingEncoding(cparams, frame_info, *metadata, frame_data)) {
   2255     return EncodeFrameStreaming(cparams, frame_info, metadata, frame_data, cms,
   2256                                 pool, output_processor, aux_out);
   2257   } else {
   2258     return EncodeFrameOneShot(cparams, frame_info, metadata, frame_data, cms,
   2259                               pool, output_processor, aux_out);
   2260   }
   2261 }
   2262 
   2263 Status EncodeFrame(const CompressParams& cparams_orig,
   2264                    const FrameInfo& frame_info, const CodecMetadata* metadata,
   2265                    const ImageBundle& ib, const JxlCmsInterface& cms,
   2266                    ThreadPool* pool, BitWriter* writer, AuxOut* aux_out) {
   2267   JxlEncoderChunkedFrameAdapter frame_data(ib.xsize(), ib.ysize(),
   2268                                            ib.extra_channels().size());
   2269   std::vector<uint8_t> color;
   2270   if (ib.IsJPEG()) {
   2271     frame_data.SetJPEGData(*ib.jpeg_data);
   2272   } else {
   2273     uint32_t num_channels =
   2274         ib.IsGray() && frame_info.ib_needs_color_transform ? 1 : 3;
   2275     size_t stride = ib.xsize() * num_channels * 4;
   2276     color.resize(ib.ysize() * stride);
   2277     JXL_RETURN_IF_ERROR(ConvertToExternal(
   2278         ib, /*bits_per_sample=*/32, /*float_out=*/true, num_channels,
   2279         JXL_NATIVE_ENDIAN, stride, pool, color.data(), color.size(),
   2280         /*out_callback=*/{}, Orientation::kIdentity));
   2281     JxlPixelFormat format{num_channels, JXL_TYPE_FLOAT, JXL_NATIVE_ENDIAN, 0};
   2282     frame_data.SetFromBuffer(0, color.data(), color.size(), format);
   2283   }
   2284   for (size_t ec = 0; ec < ib.extra_channels().size(); ++ec) {
   2285     JxlPixelFormat ec_format{1, JXL_TYPE_FLOAT, JXL_NATIVE_ENDIAN, 0};
   2286     size_t ec_stride = ib.xsize() * 4;
   2287     std::vector<uint8_t> ec_data(ib.ysize() * ec_stride);
   2288     const ImageF* channel = &ib.extra_channels()[ec];
   2289     JXL_RETURN_IF_ERROR(ConvertChannelsToExternal(
   2290         &channel, 1,
   2291         /*bits_per_sample=*/32,
   2292         /*float_out=*/true, JXL_NATIVE_ENDIAN, ec_stride, pool, ec_data.data(),
   2293         ec_data.size(), /*out_callback=*/{}, Orientation::kIdentity));
   2294     frame_data.SetFromBuffer(1 + ec, ec_data.data(), ec_data.size(), ec_format);
   2295   }
   2296   FrameInfo fi = frame_info;
   2297   fi.origin = ib.origin;
   2298   fi.blend = ib.blend;
   2299   fi.blendmode = ib.blendmode;
   2300   fi.duration = ib.duration;
   2301   fi.timecode = ib.timecode;
   2302   fi.name = ib.name;
   2303   std::vector<uint8_t> output(64);
   2304   uint8_t* next_out = output.data();
   2305   size_t avail_out = output.size();
   2306   JxlEncoderOutputProcessorWrapper output_processor;
   2307   output_processor.SetAvailOut(&next_out, &avail_out);
   2308   JXL_RETURN_IF_ERROR(EncodeFrame(cparams_orig, fi, metadata, frame_data, cms,
   2309                                   pool, &output_processor, aux_out));
   2310   output_processor.SetFinalizedPosition();
   2311   output_processor.CopyOutput(output, next_out, avail_out);
   2312   writer->AppendByteAligned(Bytes(output));
   2313   return true;
   2314 }
   2315 
   2316 }  // namespace jxl