libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

dec_jpeg_data_writer.cc (33584B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/jpeg/dec_jpeg_data_writer.h"
      7 
      8 #include <stdlib.h>
      9 #include <string.h> /* for memset, memcpy */
     10 
     11 #include <algorithm>
     12 #include <cstddef>
     13 #include <cstdint>
     14 #include <deque>
     15 #include <string>
     16 #include <vector>
     17 
     18 #include "lib/jxl/base/bits.h"
     19 #include "lib/jxl/base/byte_order.h"
     20 #include "lib/jxl/base/common.h"
     21 #include "lib/jxl/base/status.h"
     22 #include "lib/jxl/frame_dimensions.h"
     23 #include "lib/jxl/image_bundle.h"
     24 #include "lib/jxl/jpeg/dec_jpeg_serialization_state.h"
     25 #include "lib/jxl/jpeg/jpeg_data.h"
     26 
     27 namespace jxl {
     28 namespace jpeg {
     29 
     30 namespace {
     31 
     32 enum struct SerializationStatus {
     33   NEEDS_MORE_INPUT,
     34   NEEDS_MORE_OUTPUT,
     35   ERROR,
     36   DONE
     37 };
     38 
     39 const int kJpegPrecision = 8;
     40 
     41 // JpegBitWriter: buffer size
     42 const size_t kJpegBitWriterChunkSize = 16384;
     43 
     44 // Returns non-zero if and only if x has a zero byte, i.e. one of
     45 // x & 0xff, x & 0xff00, ..., x & 0xff00000000000000 is zero.
     46 JXL_INLINE uint64_t HasZeroByte(uint64_t x) {
     47   return (x - 0x0101010101010101ULL) & ~x & 0x8080808080808080ULL;
     48 }
     49 
     50 void JpegBitWriterInit(JpegBitWriter* bw,
     51                        std::deque<OutputChunk>* output_queue) {
     52   bw->output = output_queue;
     53   bw->chunk = OutputChunk(kJpegBitWriterChunkSize);
     54   bw->pos = 0;
     55   bw->put_buffer = 0;
     56   bw->put_bits = 64;
     57   bw->healthy = true;
     58   bw->data = bw->chunk.buffer->data();
     59 }
     60 
     61 JXL_NOINLINE void SwapBuffer(JpegBitWriter* bw) {
     62   bw->chunk.len = bw->pos;
     63   bw->output->emplace_back(std::move(bw->chunk));
     64   bw->chunk = OutputChunk(kJpegBitWriterChunkSize);
     65   bw->data = bw->chunk.buffer->data();
     66   bw->pos = 0;
     67 }
     68 
     69 JXL_INLINE void Reserve(JpegBitWriter* bw, size_t n_bytes) {
     70   if (JXL_UNLIKELY((bw->pos + n_bytes) > kJpegBitWriterChunkSize)) {
     71     SwapBuffer(bw);
     72   }
     73 }
     74 
     75 /**
     76  * Writes the given byte to the output, writes an extra zero if byte is 0xFF.
     77  *
     78  * This method is "careless" - caller must make sure that there is enough
     79  * space in the output buffer. Emits up to 2 bytes to buffer.
     80  */
     81 JXL_INLINE void EmitByte(JpegBitWriter* bw, int byte) {
     82   bw->data[bw->pos] = byte;
     83   bw->data[bw->pos + 1] = 0;
     84   bw->pos += (byte != 0xFF ? 1 : 2);
     85 }
     86 
     87 JXL_INLINE void DischargeBitBuffer(JpegBitWriter* bw, int nbits,
     88                                    uint64_t bits) {
     89   // At this point we are ready to emit the put_buffer to the output.
     90   // The JPEG format requires that after every 0xff byte in the entropy
     91   // coded section, there is a zero byte, therefore we first check if any of
     92   // the 8 bytes of put_buffer is 0xFF.
     93   bw->put_buffer |= (bits >> -bw->put_bits);
     94   if (JXL_UNLIKELY(HasZeroByte(~bw->put_buffer))) {
     95     // We have a 0xFF byte somewhere, examine each byte and append a zero
     96     // byte if necessary.
     97     EmitByte(bw, (bw->put_buffer >> 56) & 0xFF);
     98     EmitByte(bw, (bw->put_buffer >> 48) & 0xFF);
     99     EmitByte(bw, (bw->put_buffer >> 40) & 0xFF);
    100     EmitByte(bw, (bw->put_buffer >> 32) & 0xFF);
    101     EmitByte(bw, (bw->put_buffer >> 24) & 0xFF);
    102     EmitByte(bw, (bw->put_buffer >> 16) & 0xFF);
    103     EmitByte(bw, (bw->put_buffer >> 8) & 0xFF);
    104     EmitByte(bw, (bw->put_buffer) & 0xFF);
    105   } else {
    106     // We don't have any 0xFF bytes, output all 8 bytes without checking.
    107     StoreBE64(bw->put_buffer, bw->data + bw->pos);
    108     bw->pos += 8;
    109   }
    110 
    111   bw->put_bits += 64;
    112   bw->put_buffer = bits << bw->put_bits;
    113 }
    114 
    115 JXL_INLINE void WriteBits(JpegBitWriter* bw, int nbits, uint64_t bits) {
    116   JXL_DASSERT(nbits > 0);
    117   bw->put_bits -= nbits;
    118   if (JXL_UNLIKELY(bw->put_bits < 0)) {
    119     if (JXL_UNLIKELY(nbits > 64)) {
    120       bw->put_bits += nbits;
    121       bw->healthy = false;
    122     } else {
    123       DischargeBitBuffer(bw, nbits, bits);
    124     }
    125   } else {
    126     bw->put_buffer |= (bits << bw->put_bits);
    127   }
    128 }
    129 
    130 void EmitMarker(JpegBitWriter* bw, int marker) {
    131   Reserve(bw, 2);
    132   JXL_DASSERT(marker != 0xFF);
    133   bw->data[bw->pos++] = 0xFF;
    134   bw->data[bw->pos++] = marker;
    135 }
    136 
    137 bool JumpToByteBoundary(JpegBitWriter* bw, const uint8_t** pad_bits,
    138                         const uint8_t* pad_bits_end) {
    139   size_t n_bits = bw->put_bits & 7u;
    140   uint8_t pad_pattern;
    141   if (*pad_bits == nullptr) {
    142     pad_pattern = (1u << n_bits) - 1;
    143   } else {
    144     pad_pattern = 0;
    145     const uint8_t* src = *pad_bits;
    146     // TODO(eustas): bitwise reading looks insanely ineffective!
    147     while (n_bits--) {
    148       pad_pattern <<= 1;
    149       if (src >= pad_bits_end) return false;
    150       uint8_t bit = *src;
    151       src++;
    152       JXL_ASSERT(bit <= 1);
    153       pad_pattern |= bit;
    154     }
    155     *pad_bits = src;
    156   }
    157 
    158   Reserve(bw, 16);
    159 
    160   while (bw->put_bits <= 56) {
    161     int c = (bw->put_buffer >> 56) & 0xFF;
    162     EmitByte(bw, c);
    163     bw->put_buffer <<= 8;
    164     bw->put_bits += 8;
    165   }
    166   if (bw->put_bits < 64) {
    167     int pad_mask = 0xFFu >> (64 - bw->put_bits);
    168     int c = ((bw->put_buffer >> 56) & ~pad_mask) | pad_pattern;
    169     EmitByte(bw, c);
    170   }
    171   bw->put_buffer = 0;
    172   bw->put_bits = 64;
    173 
    174   return true;
    175 }
    176 
    177 void JpegBitWriterFinish(JpegBitWriter* bw) {
    178   if (bw->pos == 0) return;
    179   bw->chunk.len = bw->pos;
    180   bw->output->emplace_back(std::move(bw->chunk));
    181   bw->chunk = OutputChunk(nullptr, 0);
    182   bw->data = nullptr;
    183   bw->pos = 0;
    184 }
    185 
    186 void DCTCodingStateInit(DCTCodingState* s) {
    187   s->eob_run_ = 0;
    188   s->cur_ac_huff_ = nullptr;
    189   s->refinement_bits_.clear();
    190   s->refinement_bits_.reserve(64);
    191 }
    192 
    193 JXL_INLINE void WriteSymbol(int symbol, HuffmanCodeTable* table,
    194                             JpegBitWriter* bw) {
    195   WriteBits(bw, table->depth[symbol], table->code[symbol]);
    196 }
    197 
    198 JXL_INLINE void WriteSymbolBits(int symbol, HuffmanCodeTable* table,
    199                                 JpegBitWriter* bw, int nbits, uint64_t bits) {
    200   WriteBits(bw, nbits + table->depth[symbol],
    201             bits | (table->code[symbol] << nbits));
    202 }
    203 
    204 // Emit all buffered data to the bit stream using the given Huffman code and
    205 // bit writer.
    206 JXL_INLINE void Flush(DCTCodingState* s, JpegBitWriter* bw) {
    207   if (s->eob_run_ > 0) {
    208     Reserve(bw, 16);
    209     int nbits = FloorLog2Nonzero<uint32_t>(s->eob_run_);
    210     int symbol = nbits << 4u;
    211     WriteSymbol(symbol, s->cur_ac_huff_, bw);
    212     if (nbits > 0) {
    213       WriteBits(bw, nbits, s->eob_run_ & ((1 << nbits) - 1));
    214     }
    215     s->eob_run_ = 0;
    216   }
    217   const size_t kStride = 124;  // (515 - 16) / 2 / 2
    218   size_t num_words = s->refinement_bits_count_ >> 4;
    219   size_t i = 0;
    220   while (i < num_words) {
    221     size_t limit = std::min(i + kStride, num_words);
    222     Reserve(bw, 512);
    223     for (; i < limit; ++i) {
    224       WriteBits(bw, 16, s->refinement_bits_[i]);
    225     }
    226   }
    227   Reserve(bw, 16);
    228   size_t tail = s->refinement_bits_count_ & 0xF;
    229   if (tail) {
    230     WriteBits(bw, tail, s->refinement_bits_.back());
    231   }
    232   s->refinement_bits_.clear();
    233   s->refinement_bits_count_ = 0;
    234 }
    235 
    236 // Buffer some more data at the end-of-band (the last non-zero or newly
    237 // non-zero coefficient within the [Ss, Se] spectral band).
    238 JXL_INLINE void BufferEndOfBand(DCTCodingState* s, HuffmanCodeTable* ac_huff,
    239                                 const int* new_bits_array,
    240                                 size_t new_bits_count, JpegBitWriter* bw) {
    241   if (s->eob_run_ == 0) {
    242     s->cur_ac_huff_ = ac_huff;
    243   }
    244   ++s->eob_run_;
    245   if (new_bits_count) {
    246     uint64_t new_bits = 0;
    247     for (size_t i = 0; i < new_bits_count; ++i) {
    248       new_bits = (new_bits << 1) | new_bits_array[i];
    249     }
    250     size_t tail = s->refinement_bits_count_ & 0xF;
    251     if (tail) {  // First stuff the tail item
    252       size_t stuff_bits_count = std::min(16 - tail, new_bits_count);
    253       uint16_t stuff_bits = new_bits >> (new_bits_count - stuff_bits_count);
    254       stuff_bits &= ((1u << stuff_bits_count) - 1);
    255       s->refinement_bits_.back() =
    256           (s->refinement_bits_.back() << stuff_bits_count) | stuff_bits;
    257       new_bits_count -= stuff_bits_count;
    258       s->refinement_bits_count_ += stuff_bits_count;
    259     }
    260     while (new_bits_count >= 16) {
    261       s->refinement_bits_.push_back(new_bits >> (new_bits_count - 16));
    262       new_bits_count -= 16;
    263       s->refinement_bits_count_ += 16;
    264     }
    265     if (new_bits_count) {
    266       s->refinement_bits_.push_back(new_bits & ((1u << new_bits_count) - 1));
    267       s->refinement_bits_count_ += new_bits_count;
    268     }
    269   }
    270 
    271   if (s->eob_run_ == 0x7FFF) {
    272     Flush(s, bw);
    273   }
    274 }
    275 
    276 bool BuildHuffmanCodeTable(const JPEGHuffmanCode& huff,
    277                            HuffmanCodeTable* table) {
    278   int huff_code[kJpegHuffmanAlphabetSize];
    279   // +1 for a sentinel element.
    280   uint32_t huff_size[kJpegHuffmanAlphabetSize + 1];
    281   int p = 0;
    282   for (size_t l = 1; l <= kJpegHuffmanMaxBitLength; ++l) {
    283     int i = huff.counts[l];
    284     if (p + i > kJpegHuffmanAlphabetSize + 1) {
    285       return false;
    286     }
    287     while (i--) huff_size[p++] = l;
    288   }
    289 
    290   if (p == 0) {
    291     return true;
    292   }
    293 
    294   // Reuse sentinel element.
    295   int last_p = p - 1;
    296   huff_size[last_p] = 0;
    297 
    298   int code = 0;
    299   uint32_t si = huff_size[0];
    300   p = 0;
    301   while (huff_size[p]) {
    302     while ((huff_size[p]) == si) {
    303       huff_code[p++] = code;
    304       code++;
    305     }
    306     code <<= 1;
    307     si++;
    308   }
    309   for (p = 0; p < last_p; p++) {
    310     int i = huff.values[p];
    311     table->depth[i] = huff_size[p];
    312     table->code[i] = huff_code[p];
    313   }
    314   return true;
    315 }
    316 
    317 bool EncodeSOI(SerializationState* state) {
    318   state->output_queue.push_back(OutputChunk({0xFF, 0xD8}));
    319   return true;
    320 }
    321 
    322 bool EncodeEOI(const JPEGData& jpg, SerializationState* state) {
    323   state->output_queue.push_back(OutputChunk({0xFF, 0xD9}));
    324   state->output_queue.emplace_back(jpg.tail_data);
    325   return true;
    326 }
    327 
    328 bool EncodeSOF(const JPEGData& jpg, uint8_t marker, SerializationState* state) {
    329   if (marker <= 0xC2) state->is_progressive = (marker == 0xC2);
    330 
    331   const size_t n_comps = jpg.components.size();
    332   const size_t marker_len = 8 + 3 * n_comps;
    333   state->output_queue.emplace_back(marker_len + 2);
    334   uint8_t* data = state->output_queue.back().buffer->data();
    335   size_t pos = 0;
    336   data[pos++] = 0xFF;
    337   data[pos++] = marker;
    338   data[pos++] = marker_len >> 8u;
    339   data[pos++] = marker_len & 0xFFu;
    340   data[pos++] = kJpegPrecision;
    341   data[pos++] = jpg.height >> 8u;
    342   data[pos++] = jpg.height & 0xFFu;
    343   data[pos++] = jpg.width >> 8u;
    344   data[pos++] = jpg.width & 0xFFu;
    345   data[pos++] = n_comps;
    346   for (size_t i = 0; i < n_comps; ++i) {
    347     data[pos++] = jpg.components[i].id;
    348     data[pos++] = ((jpg.components[i].h_samp_factor << 4u) |
    349                    (jpg.components[i].v_samp_factor));
    350     const size_t quant_idx = jpg.components[i].quant_idx;
    351     if (quant_idx >= jpg.quant.size()) return false;
    352     data[pos++] = jpg.quant[quant_idx].index;
    353   }
    354   return true;
    355 }
    356 
    357 bool EncodeSOS(const JPEGData& jpg, const JPEGScanInfo& scan_info,
    358                SerializationState* state) {
    359   const size_t n_scans = scan_info.num_components;
    360   const size_t marker_len = 6 + 2 * n_scans;
    361   state->output_queue.emplace_back(marker_len + 2);
    362   uint8_t* data = state->output_queue.back().buffer->data();
    363   size_t pos = 0;
    364   data[pos++] = 0xFF;
    365   data[pos++] = 0xDA;
    366   data[pos++] = marker_len >> 8u;
    367   data[pos++] = marker_len & 0xFFu;
    368   data[pos++] = n_scans;
    369   for (size_t i = 0; i < n_scans; ++i) {
    370     const JPEGComponentScanInfo& si = scan_info.components[i];
    371     if (si.comp_idx >= jpg.components.size()) return false;
    372     data[pos++] = jpg.components[si.comp_idx].id;
    373     data[pos++] = (si.dc_tbl_idx << 4u) + si.ac_tbl_idx;
    374   }
    375   data[pos++] = scan_info.Ss;
    376   data[pos++] = scan_info.Se;
    377   data[pos++] = ((scan_info.Ah << 4u) | (scan_info.Al));
    378   return true;
    379 }
    380 
    381 bool EncodeDHT(const JPEGData& jpg, SerializationState* state) {
    382   const std::vector<JPEGHuffmanCode>& huffman_code = jpg.huffman_code;
    383 
    384   size_t marker_len = 2;
    385   for (size_t i = state->dht_index; i < huffman_code.size(); ++i) {
    386     const JPEGHuffmanCode& huff = huffman_code[i];
    387     marker_len += kJpegHuffmanMaxBitLength;
    388     for (size_t j = 0; j < huff.counts.size(); ++j) {
    389       marker_len += huff.counts[j];
    390     }
    391     if (huff.is_last) break;
    392   }
    393   state->output_queue.emplace_back(marker_len + 2);
    394   uint8_t* data = state->output_queue.back().buffer->data();
    395   size_t pos = 0;
    396   data[pos++] = 0xFF;
    397   data[pos++] = 0xC4;
    398   data[pos++] = marker_len >> 8u;
    399   data[pos++] = marker_len & 0xFFu;
    400   while (true) {
    401     const size_t huffman_code_index = state->dht_index++;
    402     if (huffman_code_index >= huffman_code.size()) {
    403       return false;
    404     }
    405     const JPEGHuffmanCode& huff = huffman_code[huffman_code_index];
    406     size_t index = huff.slot_id;
    407     HuffmanCodeTable* huff_table;
    408     if (index & 0x10) {
    409       index -= 0x10;
    410       huff_table = &state->ac_huff_table[index];
    411     } else {
    412       huff_table = &state->dc_huff_table[index];
    413     }
    414     // TODO(eustas): cache
    415     huff_table->InitDepths(127);
    416     if (!BuildHuffmanCodeTable(huff, huff_table)) {
    417       return false;
    418     }
    419     huff_table->initialized = true;
    420     size_t total_count = 0;
    421     size_t max_length = 0;
    422     for (size_t i = 0; i < huff.counts.size(); ++i) {
    423       if (huff.counts[i] != 0) {
    424         max_length = i;
    425       }
    426       total_count += huff.counts[i];
    427     }
    428     --total_count;
    429     data[pos++] = huff.slot_id;
    430     for (size_t i = 1; i <= kJpegHuffmanMaxBitLength; ++i) {
    431       data[pos++] = (i == max_length ? huff.counts[i] - 1 : huff.counts[i]);
    432     }
    433     for (size_t i = 0; i < total_count; ++i) {
    434       data[pos++] = huff.values[i];
    435     }
    436     if (huff.is_last) break;
    437   }
    438   return true;
    439 }
    440 
    441 bool EncodeDQT(const JPEGData& jpg, SerializationState* state) {
    442   int marker_len = 2;
    443   for (size_t i = state->dqt_index; i < jpg.quant.size(); ++i) {
    444     const JPEGQuantTable& table = jpg.quant[i];
    445     marker_len += 1 + (table.precision ? 2 : 1) * kDCTBlockSize;
    446     if (table.is_last) break;
    447   }
    448   state->output_queue.emplace_back(marker_len + 2);
    449   uint8_t* data = state->output_queue.back().buffer->data();
    450   size_t pos = 0;
    451   data[pos++] = 0xFF;
    452   data[pos++] = 0xDB;
    453   data[pos++] = marker_len >> 8u;
    454   data[pos++] = marker_len & 0xFFu;
    455   while (true) {
    456     const size_t idx = state->dqt_index++;
    457     if (idx >= jpg.quant.size()) {
    458       return false;  // corrupt input
    459     }
    460     const JPEGQuantTable& table = jpg.quant[idx];
    461     data[pos++] = (table.precision << 4u) + table.index;
    462     for (size_t i = 0; i < kDCTBlockSize; ++i) {
    463       int val_idx = kJPEGNaturalOrder[i];
    464       int val = table.values[val_idx];
    465       if (table.precision) {
    466         data[pos++] = val >> 8u;
    467       }
    468       data[pos++] = val & 0xFFu;
    469     }
    470     if (table.is_last) break;
    471   }
    472   return true;
    473 }
    474 
    475 bool EncodeDRI(const JPEGData& jpg, SerializationState* state) {
    476   state->seen_dri_marker = true;
    477   OutputChunk dri_marker = {0xFF,
    478                             0xDD,
    479                             0,
    480                             4,
    481                             static_cast<uint8_t>(jpg.restart_interval >> 8),
    482                             static_cast<uint8_t>(jpg.restart_interval & 0xFF)};
    483   state->output_queue.push_back(std::move(dri_marker));
    484   return true;
    485 }
    486 
    487 bool EncodeRestart(uint8_t marker, SerializationState* state) {
    488   state->output_queue.push_back(OutputChunk({0xFF, marker}));
    489   return true;
    490 }
    491 
    492 bool EncodeAPP(const JPEGData& jpg, uint8_t marker, SerializationState* state) {
    493   // TODO(eustas): check that marker corresponds to payload?
    494   (void)marker;
    495 
    496   size_t app_index = state->app_index++;
    497   if (app_index >= jpg.app_data.size()) return false;
    498   state->output_queue.push_back(OutputChunk({0xFF}));
    499   state->output_queue.emplace_back(jpg.app_data[app_index]);
    500   return true;
    501 }
    502 
    503 bool EncodeCOM(const JPEGData& jpg, SerializationState* state) {
    504   size_t com_index = state->com_index++;
    505   if (com_index >= jpg.com_data.size()) return false;
    506   state->output_queue.push_back(OutputChunk({0xFF}));
    507   state->output_queue.emplace_back(jpg.com_data[com_index]);
    508   return true;
    509 }
    510 
    511 bool EncodeInterMarkerData(const JPEGData& jpg, SerializationState* state) {
    512   size_t index = state->data_index++;
    513   if (index >= jpg.inter_marker_data.size()) return false;
    514   state->output_queue.emplace_back(jpg.inter_marker_data[index]);
    515   return true;
    516 }
    517 
    518 bool EncodeDCTBlockSequential(const coeff_t* coeffs, HuffmanCodeTable* dc_huff,
    519                               HuffmanCodeTable* ac_huff, int num_zero_runs,
    520                               coeff_t* last_dc_coeff, JpegBitWriter* bw) {
    521   coeff_t temp2;
    522   coeff_t temp;
    523   coeff_t litmus = 0;
    524   temp2 = coeffs[0];
    525   temp = temp2 - *last_dc_coeff;
    526   *last_dc_coeff = temp2;
    527   temp2 = temp >> (8 * sizeof(coeff_t) - 1);
    528   temp += temp2;
    529   temp2 ^= temp;
    530 
    531   int dc_nbits = (temp2 == 0) ? 0 : (FloorLog2Nonzero<uint32_t>(temp2) + 1);
    532   WriteSymbol(dc_nbits, dc_huff, bw);
    533 #if JXL_FALSE
    534   // If the input is corrupt, this could be triggered. Checking is
    535   // costly though, so it makes more sense to avoid this branch.
    536   // (producing a corrupt JPEG when the input is corrupt, instead
    537   // of catching it and returning error)
    538   if (dc_nbits >= 12) return false;
    539 #endif
    540   if (dc_nbits) {
    541     WriteBits(bw, dc_nbits, temp & ((1u << dc_nbits) - 1));
    542   }
    543   int16_t r = 0;
    544 
    545   for (size_t i = 1; i < 64; i++) {
    546     temp = coeffs[kJPEGNaturalOrder[i]];
    547     if (temp == 0) {
    548       r++;
    549     } else {
    550       temp2 = temp >> (8 * sizeof(coeff_t) - 1);
    551       temp += temp2;
    552       temp2 ^= temp;
    553       if (JXL_UNLIKELY(r > 15)) {
    554         WriteSymbol(0xf0, ac_huff, bw);
    555         r -= 16;
    556         if (r > 15) {
    557           WriteSymbol(0xf0, ac_huff, bw);
    558           r -= 16;
    559         }
    560         if (r > 15) {
    561           WriteSymbol(0xf0, ac_huff, bw);
    562           r -= 16;
    563         }
    564       }
    565       litmus |= temp2;
    566       int ac_nbits =
    567           FloorLog2Nonzero<uint32_t>(static_cast<uint16_t>(temp2)) + 1;
    568       int symbol = (r << 4u) + ac_nbits;
    569       WriteSymbolBits(symbol, ac_huff, bw, ac_nbits,
    570                       temp & ((1 << ac_nbits) - 1));
    571       r = 0;
    572     }
    573   }
    574 
    575   for (int i = 0; i < num_zero_runs; ++i) {
    576     WriteSymbol(0xf0, ac_huff, bw);
    577     r -= 16;
    578   }
    579   if (r > 0) {
    580     WriteSymbol(0, ac_huff, bw);
    581   }
    582   return (litmus >= 0);
    583 }
    584 
    585 bool EncodeDCTBlockProgressive(const coeff_t* coeffs, HuffmanCodeTable* dc_huff,
    586                                HuffmanCodeTable* ac_huff, int Ss, int Se,
    587                                int Al, int num_zero_runs,
    588                                DCTCodingState* coding_state,
    589                                coeff_t* last_dc_coeff, JpegBitWriter* bw) {
    590   bool eob_run_allowed = Ss > 0;
    591   coeff_t temp2;
    592   coeff_t temp;
    593   if (Ss == 0) {
    594     temp2 = coeffs[0] >> Al;
    595     temp = temp2 - *last_dc_coeff;
    596     *last_dc_coeff = temp2;
    597     temp2 = temp;
    598     if (temp < 0) {
    599       temp = -temp;
    600       if (temp < 0) return false;
    601       temp2--;
    602     }
    603     int nbits = (temp == 0) ? 0 : (FloorLog2Nonzero<uint32_t>(temp) + 1);
    604     WriteSymbol(nbits, dc_huff, bw);
    605     if (nbits) {
    606       WriteBits(bw, nbits, temp2 & ((1 << nbits) - 1));
    607     }
    608     ++Ss;
    609   }
    610   if (Ss > Se) {
    611     return true;
    612   }
    613   int r = 0;
    614   for (int k = Ss; k <= Se; ++k) {
    615     temp = coeffs[kJPEGNaturalOrder[k]];
    616     if (temp == 0) {
    617       r++;
    618       continue;
    619     }
    620     if (temp < 0) {
    621       temp = -temp;
    622       if (temp < 0) return false;
    623       temp >>= Al;
    624       temp2 = ~temp;
    625     } else {
    626       temp >>= Al;
    627       temp2 = temp;
    628     }
    629     if (temp == 0) {
    630       r++;
    631       continue;
    632     }
    633     Flush(coding_state, bw);
    634     while (r > 15) {
    635       WriteSymbol(0xf0, ac_huff, bw);
    636       r -= 16;
    637     }
    638     int nbits = FloorLog2Nonzero<uint32_t>(temp) + 1;
    639     int symbol = (r << 4u) + nbits;
    640     WriteSymbol(symbol, ac_huff, bw);
    641     WriteBits(bw, nbits, temp2 & ((1 << nbits) - 1));
    642     r = 0;
    643   }
    644   if (num_zero_runs > 0) {
    645     Flush(coding_state, bw);
    646     for (int i = 0; i < num_zero_runs; ++i) {
    647       WriteSymbol(0xf0, ac_huff, bw);
    648       r -= 16;
    649     }
    650   }
    651   if (r > 0) {
    652     BufferEndOfBand(coding_state, ac_huff, nullptr, 0, bw);
    653     if (!eob_run_allowed) {
    654       Flush(coding_state, bw);
    655     }
    656   }
    657   return true;
    658 }
    659 
    660 bool EncodeRefinementBits(const coeff_t* coeffs, HuffmanCodeTable* ac_huff,
    661                           int Ss, int Se, int Al, DCTCodingState* coding_state,
    662                           JpegBitWriter* bw) {
    663   bool eob_run_allowed = Ss > 0;
    664   if (Ss == 0) {
    665     // Emit next bit of DC component.
    666     WriteBits(bw, 1, (coeffs[0] >> Al) & 1);
    667     ++Ss;
    668   }
    669   if (Ss > Se) {
    670     return true;
    671   }
    672   int abs_values[kDCTBlockSize];
    673   int eob = 0;
    674   for (int k = Ss; k <= Se; k++) {
    675     const coeff_t abs_val = std::abs(coeffs[kJPEGNaturalOrder[k]]);
    676     abs_values[k] = abs_val >> Al;
    677     if (abs_values[k] == 1) {
    678       eob = k;
    679     }
    680   }
    681   int r = 0;
    682   int refinement_bits[kDCTBlockSize];
    683   size_t refinement_bits_count = 0;
    684   for (int k = Ss; k <= Se; k++) {
    685     if (abs_values[k] == 0) {
    686       r++;
    687       continue;
    688     }
    689     while (r > 15 && k <= eob) {
    690       Flush(coding_state, bw);
    691       WriteSymbol(0xf0, ac_huff, bw);
    692       r -= 16;
    693       for (size_t i = 0; i < refinement_bits_count; ++i) {
    694         WriteBits(bw, 1, refinement_bits[i]);
    695       }
    696       refinement_bits_count = 0;
    697     }
    698     if (abs_values[k] > 1) {
    699       refinement_bits[refinement_bits_count++] = abs_values[k] & 1u;
    700       continue;
    701     }
    702     Flush(coding_state, bw);
    703     int symbol = (r << 4u) + 1;
    704     int new_non_zero_bit = (coeffs[kJPEGNaturalOrder[k]] < 0) ? 0 : 1;
    705     WriteSymbol(symbol, ac_huff, bw);
    706     WriteBits(bw, 1, new_non_zero_bit);
    707     for (size_t i = 0; i < refinement_bits_count; ++i) {
    708       WriteBits(bw, 1, refinement_bits[i]);
    709     }
    710     refinement_bits_count = 0;
    711     r = 0;
    712   }
    713   if (r > 0 || refinement_bits_count) {
    714     BufferEndOfBand(coding_state, ac_huff, refinement_bits,
    715                     refinement_bits_count, bw);
    716     if (!eob_run_allowed) {
    717       Flush(coding_state, bw);
    718     }
    719   }
    720   return true;
    721 }
    722 
    723 template <int kMode>
    724 SerializationStatus JXL_NOINLINE DoEncodeScan(const JPEGData& jpg,
    725                                               SerializationState* state) {
    726   const JPEGScanInfo& scan_info = jpg.scan_info[state->scan_index];
    727   EncodeScanState& ss = state->scan_state;
    728 
    729   const int restart_interval =
    730       state->seen_dri_marker ? jpg.restart_interval : 0;
    731 
    732   const auto get_next_extra_zero_run_index = [&ss, &scan_info]() -> int {
    733     if (ss.extra_zero_runs_pos < scan_info.extra_zero_runs.size()) {
    734       return scan_info.extra_zero_runs[ss.extra_zero_runs_pos].block_idx;
    735     } else {
    736       return -1;
    737     }
    738   };
    739 
    740   const auto get_next_reset_point = [&ss, &scan_info]() -> int {
    741     if (ss.next_reset_point_pos < scan_info.reset_points.size()) {
    742       return scan_info.reset_points[ss.next_reset_point_pos++];
    743     } else {
    744       return -1;
    745     }
    746   };
    747 
    748   if (ss.stage == EncodeScanState::HEAD) {
    749     if (!EncodeSOS(jpg, scan_info, state)) return SerializationStatus::ERROR;
    750     JpegBitWriterInit(&ss.bw, &state->output_queue);
    751     DCTCodingStateInit(&ss.coding_state);
    752     ss.restarts_to_go = restart_interval;
    753     ss.next_restart_marker = 0;
    754     ss.block_scan_index = 0;
    755     ss.extra_zero_runs_pos = 0;
    756     ss.next_extra_zero_run_index = get_next_extra_zero_run_index();
    757     ss.next_reset_point_pos = 0;
    758     ss.next_reset_point = get_next_reset_point();
    759     ss.mcu_y = 0;
    760     memset(ss.last_dc_coeff, 0, sizeof(ss.last_dc_coeff));
    761     ss.stage = EncodeScanState::BODY;
    762   }
    763   JpegBitWriter* bw = &ss.bw;
    764   DCTCodingState* coding_state = &ss.coding_state;
    765 
    766   JXL_DASSERT(ss.stage == EncodeScanState::BODY);
    767 
    768   // "Non-interleaved" means color data comes in separate scans, in other words
    769   // each scan can contain only one color component.
    770   const bool is_interleaved = (scan_info.num_components > 1);
    771   int MCUs_per_row = 0;
    772   int MCU_rows = 0;
    773   jpg.CalculateMcuSize(scan_info, &MCUs_per_row, &MCU_rows);
    774   const bool is_progressive = state->is_progressive;
    775   const int Al = is_progressive ? scan_info.Al : 0;
    776   const int Ss = is_progressive ? scan_info.Ss : 0;
    777   const int Se = is_progressive ? scan_info.Se : 63;
    778 
    779   // DC-only is defined by [0..0] spectral range.
    780   const bool want_ac = ((Ss != 0) || (Se != 0));
    781   const bool want_dc = (Ss == 0);
    782   // TODO(user): support streaming decoding again.
    783   const bool complete_ac = true;
    784   const bool has_ac = true;
    785   if (want_ac && !has_ac) return SerializationStatus::NEEDS_MORE_INPUT;
    786 
    787   // |has_ac| implies |complete_dc| but not vice versa; for the sake of
    788   // simplicity we pretend they are equal, because they are separated by just a
    789   // few bytes of input.
    790   const bool complete_dc = has_ac;
    791   const bool complete = want_ac ? complete_ac : complete_dc;
    792   // When "incomplete" |ac_dc| tracks information about current ("incomplete")
    793   // band parsing progress.
    794 
    795   // FIXME: Is this always complete?
    796   // const int last_mcu_y =
    797   //     complete ? MCU_rows : parsing_state.internal->ac_dc.next_mcu_y *
    798   //     v_group;
    799   (void)complete;
    800   const int last_mcu_y = complete ? MCU_rows : 0;
    801 
    802   for (; ss.mcu_y < last_mcu_y; ++ss.mcu_y) {
    803     for (int mcu_x = 0; mcu_x < MCUs_per_row; ++mcu_x) {
    804       // Possibly emit a restart marker.
    805       if (restart_interval > 0 && ss.restarts_to_go == 0) {
    806         Flush(coding_state, bw);
    807         if (!JumpToByteBoundary(bw, &state->pad_bits, state->pad_bits_end)) {
    808           return SerializationStatus::ERROR;
    809         }
    810         EmitMarker(bw, 0xD0 + ss.next_restart_marker);
    811         ss.next_restart_marker += 1;
    812         ss.next_restart_marker &= 0x7;
    813         ss.restarts_to_go = restart_interval;
    814         memset(ss.last_dc_coeff, 0, sizeof(ss.last_dc_coeff));
    815       }
    816 
    817       // Encode one MCU
    818       for (size_t i = 0; i < scan_info.num_components; ++i) {
    819         const JPEGComponentScanInfo& si = scan_info.components[i];
    820         const JPEGComponent& c = jpg.components[si.comp_idx];
    821         size_t dc_tbl_idx = si.dc_tbl_idx;
    822         size_t ac_tbl_idx = si.ac_tbl_idx;
    823         HuffmanCodeTable* dc_huff = &state->dc_huff_table[dc_tbl_idx];
    824         HuffmanCodeTable* ac_huff = &state->ac_huff_table[ac_tbl_idx];
    825         if (want_dc && !dc_huff->initialized) {
    826           return SerializationStatus::ERROR;
    827         }
    828         if (want_ac && !ac_huff->initialized) {
    829           return SerializationStatus::ERROR;
    830         }
    831         int n_blocks_y = is_interleaved ? c.v_samp_factor : 1;
    832         int n_blocks_x = is_interleaved ? c.h_samp_factor : 1;
    833         for (int iy = 0; iy < n_blocks_y; ++iy) {
    834           for (int ix = 0; ix < n_blocks_x; ++ix) {
    835             int block_y = ss.mcu_y * n_blocks_y + iy;
    836             int block_x = mcu_x * n_blocks_x + ix;
    837             int block_idx = block_y * c.width_in_blocks + block_x;
    838             if (ss.block_scan_index == ss.next_reset_point) {
    839               Flush(coding_state, bw);
    840               ss.next_reset_point = get_next_reset_point();
    841             }
    842             int num_zero_runs = 0;
    843             if (ss.block_scan_index == ss.next_extra_zero_run_index) {
    844               num_zero_runs = scan_info.extra_zero_runs[ss.extra_zero_runs_pos]
    845                                   .num_extra_zero_runs;
    846               ++ss.extra_zero_runs_pos;
    847               ss.next_extra_zero_run_index = get_next_extra_zero_run_index();
    848             }
    849             const coeff_t* coeffs = &c.coeffs[block_idx << 6];
    850             bool ok;
    851             // compressed size per block cannot be more than 512 bytes
    852             Reserve(bw, 512);
    853             if (kMode == 0) {
    854               ok = EncodeDCTBlockSequential(coeffs, dc_huff, ac_huff,
    855                                             num_zero_runs,
    856                                             ss.last_dc_coeff + si.comp_idx, bw);
    857             } else if (kMode == 1) {
    858               ok = EncodeDCTBlockProgressive(
    859                   coeffs, dc_huff, ac_huff, Ss, Se, Al, num_zero_runs,
    860                   coding_state, ss.last_dc_coeff + si.comp_idx, bw);
    861             } else {
    862               ok = EncodeRefinementBits(coeffs, ac_huff, Ss, Se, Al,
    863                                         coding_state, bw);
    864             }
    865             if (!ok) return SerializationStatus::ERROR;
    866             ++ss.block_scan_index;
    867           }
    868         }
    869       }
    870       --ss.restarts_to_go;
    871     }
    872   }
    873   if (ss.mcu_y < MCU_rows) {
    874     if (!bw->healthy) return SerializationStatus::ERROR;
    875     return SerializationStatus::NEEDS_MORE_INPUT;
    876   }
    877   Flush(coding_state, bw);
    878   if (!JumpToByteBoundary(bw, &state->pad_bits, state->pad_bits_end)) {
    879     return SerializationStatus::ERROR;
    880   }
    881   JpegBitWriterFinish(bw);
    882   ss.stage = EncodeScanState::HEAD;
    883   state->scan_index++;
    884   if (!bw->healthy) return SerializationStatus::ERROR;
    885 
    886   return SerializationStatus::DONE;
    887 }
    888 
    889 SerializationStatus JXL_INLINE EncodeScan(const JPEGData& jpg,
    890                                           SerializationState* state) {
    891   const JPEGScanInfo& scan_info = jpg.scan_info[state->scan_index];
    892   const bool is_progressive = state->is_progressive;
    893   const int Al = is_progressive ? scan_info.Al : 0;
    894   const int Ah = is_progressive ? scan_info.Ah : 0;
    895   const int Ss = is_progressive ? scan_info.Ss : 0;
    896   const int Se = is_progressive ? scan_info.Se : 63;
    897   const bool need_sequential =
    898       !is_progressive || (Ah == 0 && Al == 0 && Ss == 0 && Se == 63);
    899   if (need_sequential) {
    900     return DoEncodeScan<0>(jpg, state);
    901   } else if (Ah == 0) {
    902     return DoEncodeScan<1>(jpg, state);
    903   } else {
    904     return DoEncodeScan<2>(jpg, state);
    905   }
    906 }
    907 
    908 SerializationStatus SerializeSection(uint8_t marker, SerializationState* state,
    909                                      const JPEGData& jpg) {
    910   const auto to_status = [](bool result) {
    911     return result ? SerializationStatus::DONE : SerializationStatus::ERROR;
    912   };
    913   // TODO(eustas): add and use marker enum
    914   switch (marker) {
    915     case 0xC0:
    916     case 0xC1:
    917     case 0xC2:
    918     case 0xC9:
    919     case 0xCA:
    920       return to_status(EncodeSOF(jpg, marker, state));
    921 
    922     case 0xC4:
    923       return to_status(EncodeDHT(jpg, state));
    924 
    925     case 0xD0:
    926     case 0xD1:
    927     case 0xD2:
    928     case 0xD3:
    929     case 0xD4:
    930     case 0xD5:
    931     case 0xD6:
    932     case 0xD7:
    933       return to_status(EncodeRestart(marker, state));
    934 
    935     case 0xD9:
    936       return to_status(EncodeEOI(jpg, state));
    937 
    938     case 0xDA:
    939       return EncodeScan(jpg, state);
    940 
    941     case 0xDB:
    942       return to_status(EncodeDQT(jpg, state));
    943 
    944     case 0xDD:
    945       return to_status(EncodeDRI(jpg, state));
    946 
    947     case 0xE0:
    948     case 0xE1:
    949     case 0xE2:
    950     case 0xE3:
    951     case 0xE4:
    952     case 0xE5:
    953     case 0xE6:
    954     case 0xE7:
    955     case 0xE8:
    956     case 0xE9:
    957     case 0xEA:
    958     case 0xEB:
    959     case 0xEC:
    960     case 0xED:
    961     case 0xEE:
    962     case 0xEF:
    963       return to_status(EncodeAPP(jpg, marker, state));
    964 
    965     case 0xFE:
    966       return to_status(EncodeCOM(jpg, state));
    967 
    968     case 0xFF:
    969       return to_status(EncodeInterMarkerData(jpg, state));
    970 
    971     default:
    972       return SerializationStatus::ERROR;
    973   }
    974 }
    975 
    976 // TODO(veluca): add streaming support again.
    977 Status WriteJpegInternal(const JPEGData& jpg, const JPEGOutput& out,
    978                          SerializationState* ss) {
    979   const auto maybe_push_output = [&]() -> Status {
    980     if (ss->stage != SerializationState::STAGE_ERROR) {
    981       while (!ss->output_queue.empty()) {
    982         auto& chunk = ss->output_queue.front();
    983         size_t num_written = out(chunk.next, chunk.len);
    984         if (num_written == 0 && chunk.len > 0) {
    985           return StatusMessage(Status(StatusCode::kNotEnoughBytes),
    986                                "Failed to write output");
    987         }
    988         chunk.len -= num_written;
    989         if (chunk.len == 0) {
    990           ss->output_queue.pop_front();
    991         }
    992       }
    993     }
    994     return true;
    995   };
    996 
    997   while (true) {
    998     switch (ss->stage) {
    999       case SerializationState::STAGE_INIT: {
   1000         // Valid Brunsli requires, at least, 0xD9 marker.
   1001         // This might happen on corrupted stream, or on unconditioned JPEGData.
   1002         // TODO(eustas): check D9 in the only one and is the last one.
   1003         if (jpg.marker_order.empty()) {
   1004           ss->stage = SerializationState::STAGE_ERROR;
   1005           break;
   1006         }
   1007         ss->dc_huff_table.resize(kMaxHuffmanTables);
   1008         ss->ac_huff_table.resize(kMaxHuffmanTables);
   1009         if (jpg.has_zero_padding_bit) {
   1010           ss->pad_bits = jpg.padding_bits.data();
   1011           ss->pad_bits_end = ss->pad_bits + jpg.padding_bits.size();
   1012         }
   1013 
   1014         EncodeSOI(ss);
   1015         JXL_QUIET_RETURN_IF_ERROR(maybe_push_output());
   1016         ss->stage = SerializationState::STAGE_SERIALIZE_SECTION;
   1017         break;
   1018       }
   1019 
   1020       case SerializationState::STAGE_SERIALIZE_SECTION: {
   1021         if (ss->section_index >= jpg.marker_order.size()) {
   1022           ss->stage = SerializationState::STAGE_DONE;
   1023           break;
   1024         }
   1025         uint8_t marker = jpg.marker_order[ss->section_index];
   1026         SerializationStatus status = SerializeSection(marker, ss, jpg);
   1027         if (status == SerializationStatus::ERROR) {
   1028           JXL_WARNING("Failed to encode marker 0x%.2x", marker);
   1029           ss->stage = SerializationState::STAGE_ERROR;
   1030           break;
   1031         }
   1032         JXL_QUIET_RETURN_IF_ERROR(maybe_push_output());
   1033         if (status == SerializationStatus::NEEDS_MORE_INPUT) {
   1034           return JXL_FAILURE("Incomplete serialization data");
   1035         } else if (status != SerializationStatus::DONE) {
   1036           JXL_DASSERT(false);
   1037           ss->stage = SerializationState::STAGE_ERROR;
   1038           break;
   1039         }
   1040         ++ss->section_index;
   1041         break;
   1042       }
   1043 
   1044       case SerializationState::STAGE_DONE:
   1045         JXL_ASSERT(ss->output_queue.empty());
   1046         if (ss->pad_bits != nullptr && ss->pad_bits != ss->pad_bits_end) {
   1047           return JXL_FAILURE("Invalid number of padding bits.");
   1048         }
   1049         return true;
   1050 
   1051       case SerializationState::STAGE_ERROR:
   1052         return JXL_FAILURE("JPEG serialization error");
   1053     }
   1054   }
   1055 }
   1056 
   1057 }  // namespace
   1058 
   1059 Status WriteJpeg(const JPEGData& jpg, const JPEGOutput& out) {
   1060   auto ss = jxl::make_unique<SerializationState>();
   1061   return WriteJpegInternal(jpg, out, ss.get());
   1062 }
   1063 
   1064 }  // namespace jpeg
   1065 }  // namespace jxl