libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

jxl_from_tree.cc (17347B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include <jxl/cms.h>
      7 #include <jxl/types.h>
      8 #include <stdio.h>
      9 #include <string.h>
     10 
     11 #include <fstream>
     12 #include <iostream>
     13 #include <istream>
     14 #include <string>
     15 #include <unordered_map>
     16 
     17 #include "lib/jxl/codec_in_out.h"
     18 #include "lib/jxl/enc_cache.h"
     19 #include "lib/jxl/enc_fields.h"
     20 #include "lib/jxl/enc_frame.h"
     21 #include "lib/jxl/image.h"
     22 #include "lib/jxl/image_metadata.h"
     23 #include "lib/jxl/modular/encoding/enc_debug_tree.h"
     24 #include "lib/jxl/splines.h"
     25 #include "lib/jxl/test_utils.h"  // TODO(eustas): cut this dependency
     26 #include "tools/file_io.h"
     27 
     28 namespace jpegxl {
     29 namespace tools {
     30 
     31 using ::jxl::BitWriter;
     32 using ::jxl::BlendMode;
     33 using ::jxl::CodecInOut;
     34 using ::jxl::CodecMetadata;
     35 using ::jxl::ColorCorrelationMap;
     36 using ::jxl::ColorEncoding;
     37 using ::jxl::ColorTransform;
     38 using ::jxl::CompressParams;
     39 using ::jxl::FrameDimensions;
     40 using ::jxl::FrameInfo;
     41 using ::jxl::Image3F;
     42 using ::jxl::ImageF;
     43 using ::jxl::PaddedBytes;
     44 using ::jxl::PassesEncoderState;
     45 using ::jxl::Predictor;
     46 using ::jxl::PropertyDecisionNode;
     47 using ::jxl::QuantizedSpline;
     48 using ::jxl::Spline;
     49 using ::jxl::Splines;
     50 using ::jxl::Tree;
     51 
     52 namespace {
     53 struct SplineData {
     54   int32_t quantization_adjustment = 1;
     55   std::vector<Spline> splines;
     56 };
     57 
     58 Splines SplinesFromSplineData(const SplineData& spline_data) {
     59   std::vector<QuantizedSpline> quantized_splines;
     60   std::vector<Spline::Point> starting_points;
     61   quantized_splines.reserve(spline_data.splines.size());
     62   starting_points.reserve(spline_data.splines.size());
     63   for (const Spline& spline : spline_data.splines) {
     64     JXL_CHECK(!spline.control_points.empty());
     65     quantized_splines.emplace_back(spline, spline_data.quantization_adjustment,
     66                                    0.0, 1.0);
     67     starting_points.push_back(spline.control_points.front());
     68   }
     69   return Splines(spline_data.quantization_adjustment,
     70                  std::move(quantized_splines), std::move(starting_points));
     71 }
     72 
     73 template <typename F>
     74 bool ParseNode(F& tok, Tree& tree, SplineData& spline_data,
     75                CompressParams& cparams, size_t& W, size_t& H, CodecInOut& io,
     76                JXL_BOOL& have_next, int& x0, int& y0) {
     77   std::unordered_map<std::string, int> property_map = {
     78       {"c", 0},
     79       {"g", 1},
     80       {"y", 2},
     81       {"x", 3},
     82       {"|N|", 4},
     83       {"|W|", 5},
     84       {"N", 6},
     85       {"W", 7},
     86       {"W-WW-NW+NWW", 8},
     87       {"W+N-NW", 9},
     88       {"W-NW", 10},
     89       {"NW-N", 11},
     90       {"N-NE", 12},
     91       {"N-NN", 13},
     92       {"W-WW", 14},
     93       {"WGH", 15},
     94       {"PrevAbs", 16},
     95       {"Prev", 17},
     96       {"PrevAbsErr", 18},
     97       {"PrevErr", 19},
     98       {"PPrevAbs", 20},
     99       {"PPrev", 21},
    100       {"PPrevAbsErr", 22},
    101       {"PPrevErr", 23},
    102       {"Prev1Abs", 16},
    103       {"Prev1", 17},
    104       {"Prev1AbsErr", 18},
    105       {"Prev1Err", 19},
    106   };
    107   for (size_t i = 0; i < 19; i++) {
    108     std::string name_prefix = "Prev" + std::to_string(i + 1);
    109     property_map[name_prefix + "Abs"] = i * 4 + 16;
    110     property_map[name_prefix] = i * 4 + 17;
    111     property_map[name_prefix + "AbsErr"] = i * 4 + 18;
    112     property_map[name_prefix + "Err"] = i * 4 + 19;
    113   }
    114   static const std::unordered_map<std::string, Predictor> predictor_map = {
    115       {"Set", Predictor::Zero},
    116       {"W", Predictor::Left},
    117       {"N", Predictor::Top},
    118       {"AvgW+N", Predictor::Average0},
    119       {"Select", Predictor::Select},
    120       {"Gradient", Predictor::Gradient},
    121       {"Weighted", Predictor::Weighted},
    122       {"NE", Predictor::TopRight},
    123       {"NW", Predictor::TopLeft},
    124       {"WW", Predictor::LeftLeft},
    125       {"AvgW+NW", Predictor::Average1},
    126       {"AvgN+NW", Predictor::Average2},
    127       {"AvgN+NE", Predictor::Average3},
    128       {"AvgAll", Predictor::Average4},
    129   };
    130   auto t = tok();
    131   if (t == "if") {
    132     // Decision node.
    133     int p;
    134     t = tok();
    135     if (!property_map.count(t)) {
    136       fprintf(stderr, "Unexpected property: %s\n", t.c_str());
    137       return false;
    138     }
    139     p = property_map.at(t);
    140     t = tok();
    141     if (t != ">") {
    142       fprintf(stderr, "Expected >, found %s\n", t.c_str());
    143       return false;
    144     }
    145     t = tok();
    146     size_t num = 0;
    147     int split = std::stoi(t, &num);
    148     if (num != t.size()) {
    149       fprintf(stderr, "Invalid splitval: %s\n", t.c_str());
    150       return false;
    151     }
    152     size_t pos = tree.size();
    153     tree.emplace_back(PropertyDecisionNode::Split(p, split, pos + 1));
    154     JXL_RETURN_IF_ERROR(ParseNode(tok, tree, spline_data, cparams, W, H, io,
    155                                   have_next, x0, y0));
    156     tree[pos].rchild = tree.size();
    157   } else if (t == "-") {
    158     // Leaf
    159     t = tok();
    160     Predictor p;
    161     if (!predictor_map.count(t)) {
    162       fprintf(stderr, "Unexpected predictor: %s\n", t.c_str());
    163       return false;
    164     }
    165     p = predictor_map.at(t);
    166     t = tok();
    167     bool subtract = false;
    168     if (t == "-") {
    169       subtract = true;
    170       t = tok();
    171     } else if (t == "+") {
    172       t = tok();
    173     }
    174     size_t num = 0;
    175     int offset = std::stoi(t, &num);
    176     if (num != t.size()) {
    177       fprintf(stderr, "Invalid offset: %s\n", t.c_str());
    178       return false;
    179     }
    180     if (subtract) offset = -offset;
    181     tree.emplace_back(PropertyDecisionNode::Leaf(p, offset));
    182     return true;
    183   } else if (t == "Width") {
    184     t = tok();
    185     size_t num = 0;
    186     W = std::stoul(t, &num);
    187     if (num != t.size()) {
    188       fprintf(stderr, "Invalid width: %s\n", t.c_str());
    189       return false;
    190     }
    191   } else if (t == "Height") {
    192     t = tok();
    193     size_t num = 0;
    194     H = std::stoul(t, &num);
    195     if (num != t.size()) {
    196       fprintf(stderr, "Invalid height: %s\n", t.c_str());
    197       return false;
    198     }
    199   } else if (t == "/*") {
    200     t = tok();
    201     while (t != "*/" && t != "") t = tok();
    202   } else if (t == "Squeeze") {
    203     cparams.responsive = true;
    204   } else if (t == "GroupShift") {
    205     t = tok();
    206     size_t num = 0;
    207     cparams.modular_group_size_shift = std::stoul(t, &num);
    208     if (num != t.size()) {
    209       fprintf(stderr, "Invalid GroupShift: %s\n", t.c_str());
    210       return false;
    211     }
    212   } else if (t == "XYB") {
    213     cparams.color_transform = ColorTransform::kXYB;
    214   } else if (t == "CbYCr") {
    215     cparams.color_transform = ColorTransform::kYCbCr;
    216   } else if (t == "HiddenChannel") {
    217     t = tok();
    218     size_t num = 0;
    219     cparams.move_to_front_from_channel = -1 - std::stoul(t, &num);
    220     if (num != t.size() || num > 16) {
    221       fprintf(stderr, "Invalid HiddenChannel (max 16): %s\n", t.c_str());
    222       return false;
    223     }
    224   } else if (t == "RCT") {
    225     t = tok();
    226     size_t num = 0;
    227     cparams.colorspace = std::stoul(t, &num);
    228     if (num != t.size()) {
    229       fprintf(stderr, "Invalid RCT: %s\n", t.c_str());
    230       return false;
    231     }
    232   } else if (t == "Orientation") {
    233     t = tok();
    234     size_t num = 0;
    235     io.metadata.m.orientation = std::stoul(t, &num);
    236     if (num != t.size()) {
    237       fprintf(stderr, "Invalid Orientation: %s\n", t.c_str());
    238       return false;
    239     }
    240   } else if (t == "Alpha") {
    241     io.metadata.m.SetAlphaBits(io.metadata.m.bit_depth.bits_per_sample);
    242     JXL_ASSIGN_OR_RETURN(ImageF alpha, ImageF::Create(W, H));
    243     io.frames[0].SetAlpha(std::move(alpha));
    244   } else if (t == "Bitdepth") {
    245     t = tok();
    246     size_t num = 0;
    247     io.metadata.m.bit_depth.bits_per_sample = std::stoul(t, &num);
    248     if (num != t.size()) {
    249       fprintf(stderr, "Invalid Bitdepth: %s\n", t.c_str());
    250       return false;
    251     }
    252   } else if (t == "FloatExpBits") {
    253     t = tok();
    254     size_t num = 0;
    255     io.metadata.m.bit_depth.floating_point_sample = true;
    256     io.metadata.m.bit_depth.exponent_bits_per_sample = std::stoul(t, &num);
    257     if (num != t.size()) {
    258       fprintf(stderr, "Invalid FloatExpBits: %s\n", t.c_str());
    259       return false;
    260     }
    261   } else if (t == "FramePos") {
    262     t = tok();
    263     size_t num = 0;
    264     x0 = std::stoi(t, &num);
    265     if (num != t.size()) {
    266       fprintf(stderr, "Invalid FramePos x0: %s\n", t.c_str());
    267       return false;
    268     }
    269     t = tok();
    270     y0 = std::stoi(t, &num);
    271     if (num != t.size()) {
    272       fprintf(stderr, "Invalid FramePos y0: %s\n", t.c_str());
    273       return false;
    274     }
    275   } else if (t == "NotLast") {
    276     have_next = JXL_TRUE;
    277   } else if (t == "Upsample") {
    278     t = tok();
    279     size_t num = 0;
    280     cparams.resampling = std::stoul(t, &num);
    281     if (num != t.size() ||
    282         (cparams.resampling != 1 && cparams.resampling != 2 &&
    283          cparams.resampling != 4 && cparams.resampling != 8)) {
    284       fprintf(stderr, "Invalid Upsample: %s\n", t.c_str());
    285       return false;
    286     }
    287   } else if (t == "Upsample_EC") {
    288     t = tok();
    289     size_t num = 0;
    290     cparams.ec_resampling = std::stoul(t, &num);
    291     if (num != t.size() ||
    292         (cparams.ec_resampling != 1 && cparams.ec_resampling != 2 &&
    293          cparams.ec_resampling != 4 && cparams.ec_resampling != 8)) {
    294       fprintf(stderr, "Invalid Upsample_EC: %s\n", t.c_str());
    295       return false;
    296     }
    297   } else if (t == "Animation") {
    298     io.metadata.m.have_animation = true;
    299     io.metadata.m.animation.tps_numerator = 1000;
    300     io.metadata.m.animation.tps_denominator = 1;
    301     io.frames[0].duration = 100;
    302   } else if (t == "AnimationFPS") {
    303     t = tok();
    304     size_t num = 0;
    305     io.metadata.m.animation.tps_numerator = std::stoul(t, &num);
    306     if (num != t.size()) {
    307       fprintf(stderr, "Invalid numerator: %s\n", t.c_str());
    308       return false;
    309     }
    310     t = tok();
    311     num = 0;
    312     io.metadata.m.animation.tps_denominator = std::stoul(t, &num);
    313     if (num != t.size()) {
    314       fprintf(stderr, "Invalid denominator: %s\n", t.c_str());
    315       return false;
    316     }
    317   } else if (t == "Duration") {
    318     t = tok();
    319     size_t num = 0;
    320     io.frames[0].duration = std::stoul(t, &num);
    321     if (num != t.size()) {
    322       fprintf(stderr, "Invalid Duration: %s\n", t.c_str());
    323       return false;
    324     }
    325   } else if (t == "BlendMode") {
    326     t = tok();
    327     if (t == "kAdd") {
    328       io.frames[0].blendmode = BlendMode::kAdd;
    329     } else if (t == "kReplace") {
    330       io.frames[0].blendmode = BlendMode::kReplace;
    331     } else if (t == "kBlend") {
    332       io.frames[0].blendmode = BlendMode::kBlend;
    333     } else if (t == "kAlphaWeightedAdd") {
    334       io.frames[0].blendmode = BlendMode::kAlphaWeightedAdd;
    335     } else if (t == "kMul") {
    336       io.frames[0].blendmode = BlendMode::kMul;
    337     } else {
    338       fprintf(stderr, "Invalid BlendMode: %s\n", t.c_str());
    339       return false;
    340     }
    341   } else if (t == "SplineQuantizationAdjustment") {
    342     t = tok();
    343     size_t num = 0;
    344     spline_data.quantization_adjustment = std::stoul(t, &num);
    345     if (num != t.size()) {
    346       fprintf(stderr, "Invalid SplineQuantizationAdjustment: %s\n", t.c_str());
    347       return false;
    348     }
    349   } else if (t == "Spline") {
    350     Spline spline;
    351     const auto ParseFloat = [&t, &tok](float& output) {
    352       t = tok();
    353       size_t num = 0;
    354       output = std::stof(t, &num);
    355       if (num != t.size()) {
    356         fprintf(stderr, "Invalid spline data: %s\n", t.c_str());
    357         return false;
    358       }
    359       return true;
    360     };
    361     for (auto& dct : spline.color_dct) {
    362       for (float& coefficient : dct) {
    363         JXL_RETURN_IF_ERROR(ParseFloat(coefficient));
    364       }
    365     }
    366     for (float& coefficient : spline.sigma_dct) {
    367       JXL_RETURN_IF_ERROR(ParseFloat(coefficient));
    368     }
    369 
    370     while (true) {
    371       t = tok();
    372       if (t == "EndSpline") break;
    373       size_t num = 0;
    374       Spline::Point point;
    375       point.x = std::stof(t, &num);
    376       bool ok_x = num == t.size();
    377       auto t_y = tok();
    378       point.y = std::stof(t_y, &num);
    379       if (!ok_x || num != t_y.size()) {
    380         fprintf(stderr, "Invalid spline control point: %s %s\n", t.c_str(),
    381                 t_y.c_str());
    382         return false;
    383       }
    384       spline.control_points.push_back(point);
    385     }
    386 
    387     if (spline.control_points.empty()) {
    388       fprintf(stderr, "Spline with no control point\n");
    389       return false;
    390     }
    391 
    392     spline_data.splines.push_back(std::move(spline));
    393   } else if (t == "Gaborish") {
    394     cparams.gaborish = jxl::Override::kOn;
    395   } else if (t == "DeltaPalette") {
    396     cparams.lossy_palette = true;
    397     cparams.palette_colors = 0;
    398   } else if (t == "EPF") {
    399     t = tok();
    400     size_t num = 0;
    401     cparams.epf = std::stoul(t, &num);
    402     if (num != t.size() || cparams.epf > 3) {
    403       fprintf(stderr, "Invalid EPF: %s\n", t.c_str());
    404       return false;
    405     }
    406   } else if (t == "Noise") {
    407     cparams.manual_noise.resize(8);
    408     for (size_t i = 0; i < 8; i++) {
    409       t = tok();
    410       size_t num = 0;
    411       cparams.manual_noise[i] = std::stof(t, &num);
    412       if (num != t.size()) {
    413         fprintf(stderr, "Invalid noise entry: %s\n", t.c_str());
    414         return false;
    415       }
    416     }
    417   } else if (t == "XYBFactors") {
    418     cparams.manual_xyb_factors.resize(3);
    419     for (size_t i = 0; i < 3; i++) {
    420       t = tok();
    421       size_t num = 0;
    422       cparams.manual_xyb_factors[i] = std::stof(t, &num);
    423       if (num != t.size()) {
    424         fprintf(stderr, "Invalid XYB factor: %s\n", t.c_str());
    425         return false;
    426       }
    427     }
    428   } else {
    429     fprintf(stderr, "Unexpected node type: %s\n", t.c_str());
    430     return false;
    431   }
    432   JXL_RETURN_IF_ERROR(
    433       ParseNode(tok, tree, spline_data, cparams, W, H, io, have_next, x0, y0));
    434   return true;
    435 }
    436 }  // namespace
    437 
    438 int JxlFromTree(const char* in, const char* out, const char* tree_out) {
    439   Tree tree;
    440   SplineData spline_data;
    441   CompressParams cparams = {};
    442   size_t width = 1024;
    443   size_t height = 1024;
    444   int x0 = 0;
    445   int y0 = 0;
    446   cparams.SetLossless();
    447   cparams.responsive = JXL_FALSE;
    448   cparams.resampling = 1;
    449   cparams.ec_resampling = 1;
    450   cparams.modular_group_size_shift = 3;
    451   cparams.colorspace = 0;
    452   CodecInOut io;
    453   int have_next = JXL_FALSE;
    454 
    455   std::istream* f = &std::cin;
    456   std::ifstream file;
    457 
    458   if (strcmp(in, "-") > 0) {
    459     file.open(in, std::ifstream::in);
    460     f = &file;
    461   }
    462 
    463   auto tok = [&f]() {
    464     std::string out;
    465     *f >> out;
    466     return out;
    467   };
    468   if (!ParseNode(tok, tree, spline_data, cparams, width, height, io, have_next,
    469                  x0, y0)) {
    470     return 1;
    471   }
    472 
    473   if (tree_out) {
    474     PrintTree(tree, tree_out);
    475   }
    476   JXL_ASSIGN_OR_RETURN(
    477       Image3F image,
    478       Image3F::Create(width * cparams.resampling, height * cparams.resampling));
    479   io.SetFromImage(std::move(image), ColorEncoding::SRGB());
    480   io.SetSize((width + x0) * cparams.resampling,
    481              (height + y0) * cparams.resampling);
    482   io.metadata.m.color_encoding.DecideIfWantICC(*JxlGetDefaultCms());
    483   cparams.options.zero_tokens = true;
    484   cparams.palette_colors = 0;
    485   cparams.channel_colors_pre_transform_percent = 0;
    486   cparams.channel_colors_percent = 0;
    487   cparams.patches = jxl::Override::kOff;
    488   cparams.already_downsampled = true;
    489   cparams.custom_fixed_tree = tree;
    490   cparams.custom_splines = SplinesFromSplineData(spline_data);
    491   PaddedBytes compressed;
    492 
    493   io.CheckMetadata();
    494   BitWriter writer;
    495 
    496   std::unique_ptr<CodecMetadata> metadata = jxl::make_unique<CodecMetadata>();
    497   *metadata = io.metadata;
    498   JXL_RETURN_IF_ERROR(metadata->size.Set(io.xsize(), io.ysize()));
    499 
    500   metadata->m.xyb_encoded = (cparams.color_transform == ColorTransform::kXYB);
    501   metadata->m.modular_16_bit_buffer_sufficient = false;
    502 
    503   if (cparams.move_to_front_from_channel < -1) {
    504     size_t nch = -1 - cparams.move_to_front_from_channel;
    505     cparams.move_to_front_from_channel = 3 + metadata->m.num_extra_channels;
    506     metadata->m.num_extra_channels += nch;
    507     for (size_t _ = 0; _ < nch; _++) {
    508       metadata->m.extra_channel_info.emplace_back();
    509       auto& eci = metadata->m.extra_channel_info.back();
    510       eci.type = jxl::ExtraChannel::kOptional;
    511       JXL_ASSIGN_OR_DIE(ImageF ch, ImageF::Create(io.xsize(), io.ysize()));
    512       io.frames[0].extra_channels().emplace_back(std::move(ch));
    513     }
    514   }
    515 
    516   JXL_RETURN_IF_ERROR(WriteCodestreamHeaders(metadata.get(), &writer, nullptr));
    517   writer.ZeroPadToByte();
    518 
    519   while (true) {
    520     FrameInfo info;
    521     info.is_last = !FROM_JXL_BOOL(have_next);
    522     if (!info.is_last) info.save_as_reference = 1;
    523 
    524     io.frames[0].origin.x0 = x0;
    525     io.frames[0].origin.y0 = y0;
    526     info.clamp = false;
    527 
    528     JXL_RETURN_IF_ERROR(jxl::EncodeFrame(cparams, info, metadata.get(),
    529                                          io.frames[0], *JxlGetDefaultCms(),
    530                                          nullptr, &writer, nullptr));
    531     if (!have_next) break;
    532     tree.clear();
    533     spline_data.splines.clear();
    534     have_next = JXL_FALSE;
    535     cparams.manual_noise.clear();
    536     if (!ParseNode(tok, tree, spline_data, cparams, width, height, io,
    537                    have_next, x0, y0)) {
    538       return 1;
    539     }
    540     cparams.custom_fixed_tree = tree;
    541     JXL_ASSIGN_OR_RETURN(Image3F image, Image3F::Create(width, height));
    542     io.SetFromImage(std::move(image), ColorEncoding::SRGB());
    543     io.frames[0].blend = true;
    544   }
    545 
    546   compressed = std::move(writer).TakeBytes();
    547 
    548   if (!WriteFile(out, compressed)) {
    549     fprintf(stderr, "Failed to write to \"%s\"\n", out);
    550     return 1;
    551   }
    552 
    553   return 0;
    554 }
    555 }  // namespace tools
    556 }  // namespace jpegxl
    557 
    558 int main(int argc, char** argv) {
    559   if ((argc != 3 && argc != 4) ||
    560       ((strcmp(argv[1], "-") > 0) && !strcmp(argv[1], argv[2]))) {
    561     fprintf(stderr, "Usage: %s tree_in.txt out.jxl [tree_drawing]\n", argv[0]);
    562     return 1;
    563   }
    564   return jpegxl::tools::JxlFromTree(argv[1], argv[2],
    565                                     argc < 4 ? nullptr : argv[3]);
    566 }