palette.cc (7367B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/modular/transform/palette.h" 7 8 namespace jxl { 9 10 Status InvPalette(Image &input, uint32_t begin_c, uint32_t nb_colors, 11 uint32_t nb_deltas, Predictor predictor, 12 const weighted::Header &wp_header, ThreadPool *pool) { 13 if (input.nb_meta_channels < 1) { 14 return JXL_FAILURE("Error: Palette transform without palette."); 15 } 16 std::atomic<int> num_errors{0}; 17 int nb = input.channel[0].h; 18 uint32_t c0 = begin_c + 1; 19 if (c0 >= input.channel.size()) { 20 return JXL_FAILURE("Channel is out of range."); 21 } 22 size_t w = input.channel[c0].w; 23 size_t h = input.channel[c0].h; 24 if (nb < 1) return JXL_FAILURE("Corrupted transforms"); 25 for (int i = 1; i < nb; i++) { 26 StatusOr<Channel> channel_or = Channel::Create( 27 w, h, input.channel[c0].hshift, input.channel[c0].vshift); 28 JXL_RETURN_IF_ERROR(channel_or.status()); 29 input.channel.insert(input.channel.begin() + c0 + 1, 30 std::move(channel_or).value()); 31 } 32 const Channel &palette = input.channel[0]; 33 const pixel_type *JXL_RESTRICT p_palette = input.channel[0].Row(0); 34 intptr_t onerow = input.channel[0].plane.PixelsPerRow(); 35 intptr_t onerow_image = input.channel[c0].plane.PixelsPerRow(); 36 const int bit_depth = std::min(input.bitdepth, 24); 37 38 if (w == 0) { 39 // Nothing to do. 40 // Avoid touching "empty" channels with non-zero height. 41 } else if (nb_deltas == 0 && predictor == Predictor::Zero) { 42 if (nb == 1) { 43 JXL_RETURN_IF_ERROR(RunOnPool( 44 pool, 0, h, ThreadPool::NoInit, 45 [&](const uint32_t task, size_t /* thread */) { 46 const size_t y = task; 47 pixel_type *p = input.channel[c0].Row(y); 48 for (size_t x = 0; x < w; x++) { 49 const int index = 50 Clamp1<int>(p[x], 0, static_cast<pixel_type>(palette.w) - 1); 51 p[x] = palette_internal::GetPaletteValue( 52 p_palette, index, /*c=*/0, 53 /*palette_size=*/palette.w, 54 /*onerow=*/onerow, /*bit_depth=*/bit_depth); 55 } 56 }, 57 "UndoChannelPalette")); 58 } else { 59 JXL_RETURN_IF_ERROR(RunOnPool( 60 pool, 0, h, ThreadPool::NoInit, 61 [&](const uint32_t task, size_t /* thread */) { 62 const size_t y = task; 63 std::vector<pixel_type *> p_out(nb); 64 const pixel_type *p_index = input.channel[c0].Row(y); 65 for (int c = 0; c < nb; c++) 66 p_out[c] = input.channel[c0 + c].Row(y); 67 for (size_t x = 0; x < w; x++) { 68 const int index = p_index[x]; 69 for (int c = 0; c < nb; c++) { 70 p_out[c][x] = palette_internal::GetPaletteValue( 71 p_palette, index, /*c=*/c, 72 /*palette_size=*/palette.w, 73 /*onerow=*/onerow, /*bit_depth=*/bit_depth); 74 } 75 } 76 }, 77 "UndoPalette")); 78 } 79 } else { 80 // Parallelized per channel. 81 ImageI indices; 82 ImageI &plane = input.channel[c0].plane; 83 JXL_ASSIGN_OR_RETURN(indices, ImageI::Create(plane.xsize(), plane.ysize())); 84 plane.Swap(indices); 85 if (predictor == Predictor::Weighted) { 86 JXL_RETURN_IF_ERROR(RunOnPool( 87 pool, 0, nb, ThreadPool::NoInit, 88 [&](const uint32_t c, size_t /* thread */) { 89 Channel &channel = input.channel[c0 + c]; 90 weighted::State wp_state(wp_header, channel.w, channel.h); 91 for (size_t y = 0; y < channel.h; y++) { 92 pixel_type *JXL_RESTRICT p = channel.Row(y); 93 const pixel_type *JXL_RESTRICT idx = indices.Row(y); 94 for (size_t x = 0; x < channel.w; x++) { 95 int index = idx[x]; 96 pixel_type_w val = 0; 97 const pixel_type palette_entry = 98 palette_internal::GetPaletteValue( 99 p_palette, index, /*c=*/c, 100 /*palette_size=*/palette.w, /*onerow=*/onerow, 101 /*bit_depth=*/bit_depth); 102 if (index < static_cast<int32_t>(nb_deltas)) { 103 PredictionResult pred = 104 PredictNoTreeWP(channel.w, p + x, onerow_image, x, y, 105 predictor, &wp_state); 106 val = pred.guess + palette_entry; 107 } else { 108 val = palette_entry; 109 } 110 p[x] = val; 111 wp_state.UpdateErrors(p[x], x, y, channel.w); 112 } 113 } 114 }, 115 "UndoDeltaPaletteWP")); 116 } else { 117 JXL_RETURN_IF_ERROR(RunOnPool( 118 pool, 0, nb, ThreadPool::NoInit, 119 [&](const uint32_t c, size_t /* thread */) { 120 Channel &channel = input.channel[c0 + c]; 121 for (size_t y = 0; y < channel.h; y++) { 122 pixel_type *JXL_RESTRICT p = channel.Row(y); 123 const pixel_type *JXL_RESTRICT idx = indices.Row(y); 124 for (size_t x = 0; x < channel.w; x++) { 125 int index = idx[x]; 126 pixel_type_w val = 0; 127 const pixel_type palette_entry = 128 palette_internal::GetPaletteValue( 129 p_palette, index, /*c=*/c, 130 /*palette_size=*/palette.w, 131 /*onerow=*/onerow, /*bit_depth=*/bit_depth); 132 if (index < static_cast<int32_t>(nb_deltas)) { 133 PredictionResult pred = PredictNoTreeNoWP( 134 channel.w, p + x, onerow_image, x, y, predictor); 135 val = pred.guess + palette_entry; 136 } else { 137 val = palette_entry; 138 } 139 p[x] = val; 140 } 141 } 142 }, 143 "UndoDeltaPaletteNoWP")); 144 } 145 } 146 if (c0 >= input.nb_meta_channels) { 147 // Palette was done on normal channels 148 input.nb_meta_channels--; 149 } else { 150 // Palette was done on metachannels 151 JXL_ASSERT(static_cast<int>(input.nb_meta_channels) >= 2 - nb); 152 input.nb_meta_channels -= 2 - nb; 153 JXL_ASSERT(begin_c + nb - 1 < input.nb_meta_channels); 154 } 155 input.channel.erase(input.channel.begin(), input.channel.begin() + 1); 156 return num_errors.load(std::memory_order_relaxed) == 0; 157 } 158 159 Status MetaPalette(Image &input, uint32_t begin_c, uint32_t end_c, 160 uint32_t nb_colors, uint32_t nb_deltas, bool lossy) { 161 JXL_RETURN_IF_ERROR(CheckEqualChannels(input, begin_c, end_c)); 162 163 size_t nb = end_c - begin_c + 1; 164 if (begin_c >= input.nb_meta_channels) { 165 // Palette was done on normal channels 166 input.nb_meta_channels++; 167 } else { 168 // Palette was done on metachannels 169 JXL_ASSERT(end_c < input.nb_meta_channels); 170 // we remove nb-1 metachannels and add one 171 input.nb_meta_channels += 2 - nb; 172 } 173 input.channel.erase(input.channel.begin() + begin_c + 1, 174 input.channel.begin() + end_c + 1); 175 JXL_ASSIGN_OR_RETURN(Channel pch, Channel::Create(nb_colors + nb_deltas, nb)); 176 pch.hshift = -1; 177 pch.vshift = -1; 178 input.channel.insert(input.channel.begin(), std::move(pch)); 179 return true; 180 } 181 182 } // namespace jxl