squeeze.cc (18212B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/modular/transform/squeeze.h" 7 8 #include <stdlib.h> 9 10 #include "lib/jxl/base/common.h" 11 #include "lib/jxl/base/data_parallel.h" 12 #include "lib/jxl/base/printf_macros.h" 13 #include "lib/jxl/modular/modular_image.h" 14 #include "lib/jxl/modular/transform/transform.h" 15 #undef HWY_TARGET_INCLUDE 16 #define HWY_TARGET_INCLUDE "lib/jxl/modular/transform/squeeze.cc" 17 #include <hwy/foreach_target.h> 18 #include <hwy/highway.h> 19 20 #include "lib/jxl/simd_util-inl.h" 21 22 HWY_BEFORE_NAMESPACE(); 23 namespace jxl { 24 namespace HWY_NAMESPACE { 25 26 // These templates are not found via ADL. 27 using hwy::HWY_NAMESPACE::Abs; 28 using hwy::HWY_NAMESPACE::Add; 29 using hwy::HWY_NAMESPACE::And; 30 using hwy::HWY_NAMESPACE::Gt; 31 using hwy::HWY_NAMESPACE::IfThenElse; 32 using hwy::HWY_NAMESPACE::IfThenZeroElse; 33 using hwy::HWY_NAMESPACE::Lt; 34 using hwy::HWY_NAMESPACE::MulEven; 35 using hwy::HWY_NAMESPACE::Ne; 36 using hwy::HWY_NAMESPACE::Neg; 37 using hwy::HWY_NAMESPACE::OddEven; 38 using hwy::HWY_NAMESPACE::RebindToUnsigned; 39 using hwy::HWY_NAMESPACE::ShiftLeft; 40 using hwy::HWY_NAMESPACE::ShiftRight; 41 using hwy::HWY_NAMESPACE::Sub; 42 using hwy::HWY_NAMESPACE::Xor; 43 44 #if HWY_TARGET != HWY_SCALAR 45 46 JXL_INLINE void FastUnsqueeze(const pixel_type *JXL_RESTRICT p_residual, 47 const pixel_type *JXL_RESTRICT p_avg, 48 const pixel_type *JXL_RESTRICT p_navg, 49 const pixel_type *p_pout, 50 pixel_type *JXL_RESTRICT p_out, 51 pixel_type *p_nout) { 52 const HWY_CAPPED(pixel_type, 8) d; 53 const RebindToUnsigned<decltype(d)> du; 54 const size_t N = Lanes(d); 55 auto onethird = Set(d, 0x55555556); 56 for (size_t x = 0; x < 8; x += N) { 57 auto avg = Load(d, p_avg + x); 58 auto next_avg = Load(d, p_navg + x); 59 auto top = Load(d, p_pout + x); 60 // Equivalent to SmoothTendency(top,avg,next_avg), but without branches 61 auto Ba = Sub(top, avg); 62 auto an = Sub(avg, next_avg); 63 auto nonmono = Xor(Ba, an); 64 auto absBa = Abs(Ba); 65 auto absan = Abs(an); 66 auto absBn = Abs(Sub(top, next_avg)); 67 // Compute a3 = absBa / 3 68 auto a3e = BitCast(d, ShiftRight<32>(MulEven(absBa, onethird))); 69 auto a3oi = MulEven(Reverse(d, absBa), onethird); 70 auto a3o = BitCast( 71 d, Reverse(hwy::HWY_NAMESPACE::Repartition<pixel_type_w, decltype(d)>(), 72 a3oi)); 73 auto a3 = OddEven(a3o, a3e); 74 a3 = Add(a3, Add(absBn, Set(d, 2))); 75 auto absdiff = ShiftRight<2>(a3); 76 auto skipdiff = Ne(Ba, Zero(d)); 77 skipdiff = And(skipdiff, Ne(an, Zero(d))); 78 skipdiff = And(skipdiff, Lt(nonmono, Zero(d))); 79 auto absBa2 = Add(ShiftLeft<1>(absBa), And(absdiff, Set(d, 1))); 80 absdiff = IfThenElse(Gt(absdiff, absBa2), 81 Add(ShiftLeft<1>(absBa), Set(d, 1)), absdiff); 82 auto absan2 = ShiftLeft<1>(absan); 83 absdiff = IfThenElse(Gt(Add(absdiff, And(absdiff, Set(d, 1))), absan2), 84 absan2, absdiff); 85 auto diff1 = IfThenElse(Lt(top, next_avg), Neg(absdiff), absdiff); 86 auto tendency = IfThenZeroElse(skipdiff, diff1); 87 88 auto diff_minus_tendency = Load(d, p_residual + x); 89 auto diff = Add(diff_minus_tendency, tendency); 90 auto out = 91 Add(avg, ShiftRight<1>( 92 Add(diff, BitCast(d, ShiftRight<31>(BitCast(du, diff)))))); 93 Store(out, d, p_out + x); 94 Store(Sub(out, diff), d, p_nout + x); 95 } 96 } 97 98 #endif 99 100 Status InvHSqueeze(Image &input, uint32_t c, uint32_t rc, ThreadPool *pool) { 101 JXL_ASSERT(c < input.channel.size()); 102 JXL_ASSERT(rc < input.channel.size()); 103 Channel &chin = input.channel[c]; 104 const Channel &chin_residual = input.channel[rc]; 105 // These must be valid since we ran MetaApply already. 106 JXL_ASSERT(chin.w == DivCeil(chin.w + chin_residual.w, 2)); 107 JXL_ASSERT(chin.h == chin_residual.h); 108 109 if (chin_residual.w == 0) { 110 // Short-circuit: output channel has same dimensions as input. 111 input.channel[c].hshift--; 112 return true; 113 } 114 115 // Note: chin.w >= chin_residual.w and at most 1 different. 116 JXL_ASSIGN_OR_RETURN(Channel chout, 117 Channel::Create(chin.w + chin_residual.w, chin.h, 118 chin.hshift - 1, chin.vshift)); 119 JXL_DEBUG_V(4, 120 "Undoing horizontal squeeze of channel %i using residuals in " 121 "channel %i (going from width %" PRIuS " to %" PRIuS ")", 122 c, rc, chin.w, chout.w); 123 124 if (chin_residual.h == 0) { 125 // Short-circuit: channel with no pixels. 126 input.channel[c] = std::move(chout); 127 return true; 128 } 129 auto unsqueeze_row = [&](size_t y, size_t x0) { 130 const pixel_type *JXL_RESTRICT p_residual = chin_residual.Row(y); 131 const pixel_type *JXL_RESTRICT p_avg = chin.Row(y); 132 pixel_type *JXL_RESTRICT p_out = chout.Row(y); 133 for (size_t x = x0; x < chin_residual.w; x++) { 134 pixel_type_w diff_minus_tendency = p_residual[x]; 135 pixel_type_w avg = p_avg[x]; 136 pixel_type_w next_avg = (x + 1 < chin.w ? p_avg[x + 1] : avg); 137 pixel_type_w left = (x ? p_out[(x << 1) - 1] : avg); 138 pixel_type_w tendency = SmoothTendency(left, avg, next_avg); 139 pixel_type_w diff = diff_minus_tendency + tendency; 140 pixel_type_w A = avg + (diff / 2); 141 p_out[(x << 1)] = A; 142 pixel_type_w B = A - diff; 143 p_out[(x << 1) + 1] = B; 144 } 145 if (chout.w & 1) p_out[chout.w - 1] = p_avg[chin.w - 1]; 146 }; 147 148 // somewhat complicated trickery just to be able to SIMD this. 149 // Horizontal unsqueeze has horizontal data dependencies, so we do 150 // 8 rows at a time and treat it as a vertical unsqueeze of a 151 // transposed 8x8 block (or 9x8 for one input). 152 static constexpr const size_t kRowsPerThread = 8; 153 const auto unsqueeze_span = [&](const uint32_t task, size_t /* thread */) { 154 const size_t y0 = task * kRowsPerThread; 155 const size_t rows = std::min(kRowsPerThread, chin.h - y0); 156 size_t x = 0; 157 158 #if HWY_TARGET != HWY_SCALAR 159 intptr_t onerow_in = chin.plane.PixelsPerRow(); 160 intptr_t onerow_inr = chin_residual.plane.PixelsPerRow(); 161 intptr_t onerow_out = chout.plane.PixelsPerRow(); 162 const pixel_type *JXL_RESTRICT p_residual = chin_residual.Row(y0); 163 const pixel_type *JXL_RESTRICT p_avg = chin.Row(y0); 164 pixel_type *JXL_RESTRICT p_out = chout.Row(y0); 165 HWY_ALIGN pixel_type b_p_avg[9 * kRowsPerThread]; 166 HWY_ALIGN pixel_type b_p_residual[8 * kRowsPerThread]; 167 HWY_ALIGN pixel_type b_p_out_even[8 * kRowsPerThread]; 168 HWY_ALIGN pixel_type b_p_out_odd[8 * kRowsPerThread]; 169 HWY_ALIGN pixel_type b_p_out_evenT[8 * kRowsPerThread]; 170 HWY_ALIGN pixel_type b_p_out_oddT[8 * kRowsPerThread]; 171 const HWY_CAPPED(pixel_type, 8) d; 172 const size_t N = Lanes(d); 173 if (chin_residual.w > 16 && rows == kRowsPerThread) { 174 for (; x < chin_residual.w - 9; x += 8) { 175 Transpose8x8Block(p_residual + x, b_p_residual, onerow_inr); 176 Transpose8x8Block(p_avg + x, b_p_avg, onerow_in); 177 for (size_t y = 0; y < kRowsPerThread; y++) { 178 b_p_avg[8 * 8 + y] = p_avg[x + 8 + onerow_in * y]; 179 } 180 for (size_t i = 0; i < 8; i++) { 181 FastUnsqueeze( 182 b_p_residual + 8 * i, b_p_avg + 8 * i, b_p_avg + 8 * (i + 1), 183 (x + i ? b_p_out_odd + 8 * ((x + i - 1) & 7) : b_p_avg + 8 * i), 184 b_p_out_even + 8 * i, b_p_out_odd + 8 * i); 185 } 186 187 Transpose8x8Block(b_p_out_even, b_p_out_evenT, 8); 188 Transpose8x8Block(b_p_out_odd, b_p_out_oddT, 8); 189 for (size_t y = 0; y < kRowsPerThread; y++) { 190 for (size_t i = 0; i < kRowsPerThread; i += N) { 191 auto even = Load(d, b_p_out_evenT + 8 * y + i); 192 auto odd = Load(d, b_p_out_oddT + 8 * y + i); 193 StoreInterleaved(d, even, odd, 194 p_out + ((x + i) << 1) + onerow_out * y); 195 } 196 } 197 } 198 } 199 #endif 200 for (size_t y = 0; y < rows; y++) { 201 unsqueeze_row(y0 + y, x); 202 } 203 }; 204 JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, DivCeil(chin.h, kRowsPerThread), 205 ThreadPool::NoInit, unsqueeze_span, 206 "InvHorizontalSqueeze")); 207 input.channel[c] = std::move(chout); 208 return true; 209 } 210 211 Status InvVSqueeze(Image &input, uint32_t c, uint32_t rc, ThreadPool *pool) { 212 JXL_ASSERT(c < input.channel.size()); 213 JXL_ASSERT(rc < input.channel.size()); 214 const Channel &chin = input.channel[c]; 215 const Channel &chin_residual = input.channel[rc]; 216 // These must be valid since we ran MetaApply already. 217 JXL_ASSERT(chin.h == DivCeil(chin.h + chin_residual.h, 2)); 218 JXL_ASSERT(chin.w == chin_residual.w); 219 220 if (chin_residual.h == 0) { 221 // Short-circuit: output channel has same dimensions as input. 222 input.channel[c].vshift--; 223 return true; 224 } 225 226 // Note: chin.h >= chin_residual.h and at most 1 different. 227 JXL_ASSIGN_OR_RETURN(Channel chout, 228 Channel::Create(chin.w, chin.h + chin_residual.h, 229 chin.hshift, chin.vshift - 1)); 230 JXL_DEBUG_V( 231 4, 232 "Undoing vertical squeeze of channel %i using residuals in channel " 233 "%i (going from height %" PRIuS " to %" PRIuS ")", 234 c, rc, chin.h, chout.h); 235 236 if (chin_residual.w == 0) { 237 // Short-circuit: channel with no pixels. 238 input.channel[c] = std::move(chout); 239 return true; 240 } 241 242 static constexpr const int kColsPerThread = 64; 243 const auto unsqueeze_slice = [&](const uint32_t task, size_t /* thread */) { 244 const size_t x0 = task * kColsPerThread; 245 const size_t x1 = 246 std::min(static_cast<size_t>(task + 1) * kColsPerThread, chin.w); 247 const size_t w = x1 - x0; 248 // We only iterate up to std::min(chin_residual.h, chin.h) which is 249 // always chin_residual.h. 250 for (size_t y = 0; y < chin_residual.h; y++) { 251 const pixel_type *JXL_RESTRICT p_residual = chin_residual.Row(y) + x0; 252 const pixel_type *JXL_RESTRICT p_avg = chin.Row(y) + x0; 253 const pixel_type *JXL_RESTRICT p_navg = 254 chin.Row(y + 1 < chin.h ? y + 1 : y) + x0; 255 pixel_type *JXL_RESTRICT p_out = chout.Row(y << 1) + x0; 256 pixel_type *JXL_RESTRICT p_nout = chout.Row((y << 1) + 1) + x0; 257 const pixel_type *p_pout = y > 0 ? chout.Row((y << 1) - 1) + x0 : p_avg; 258 size_t x = 0; 259 #if HWY_TARGET != HWY_SCALAR 260 for (; x + 7 < w; x += 8) { 261 FastUnsqueeze(p_residual + x, p_avg + x, p_navg + x, p_pout + x, 262 p_out + x, p_nout + x); 263 } 264 #endif 265 for (; x < w; x++) { 266 pixel_type_w avg = p_avg[x]; 267 pixel_type_w next_avg = p_navg[x]; 268 pixel_type_w top = p_pout[x]; 269 pixel_type_w tendency = SmoothTendency(top, avg, next_avg); 270 pixel_type_w diff_minus_tendency = p_residual[x]; 271 pixel_type_w diff = diff_minus_tendency + tendency; 272 pixel_type_w out = avg + (diff / 2); 273 p_out[x] = out; 274 // If the chin_residual.h == chin.h, the output has an even number 275 // of rows so the next line is fine. Otherwise, this loop won't 276 // write to the last output row which is handled separately. 277 p_nout[x] = out - diff; 278 } 279 } 280 }; 281 JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, DivCeil(chin.w, kColsPerThread), 282 ThreadPool::NoInit, unsqueeze_slice, 283 "InvVertSqueeze")); 284 285 if (chout.h & 1) { 286 size_t y = chin.h - 1; 287 const pixel_type *p_avg = chin.Row(y); 288 pixel_type *p_out = chout.Row(y << 1); 289 for (size_t x = 0; x < chin.w; x++) { 290 p_out[x] = p_avg[x]; 291 } 292 } 293 input.channel[c] = std::move(chout); 294 return true; 295 } 296 297 Status InvSqueeze(Image &input, const std::vector<SqueezeParams> ¶meters, 298 ThreadPool *pool) { 299 for (int i = parameters.size() - 1; i >= 0; i--) { 300 JXL_RETURN_IF_ERROR( 301 CheckMetaSqueezeParams(parameters[i], input.channel.size())); 302 bool horizontal = parameters[i].horizontal; 303 bool in_place = parameters[i].in_place; 304 uint32_t beginc = parameters[i].begin_c; 305 uint32_t endc = parameters[i].begin_c + parameters[i].num_c - 1; 306 uint32_t offset; 307 if (in_place) { 308 offset = endc + 1; 309 } else { 310 offset = input.channel.size() + beginc - endc - 1; 311 } 312 if (beginc < input.nb_meta_channels) { 313 // This is checked in MetaSqueeze. 314 JXL_ASSERT(input.nb_meta_channels > parameters[i].num_c); 315 input.nb_meta_channels -= parameters[i].num_c; 316 } 317 318 for (uint32_t c = beginc; c <= endc; c++) { 319 uint32_t rc = offset + c - beginc; 320 // MetaApply should imply that `rc` is within range, otherwise there's a 321 // programming bug. 322 JXL_ASSERT(rc < input.channel.size()); 323 if ((input.channel[c].w < input.channel[rc].w) || 324 (input.channel[c].h < input.channel[rc].h)) { 325 return JXL_FAILURE("Corrupted squeeze transform"); 326 } 327 if (horizontal) { 328 JXL_RETURN_IF_ERROR(InvHSqueeze(input, c, rc, pool)); 329 } else { 330 JXL_RETURN_IF_ERROR(InvVSqueeze(input, c, rc, pool)); 331 } 332 } 333 input.channel.erase(input.channel.begin() + offset, 334 input.channel.begin() + offset + (endc - beginc + 1)); 335 } 336 return true; 337 } 338 339 } // namespace HWY_NAMESPACE 340 } // namespace jxl 341 HWY_AFTER_NAMESPACE(); 342 343 #if HWY_ONCE 344 345 namespace jxl { 346 347 HWY_EXPORT(InvSqueeze); 348 Status InvSqueeze(Image &input, const std::vector<SqueezeParams> ¶meters, 349 ThreadPool *pool) { 350 return HWY_DYNAMIC_DISPATCH(InvSqueeze)(input, parameters, pool); 351 } 352 353 void DefaultSqueezeParameters(std::vector<SqueezeParams> *parameters, 354 const Image &image) { 355 int nb_channels = image.channel.size() - image.nb_meta_channels; 356 357 parameters->clear(); 358 size_t w = image.channel[image.nb_meta_channels].w; 359 size_t h = image.channel[image.nb_meta_channels].h; 360 JXL_DEBUG_V( 361 7, "Default squeeze parameters for %" PRIuS "x%" PRIuS " image: ", w, h); 362 363 // do horizontal first on wide images; vertical first on tall images 364 bool wide = (w > h); 365 366 if (nb_channels > 2 && image.channel[image.nb_meta_channels + 1].w == w && 367 image.channel[image.nb_meta_channels + 1].h == h) { 368 // assume channels 1 and 2 are chroma, and can be squeezed first for 4:2:0 369 // previews 370 JXL_DEBUG_V(7, "(4:2:0 chroma), %" PRIuS "x%" PRIuS " image", w, h); 371 SqueezeParams params; 372 // horizontal chroma squeeze 373 params.horizontal = true; 374 params.in_place = false; 375 params.begin_c = image.nb_meta_channels + 1; 376 params.num_c = 2; 377 parameters->push_back(params); 378 params.horizontal = false; 379 // vertical chroma squeeze 380 parameters->push_back(params); 381 } 382 SqueezeParams params; 383 params.begin_c = image.nb_meta_channels; 384 params.num_c = nb_channels; 385 params.in_place = true; 386 387 if (!wide) { 388 if (h > JXL_MAX_FIRST_PREVIEW_SIZE) { 389 params.horizontal = false; 390 parameters->push_back(params); 391 h = (h + 1) / 2; 392 JXL_DEBUG_V(7, "Vertical (%" PRIuS "x%" PRIuS "), ", w, h); 393 } 394 } 395 while (w > JXL_MAX_FIRST_PREVIEW_SIZE || h > JXL_MAX_FIRST_PREVIEW_SIZE) { 396 if (w > JXL_MAX_FIRST_PREVIEW_SIZE) { 397 params.horizontal = true; 398 parameters->push_back(params); 399 w = (w + 1) / 2; 400 JXL_DEBUG_V(7, "Horizontal (%" PRIuS "x%" PRIuS "), ", w, h); 401 } 402 if (h > JXL_MAX_FIRST_PREVIEW_SIZE) { 403 params.horizontal = false; 404 parameters->push_back(params); 405 h = (h + 1) / 2; 406 JXL_DEBUG_V(7, "Vertical (%" PRIuS "x%" PRIuS "), ", w, h); 407 } 408 } 409 JXL_DEBUG_V(7, "that's it"); 410 } 411 412 Status CheckMetaSqueezeParams(const SqueezeParams ¶meter, 413 int num_channels) { 414 int c1 = parameter.begin_c; 415 int c2 = parameter.begin_c + parameter.num_c - 1; 416 if (c1 < 0 || c1 >= num_channels || c2 < 0 || c2 >= num_channels || c2 < c1) { 417 return JXL_FAILURE("Invalid channel range"); 418 } 419 return true; 420 } 421 422 Status MetaSqueeze(Image &image, std::vector<SqueezeParams> *parameters) { 423 if (parameters->empty()) { 424 DefaultSqueezeParameters(parameters, image); 425 } 426 427 for (size_t i = 0; i < parameters->size(); i++) { 428 JXL_RETURN_IF_ERROR( 429 CheckMetaSqueezeParams((*parameters)[i], image.channel.size())); 430 bool horizontal = (*parameters)[i].horizontal; 431 bool in_place = (*parameters)[i].in_place; 432 uint32_t beginc = (*parameters)[i].begin_c; 433 uint32_t endc = (*parameters)[i].begin_c + (*parameters)[i].num_c - 1; 434 435 uint32_t offset; 436 if (beginc < image.nb_meta_channels) { 437 if (endc >= image.nb_meta_channels) { 438 return JXL_FAILURE("Invalid squeeze: mix of meta and nonmeta channels"); 439 } 440 if (!in_place) { 441 return JXL_FAILURE( 442 "Invalid squeeze: meta channels require in-place residuals"); 443 } 444 image.nb_meta_channels += (*parameters)[i].num_c; 445 } 446 if (in_place) { 447 offset = endc + 1; 448 } else { 449 offset = image.channel.size(); 450 } 451 for (uint32_t c = beginc; c <= endc; c++) { 452 if (image.channel[c].hshift > 30 || image.channel[c].vshift > 30) { 453 return JXL_FAILURE("Too many squeezes: shift > 30"); 454 } 455 size_t w = image.channel[c].w; 456 size_t h = image.channel[c].h; 457 if (w == 0 || h == 0) return JXL_FAILURE("Squeezing empty channel"); 458 if (horizontal) { 459 image.channel[c].w = (w + 1) / 2; 460 if (image.channel[c].hshift >= 0) image.channel[c].hshift++; 461 w = w - (w + 1) / 2; 462 } else { 463 image.channel[c].h = (h + 1) / 2; 464 if (image.channel[c].vshift >= 0) image.channel[c].vshift++; 465 h = h - (h + 1) / 2; 466 } 467 JXL_RETURN_IF_ERROR(image.channel[c].shrink()); 468 JXL_ASSIGN_OR_RETURN(Channel placeholder, Channel::Create(w, h)); 469 placeholder.hshift = image.channel[c].hshift; 470 placeholder.vshift = image.channel[c].vshift; 471 472 image.channel.insert(image.channel.begin() + offset + (c - beginc), 473 std::move(placeholder)); 474 JXL_DEBUG_V(8, "MetaSqueeze applied, current image: %s", 475 image.DebugString().c_str()); 476 } 477 } 478 return true; 479 } 480 481 } // namespace jxl 482 483 #endif