libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

enc_noise.cc (13387B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/enc_noise.h"
      7 
      8 #include <stdint.h>
      9 #include <stdlib.h>
     10 
     11 #include <algorithm>
     12 #include <numeric>
     13 #include <utility>
     14 
     15 #include "lib/jxl/base/compiler_specific.h"
     16 #include "lib/jxl/chroma_from_luma.h"
     17 #include "lib/jxl/convolve.h"
     18 #include "lib/jxl/enc_aux_out.h"
     19 #include "lib/jxl/enc_optimize.h"
     20 #include "lib/jxl/image_ops.h"
     21 
     22 namespace jxl {
     23 namespace {
     24 
     25 using OptimizeArray = optimize::Array<double, NoiseParams::kNumNoisePoints>;
     26 
     27 float GetScoreSumsOfAbsoluteDifferences(const Image3F& opsin, const int x,
     28                                         const int y, const int block_size) {
     29   const int small_bl_size_x = 3;
     30   const int small_bl_size_y = 4;
     31   const int kNumSAD =
     32       (block_size - small_bl_size_x) * (block_size - small_bl_size_y);
     33   // block_size x block_size reference pixels
     34   int counter = 0;
     35   const int offset = 2;
     36 
     37   std::vector<float> sad(kNumSAD, 0);
     38   for (int y_bl = 0; y_bl + small_bl_size_y < block_size; ++y_bl) {
     39     for (int x_bl = 0; x_bl + small_bl_size_x < block_size; ++x_bl) {
     40       float sad_sum = 0;
     41       // size of the center patch, we compare all the patches inside window with
     42       // the center one
     43       for (int cy = 0; cy < small_bl_size_y; ++cy) {
     44         for (int cx = 0; cx < small_bl_size_x; ++cx) {
     45           float wnd = 0.5f * (opsin.PlaneRow(1, y + y_bl + cy)[x + x_bl + cx] +
     46                               opsin.PlaneRow(0, y + y_bl + cy)[x + x_bl + cx]);
     47           float center =
     48               0.5f * (opsin.PlaneRow(1, y + offset + cy)[x + offset + cx] +
     49                       opsin.PlaneRow(0, y + offset + cy)[x + offset + cx]);
     50           sad_sum += std::abs(center - wnd);
     51         }
     52       }
     53       sad[counter++] = sad_sum;
     54     }
     55   }
     56   const int kSamples = (kNumSAD) / 2;
     57   // As with ROAD (rank order absolute distance), we keep the smallest half of
     58   // the values in SAD (we use here the more robust patch SAD instead of
     59   // absolute single-pixel differences).
     60   std::sort(sad.begin(), sad.end());
     61   const float total_sad_sum =
     62       std::accumulate(sad.begin(), sad.begin() + kSamples, 0.0f);
     63   return total_sad_sum / kSamples;
     64 }
     65 
     66 class NoiseHistogram {
     67  public:
     68   static constexpr int kBins = 256;
     69 
     70   NoiseHistogram() { std::fill(bins, bins + kBins, 0); }
     71 
     72   void Increment(const float x) { bins[Index(x)] += 1; }
     73   int Get(const float x) const { return bins[Index(x)]; }
     74   int Bin(const size_t bin) const { return bins[bin]; }
     75 
     76   int Mode() const {
     77     size_t max_idx = 0;
     78     for (size_t i = 0; i < kBins; i++) {
     79       if (bins[i] > bins[max_idx]) max_idx = i;
     80     }
     81     return max_idx;
     82   }
     83 
     84   double Quantile(double q01) const {
     85     const int64_t total = std::accumulate(bins, bins + kBins, int64_t{1});
     86     const int64_t target = static_cast<int64_t>(q01 * total);
     87     // Until sum >= target:
     88     int64_t sum = 0;
     89     size_t i = 0;
     90     for (; i < kBins; ++i) {
     91       sum += bins[i];
     92       // Exact match: assume middle of bin i
     93       if (sum == target) {
     94         return i + 0.5;
     95       }
     96       if (sum > target) break;
     97     }
     98 
     99     // Next non-empty bin (in case histogram is sparsely filled)
    100     size_t next = i + 1;
    101     while (next < kBins && bins[next] == 0) {
    102       ++next;
    103     }
    104 
    105     // Linear interpolation according to how far into next we went
    106     const double excess = target - sum;
    107     const double weight_next = bins[Index(next)] / excess;
    108     return ClampX(next * weight_next + i * (1.0 - weight_next));
    109   }
    110 
    111   // Inter-quartile range
    112   double IQR() const { return Quantile(0.75) - Quantile(0.25); }
    113 
    114  private:
    115   template <typename T>
    116   T ClampX(const T x) const {
    117     return std::min(std::max(static_cast<T>(0), x), static_cast<T>(kBins - 1));
    118   }
    119   size_t Index(const float x) const { return ClampX(static_cast<int>(x)); }
    120 
    121   uint32_t bins[kBins];
    122 };
    123 
    124 std::vector<float> GetSADScoresForPatches(const Image3F& opsin,
    125                                           const size_t block_s,
    126                                           const size_t num_bin,
    127                                           NoiseHistogram* sad_histogram) {
    128   std::vector<float> sad_scores(
    129       (opsin.ysize() / block_s) * (opsin.xsize() / block_s), 0.0f);
    130 
    131   int block_index = 0;
    132 
    133   for (size_t y = 0; y + block_s <= opsin.ysize(); y += block_s) {
    134     for (size_t x = 0; x + block_s <= opsin.xsize(); x += block_s) {
    135       float sad_sc = GetScoreSumsOfAbsoluteDifferences(opsin, x, y, block_s);
    136       sad_scores[block_index++] = sad_sc;
    137       sad_histogram->Increment(sad_sc * num_bin);
    138     }
    139   }
    140   return sad_scores;
    141 }
    142 
    143 float GetSADThreshold(const NoiseHistogram& histogram, const int num_bin) {
    144   // Here we assume that the most patches with similar SAD value is a "flat"
    145   // patches. However, some images might contain regular texture part and
    146   // generate second strong peak at the histogram
    147   // TODO(user) handle bimodal and heavy-tailed case
    148   const int mode = histogram.Mode();
    149   return static_cast<float>(mode) / NoiseHistogram::kBins;
    150 }
    151 
    152 // loss = sum asym * (F(x) - nl)^2 + kReg * num_points * sum (w[i] - w[i+1])^2
    153 // where asym = 1 if F(x) < nl, kAsym if F(x) > nl.
    154 struct LossFunction {
    155   explicit LossFunction(std::vector<NoiseLevel> nl0) : nl(std::move(nl0)) {}
    156 
    157   double Compute(const OptimizeArray& w, OptimizeArray* df,
    158                  bool skip_regularization = false) const {
    159     constexpr double kReg = 0.005;
    160     constexpr double kAsym = 1.1;
    161     double loss_function = 0;
    162     for (size_t i = 0; i < w.size(); i++) {
    163       (*df)[i] = 0;
    164     }
    165     for (auto ind : nl) {
    166       std::pair<int, float> pos = IndexAndFrac(ind.intensity);
    167       JXL_DASSERT(pos.first >= 0 && static_cast<size_t>(pos.first) <
    168                                         NoiseParams::kNumNoisePoints - 1);
    169       double low = w[pos.first];
    170       double hi = w[pos.first + 1];
    171       double val = low * (1.0f - pos.second) + hi * pos.second;
    172       double dist = val - ind.noise_level;
    173       if (dist > 0) {
    174         loss_function += kAsym * dist * dist;
    175         (*df)[pos.first] -= kAsym * (1.0f - pos.second) * dist;
    176         (*df)[pos.first + 1] -= kAsym * pos.second * dist;
    177       } else {
    178         loss_function += dist * dist;
    179         (*df)[pos.first] -= (1.0f - pos.second) * dist;
    180         (*df)[pos.first + 1] -= pos.second * dist;
    181       }
    182     }
    183     if (skip_regularization) return loss_function;
    184     for (size_t i = 0; i + 1 < w.size(); i++) {
    185       double diff = w[i] - w[i + 1];
    186       loss_function += kReg * nl.size() * diff * diff;
    187       (*df)[i] -= kReg * diff * nl.size();
    188       (*df)[i + 1] += kReg * diff * nl.size();
    189     }
    190     return loss_function;
    191   }
    192 
    193   std::vector<NoiseLevel> nl;
    194 };
    195 
    196 void OptimizeNoiseParameters(const std::vector<NoiseLevel>& noise_level,
    197                              NoiseParams* noise_params) {
    198   constexpr double kMaxError = 1e-3;
    199   static const double kPrecision = 1e-8;
    200   static const int kMaxIter = 40;
    201 
    202   float avg = 0;
    203   for (const NoiseLevel& nl : noise_level) {
    204     avg += nl.noise_level;
    205   }
    206   avg /= noise_level.size();
    207 
    208   LossFunction loss_function(noise_level);
    209   OptimizeArray parameter_vector;
    210   for (size_t i = 0; i < parameter_vector.size(); i++) {
    211     parameter_vector[i] = avg;
    212   }
    213 
    214   parameter_vector = optimize::OptimizeWithScaledConjugateGradientMethod(
    215       loss_function, parameter_vector, kPrecision, kMaxIter);
    216 
    217   OptimizeArray df = parameter_vector;
    218   float loss = loss_function.Compute(parameter_vector, &df,
    219                                      /*skip_regularization=*/true) /
    220                noise_level.size();
    221 
    222   // Approximation went too badly: escape with no noise at all.
    223   if (loss > kMaxError) {
    224     noise_params->Clear();
    225     return;
    226   }
    227 
    228   for (size_t i = 0; i < parameter_vector.size(); i++) {
    229     noise_params->lut[i] = std::max(parameter_vector[i], 0.0);
    230   }
    231 }
    232 
    233 std::vector<NoiseLevel> GetNoiseLevel(
    234     const Image3F& opsin, const std::vector<float>& texture_strength,
    235     const float threshold, const size_t block_s) {
    236   std::vector<NoiseLevel> noise_level_per_intensity;
    237 
    238   const int filt_size = 1;
    239   static const float kLaplFilter[filt_size * 2 + 1][filt_size * 2 + 1] = {
    240       {-0.25f, -1.0f, -0.25f},
    241       {-1.0f, 5.0f, -1.0f},
    242       {-0.25f, -1.0f, -0.25f},
    243   };
    244 
    245   // The noise model is built based on channel 0.5 * (X+Y) as we notice that it
    246   // is similar to the model 0.5 * (Y-X)
    247   size_t patch_index = 0;
    248 
    249   for (size_t y = 0; y + block_s <= opsin.ysize(); y += block_s) {
    250     for (size_t x = 0; x + block_s <= opsin.xsize(); x += block_s) {
    251       if (texture_strength[patch_index] <= threshold) {
    252         // Calculate mean value
    253         float mean_int = 0;
    254         for (size_t y_bl = 0; y_bl < block_s; ++y_bl) {
    255           for (size_t x_bl = 0; x_bl < block_s; ++x_bl) {
    256             mean_int += 0.5f * (opsin.PlaneRow(1, y + y_bl)[x + x_bl] +
    257                                 opsin.PlaneRow(0, y + y_bl)[x + x_bl]);
    258           }
    259         }
    260         mean_int /= block_s * block_s;
    261 
    262         // Calculate Noise level
    263         float noise_level = 0;
    264         size_t count = 0;
    265         for (size_t y_bl = 0; y_bl < block_s; ++y_bl) {
    266           for (size_t x_bl = 0; x_bl < block_s; ++x_bl) {
    267             float filtered_value = 0;
    268             for (int y_f = -1 * filt_size; y_f <= filt_size; ++y_f) {
    269               if ((static_cast<ssize_t>(y_bl) + y_f) >= 0 &&
    270                   (y_bl + y_f) < block_s) {
    271                 for (int x_f = -1 * filt_size; x_f <= filt_size; ++x_f) {
    272                   if ((static_cast<ssize_t>(x_bl) + x_f) >= 0 &&
    273                       (x_bl + x_f) < block_s) {
    274                     filtered_value +=
    275                         0.5f *
    276                         (opsin.PlaneRow(1, y + y_bl + y_f)[x + x_bl + x_f] +
    277                          opsin.PlaneRow(0, y + y_bl + y_f)[x + x_bl + x_f]) *
    278                         kLaplFilter[y_f + filt_size][x_f + filt_size];
    279                   } else {
    280                     filtered_value +=
    281                         0.5f *
    282                         (opsin.PlaneRow(1, y + y_bl + y_f)[x + x_bl - x_f] +
    283                          opsin.PlaneRow(0, y + y_bl + y_f)[x + x_bl - x_f]) *
    284                         kLaplFilter[y_f + filt_size][x_f + filt_size];
    285                   }
    286                 }
    287               } else {
    288                 for (int x_f = -1 * filt_size; x_f <= filt_size; ++x_f) {
    289                   if ((static_cast<ssize_t>(x_bl) + x_f) >= 0 &&
    290                       (x_bl + x_f) < block_s) {
    291                     filtered_value +=
    292                         0.5f *
    293                         (opsin.PlaneRow(1, y + y_bl - y_f)[x + x_bl + x_f] +
    294                          opsin.PlaneRow(0, y + y_bl - y_f)[x + x_bl + x_f]) *
    295                         kLaplFilter[y_f + filt_size][x_f + filt_size];
    296                   } else {
    297                     filtered_value +=
    298                         0.5f *
    299                         (opsin.PlaneRow(1, y + y_bl - y_f)[x + x_bl - x_f] +
    300                          opsin.PlaneRow(0, y + y_bl - y_f)[x + x_bl - x_f]) *
    301                         kLaplFilter[y_f + filt_size][x_f + filt_size];
    302                   }
    303                 }
    304               }
    305             }
    306             noise_level += std::abs(filtered_value);
    307             ++count;
    308           }
    309         }
    310         noise_level /= count;
    311         NoiseLevel nl;
    312         nl.intensity = mean_int;
    313         nl.noise_level = noise_level;
    314         noise_level_per_intensity.push_back(nl);
    315       }
    316       ++patch_index;
    317     }
    318   }
    319   return noise_level_per_intensity;
    320 }
    321 
    322 void EncodeFloatParam(float val, float precision, BitWriter* writer) {
    323   JXL_ASSERT(val >= 0);
    324   const int absval_quant = static_cast<int>(std::lround(val * precision));
    325   JXL_ASSERT(absval_quant < (1 << 10));
    326   writer->Write(10, absval_quant);
    327 }
    328 
    329 }  // namespace
    330 
    331 Status GetNoiseParameter(const Image3F& opsin, NoiseParams* noise_params,
    332                          float quality_coef) {
    333   // The size of a patch in decoder might be different from encoder's patch
    334   // size.
    335   // For encoder: the patch size should be big enough to estimate
    336   //              noise level, but, at the same time, it should be not too big
    337   //              to be able to estimate intensity value of the patch
    338   const size_t block_s = 8;
    339   const size_t kNumBin = 256;
    340   NoiseHistogram sad_histogram;
    341   std::vector<float> sad_scores =
    342       GetSADScoresForPatches(opsin, block_s, kNumBin, &sad_histogram);
    343   float sad_threshold = GetSADThreshold(sad_histogram, kNumBin);
    344   // If threshold is too large, the image has a strong pattern. This pattern
    345   // fools our model and it will add too much noise. Therefore, we do not add
    346   // noise for such images
    347   if (sad_threshold > 0.15f || sad_threshold <= 0.0f) {
    348     noise_params->Clear();
    349     return false;
    350   }
    351   std::vector<NoiseLevel> nl =
    352       GetNoiseLevel(opsin, sad_scores, sad_threshold, block_s);
    353 
    354   OptimizeNoiseParameters(nl, noise_params);
    355   for (float& i : noise_params->lut) {
    356     i *= quality_coef * 1.4;
    357   }
    358   return noise_params->HasAny();
    359 }
    360 
    361 void EncodeNoise(const NoiseParams& noise_params, BitWriter* writer,
    362                  size_t layer, AuxOut* aux_out) {
    363   JXL_ASSERT(noise_params.HasAny());
    364 
    365   BitWriter::Allotment allotment(writer, NoiseParams::kNumNoisePoints * 16);
    366   for (float i : noise_params.lut) {
    367     EncodeFloatParam(i, kNoisePrecision, writer);
    368   }
    369   allotment.ReclaimAndCharge(writer, layer, aux_out);
    370 }
    371 
    372 }  // namespace jxl