fields.cc (21029B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/fields.h" 7 8 #include <algorithm> 9 #include <cinttypes> 10 #include <cmath> 11 #include <cstddef> 12 #include <hwy/base.h> 13 14 #include "lib/jxl/base/bits.h" 15 #include "lib/jxl/base/printf_macros.h" 16 17 namespace jxl { 18 19 namespace { 20 21 using ::jxl::fields_internal::VisitorBase; 22 23 struct InitVisitor : public VisitorBase { 24 Status Bits(const size_t /*unused*/, const uint32_t default_value, 25 uint32_t* JXL_RESTRICT value) override { 26 *value = default_value; 27 return true; 28 } 29 30 Status U32(const U32Enc /*unused*/, const uint32_t default_value, 31 uint32_t* JXL_RESTRICT value) override { 32 *value = default_value; 33 return true; 34 } 35 36 Status U64(const uint64_t default_value, 37 uint64_t* JXL_RESTRICT value) override { 38 *value = default_value; 39 return true; 40 } 41 42 Status Bool(bool default_value, bool* JXL_RESTRICT value) override { 43 *value = default_value; 44 return true; 45 } 46 47 Status F16(const float default_value, float* JXL_RESTRICT value) override { 48 *value = default_value; 49 return true; 50 } 51 52 // Always visit conditional fields to ensure they are initialized. 53 Status Conditional(bool /*condition*/) override { return true; } 54 55 Status AllDefault(const Fields& /*fields*/, 56 bool* JXL_RESTRICT all_default) override { 57 // Just initialize this field and don't skip initializing others. 58 JXL_RETURN_IF_ERROR(Bool(true, all_default)); 59 return false; 60 } 61 62 Status VisitNested(Fields* /*fields*/) override { 63 // Avoid re-initializing nested bundles (their ctors already called 64 // Bundle::Init for their fields). 65 return true; 66 } 67 }; 68 69 // Similar to InitVisitor, but also initializes nested fields. 70 struct SetDefaultVisitor : public VisitorBase { 71 Status Bits(const size_t /*unused*/, const uint32_t default_value, 72 uint32_t* JXL_RESTRICT value) override { 73 *value = default_value; 74 return true; 75 } 76 77 Status U32(const U32Enc /*unused*/, const uint32_t default_value, 78 uint32_t* JXL_RESTRICT value) override { 79 *value = default_value; 80 return true; 81 } 82 83 Status U64(const uint64_t default_value, 84 uint64_t* JXL_RESTRICT value) override { 85 *value = default_value; 86 return true; 87 } 88 89 Status Bool(bool default_value, bool* JXL_RESTRICT value) override { 90 *value = default_value; 91 return true; 92 } 93 94 Status F16(const float default_value, float* JXL_RESTRICT value) override { 95 *value = default_value; 96 return true; 97 } 98 99 // Always visit conditional fields to ensure they are initialized. 100 Status Conditional(bool /*condition*/) override { return true; } 101 102 Status AllDefault(const Fields& /*fields*/, 103 bool* JXL_RESTRICT all_default) override { 104 // Just initialize this field and don't skip initializing others. 105 JXL_RETURN_IF_ERROR(Bool(true, all_default)); 106 return false; 107 } 108 }; 109 110 class AllDefaultVisitor : public VisitorBase { 111 public: 112 explicit AllDefaultVisitor() = default; 113 114 Status Bits(const size_t bits, const uint32_t default_value, 115 uint32_t* JXL_RESTRICT value) override { 116 all_default_ &= *value == default_value; 117 return true; 118 } 119 120 Status U32(const U32Enc /*unused*/, const uint32_t default_value, 121 uint32_t* JXL_RESTRICT value) override { 122 all_default_ &= *value == default_value; 123 return true; 124 } 125 126 Status U64(const uint64_t default_value, 127 uint64_t* JXL_RESTRICT value) override { 128 all_default_ &= *value == default_value; 129 return true; 130 } 131 132 Status F16(const float default_value, float* JXL_RESTRICT value) override { 133 all_default_ &= std::abs(*value - default_value) < 1E-6f; 134 return true; 135 } 136 137 Status AllDefault(const Fields& /*fields*/, 138 bool* JXL_RESTRICT /*all_default*/) override { 139 // Visit all fields so we can compute the actual all_default_ value. 140 return false; 141 } 142 143 bool AllDefault() const { return all_default_; } 144 145 private: 146 bool all_default_ = true; 147 }; 148 149 class ReadVisitor : public VisitorBase { 150 public: 151 explicit ReadVisitor(BitReader* reader) : reader_(reader) {} 152 153 Status Bits(const size_t bits, const uint32_t /*default_value*/, 154 uint32_t* JXL_RESTRICT value) override { 155 *value = BitsCoder::Read(bits, reader_); 156 if (!reader_->AllReadsWithinBounds()) { 157 return JXL_STATUS(StatusCode::kNotEnoughBytes, 158 "Not enough bytes for header"); 159 } 160 return true; 161 } 162 163 Status U32(const U32Enc dist, const uint32_t /*default_value*/, 164 uint32_t* JXL_RESTRICT value) override { 165 *value = U32Coder::Read(dist, reader_); 166 if (!reader_->AllReadsWithinBounds()) { 167 return JXL_STATUS(StatusCode::kNotEnoughBytes, 168 "Not enough bytes for header"); 169 } 170 return true; 171 } 172 173 Status U64(const uint64_t /*default_value*/, 174 uint64_t* JXL_RESTRICT value) override { 175 *value = U64Coder::Read(reader_); 176 if (!reader_->AllReadsWithinBounds()) { 177 return JXL_STATUS(StatusCode::kNotEnoughBytes, 178 "Not enough bytes for header"); 179 } 180 return true; 181 } 182 183 Status F16(const float /*default_value*/, 184 float* JXL_RESTRICT value) override { 185 ok_ &= F16Coder::Read(reader_, value); 186 if (!reader_->AllReadsWithinBounds()) { 187 return JXL_STATUS(StatusCode::kNotEnoughBytes, 188 "Not enough bytes for header"); 189 } 190 return true; 191 } 192 193 void SetDefault(Fields* fields) override { Bundle::SetDefault(fields); } 194 195 bool IsReading() const override { return true; } 196 197 // This never fails because visitors are expected to keep reading until 198 // EndExtensions, see comment there. 199 Status BeginExtensions(uint64_t* JXL_RESTRICT extensions) override { 200 JXL_QUIET_RETURN_IF_ERROR(VisitorBase::BeginExtensions(extensions)); 201 if (*extensions == 0) return true; 202 203 // For each nonzero bit, i.e. extension that is present: 204 for (uint64_t remaining_extensions = *extensions; remaining_extensions != 0; 205 remaining_extensions &= remaining_extensions - 1) { 206 const size_t idx_extension = 207 Num0BitsBelowLS1Bit_Nonzero(remaining_extensions); 208 // Read additional U64 (one per extension) indicating the number of bits 209 // (allows skipping individual extensions). 210 JXL_RETURN_IF_ERROR(U64(0, &extension_bits_[idx_extension])); 211 if (!SafeAdd(total_extension_bits_, extension_bits_[idx_extension], 212 total_extension_bits_)) { 213 return JXL_FAILURE("Extension bits overflowed, invalid codestream"); 214 } 215 } 216 // Used by EndExtensions to skip past any _remaining_ extensions. 217 pos_after_ext_size_ = reader_->TotalBitsConsumed(); 218 JXL_ASSERT(pos_after_ext_size_ != 0); 219 return true; 220 } 221 222 Status EndExtensions() override { 223 JXL_QUIET_RETURN_IF_ERROR(VisitorBase::EndExtensions()); 224 // Happens if extensions == 0: don't read size, done. 225 if (pos_after_ext_size_ == 0) return true; 226 227 // Not enough bytes as set by BeginExtensions or earlier. Do not return 228 // this as a JXL_FAILURE or false (which can also propagate to error 229 // through e.g. JXL_RETURN_IF_ERROR), since this may be used while 230 // silently checking whether there are enough bytes. If this case must be 231 // treated as an error, reader_>Close() will do this, just like is already 232 // done for non-extension fields. 233 if (!enough_bytes_) return true; 234 235 // Skip new fields this (old?) decoder didn't know about, if any. 236 const size_t bits_read = reader_->TotalBitsConsumed(); 237 uint64_t end; 238 if (!SafeAdd(pos_after_ext_size_, total_extension_bits_, end)) { 239 return JXL_FAILURE("Invalid extension size, caused overflow"); 240 } 241 if (bits_read > end) { 242 return JXL_FAILURE("Read more extension bits than budgeted"); 243 } 244 const size_t remaining_bits = end - bits_read; 245 if (remaining_bits != 0) { 246 JXL_WARNING("Skipping %" PRIuS "-bit extension(s)", remaining_bits); 247 reader_->SkipBits(remaining_bits); 248 if (!reader_->AllReadsWithinBounds()) { 249 return JXL_STATUS(StatusCode::kNotEnoughBytes, 250 "Not enough bytes for header"); 251 } 252 } 253 return true; 254 } 255 256 Status OK() const { return ok_; } 257 258 private: 259 // Whether any error other than not enough bytes occurred. 260 bool ok_ = true; 261 262 // Whether there are enough input bytes to read from. 263 bool enough_bytes_ = true; 264 BitReader* const reader_; 265 // May be 0 even if the corresponding extension is present. 266 uint64_t extension_bits_[Bundle::kMaxExtensions] = {0}; 267 uint64_t total_extension_bits_ = 0; 268 size_t pos_after_ext_size_ = 0; // 0 iff extensions == 0. 269 270 friend Status jxl::CheckHasEnoughBits(Visitor* /* visitor */, 271 size_t /* bits */); 272 }; 273 274 class MaxBitsVisitor : public VisitorBase { 275 public: 276 Status Bits(const size_t bits, const uint32_t /*default_value*/, 277 uint32_t* JXL_RESTRICT /*value*/) override { 278 max_bits_ += BitsCoder::MaxEncodedBits(bits); 279 return true; 280 } 281 282 Status U32(const U32Enc enc, const uint32_t /*default_value*/, 283 uint32_t* JXL_RESTRICT /*value*/) override { 284 max_bits_ += U32Coder::MaxEncodedBits(enc); 285 return true; 286 } 287 288 Status U64(const uint64_t /*default_value*/, 289 uint64_t* JXL_RESTRICT /*value*/) override { 290 max_bits_ += U64Coder::MaxEncodedBits(); 291 return true; 292 } 293 294 Status F16(const float /*default_value*/, 295 float* JXL_RESTRICT /*value*/) override { 296 max_bits_ += F16Coder::MaxEncodedBits(); 297 return true; 298 } 299 300 Status AllDefault(const Fields& /*fields*/, 301 bool* JXL_RESTRICT all_default) override { 302 JXL_RETURN_IF_ERROR(Bool(true, all_default)); 303 return false; // For max bits, assume nothing is default 304 } 305 306 // Always visit conditional fields to get a (loose) upper bound. 307 Status Conditional(bool /*condition*/) override { return true; } 308 309 Status BeginExtensions(uint64_t* JXL_RESTRICT /*extensions*/) override { 310 // Skip - extensions are not included in "MaxBits" because their length 311 // is potentially unbounded. 312 return true; 313 } 314 315 Status EndExtensions() override { return true; } 316 317 size_t MaxBits() const { return max_bits_; } 318 319 private: 320 size_t max_bits_ = 0; 321 }; 322 323 class CanEncodeVisitor : public VisitorBase { 324 public: 325 explicit CanEncodeVisitor() = default; 326 327 Status Bits(const size_t bits, const uint32_t /*default_value*/, 328 uint32_t* JXL_RESTRICT value) override { 329 size_t encoded_bits = 0; 330 ok_ &= BitsCoder::CanEncode(bits, *value, &encoded_bits); 331 encoded_bits_ += encoded_bits; 332 return true; 333 } 334 335 Status U32(const U32Enc enc, const uint32_t /*default_value*/, 336 uint32_t* JXL_RESTRICT value) override { 337 size_t encoded_bits = 0; 338 ok_ &= U32Coder::CanEncode(enc, *value, &encoded_bits); 339 encoded_bits_ += encoded_bits; 340 return true; 341 } 342 343 Status U64(const uint64_t /*default_value*/, 344 uint64_t* JXL_RESTRICT value) override { 345 size_t encoded_bits = 0; 346 ok_ &= U64Coder::CanEncode(*value, &encoded_bits); 347 encoded_bits_ += encoded_bits; 348 return true; 349 } 350 351 Status F16(const float /*default_value*/, 352 float* JXL_RESTRICT value) override { 353 size_t encoded_bits = 0; 354 ok_ &= F16Coder::CanEncode(*value, &encoded_bits); 355 encoded_bits_ += encoded_bits; 356 return true; 357 } 358 359 Status AllDefault(const Fields& fields, 360 bool* JXL_RESTRICT all_default) override { 361 *all_default = Bundle::AllDefault(fields); 362 JXL_RETURN_IF_ERROR(Bool(true, all_default)); 363 return *all_default; 364 } 365 366 Status BeginExtensions(uint64_t* JXL_RESTRICT extensions) override { 367 JXL_QUIET_RETURN_IF_ERROR(VisitorBase::BeginExtensions(extensions)); 368 extensions_ = *extensions; 369 if (*extensions != 0) { 370 JXL_ASSERT(pos_after_ext_ == 0); 371 pos_after_ext_ = encoded_bits_; 372 JXL_ASSERT(pos_after_ext_ != 0); // visited "extensions" 373 } 374 return true; 375 } 376 // EndExtensions = default. 377 378 Status GetSizes(size_t* JXL_RESTRICT extension_bits, 379 size_t* JXL_RESTRICT total_bits) { 380 JXL_RETURN_IF_ERROR(ok_); 381 *extension_bits = 0; 382 *total_bits = encoded_bits_; 383 // Only if extension field was nonzero will we encode their sizes. 384 if (pos_after_ext_ != 0) { 385 JXL_ASSERT(encoded_bits_ >= pos_after_ext_); 386 *extension_bits = encoded_bits_ - pos_after_ext_; 387 // Also need to encode *extension_bits and bill it to *total_bits. 388 size_t encoded_bits = 0; 389 ok_ &= U64Coder::CanEncode(*extension_bits, &encoded_bits); 390 *total_bits += encoded_bits; 391 392 // TODO(janwas): support encoding individual extension sizes. We 393 // currently ascribe all bits to the first and send zeros for the 394 // others. 395 for (size_t i = 1; i < hwy::PopCount(extensions_); ++i) { 396 encoded_bits = 0; 397 ok_ &= U64Coder::CanEncode(0, &encoded_bits); 398 *total_bits += encoded_bits; 399 } 400 } 401 return true; 402 } 403 404 private: 405 bool ok_ = true; 406 size_t encoded_bits_ = 0; 407 uint64_t extensions_ = 0; 408 // Snapshot of encoded_bits_ after visiting the extension field, but NOT 409 // including the hidden extension sizes. 410 uint64_t pos_after_ext_ = 0; 411 }; 412 } // namespace 413 414 void Bundle::Init(Fields* fields) { 415 InitVisitor visitor; 416 if (!visitor.Visit(fields)) { 417 JXL_UNREACHABLE("Init should never fail"); 418 } 419 } 420 void Bundle::SetDefault(Fields* fields) { 421 SetDefaultVisitor visitor; 422 if (!visitor.Visit(fields)) { 423 JXL_UNREACHABLE("SetDefault should never fail"); 424 } 425 } 426 bool Bundle::AllDefault(const Fields& fields) { 427 AllDefaultVisitor visitor; 428 if (!visitor.VisitConst(fields)) { 429 JXL_UNREACHABLE("AllDefault should never fail"); 430 } 431 return visitor.AllDefault(); 432 } 433 size_t Bundle::MaxBits(const Fields& fields) { 434 MaxBitsVisitor visitor; 435 #if JXL_ENABLE_ASSERT 436 Status ret = 437 #else 438 (void) 439 #endif // JXL_ENABLE_ASSERT 440 visitor.VisitConst(fields); 441 JXL_ASSERT(ret); 442 return visitor.MaxBits(); 443 } 444 Status Bundle::CanEncode(const Fields& fields, size_t* extension_bits, 445 size_t* total_bits) { 446 CanEncodeVisitor visitor; 447 JXL_QUIET_RETURN_IF_ERROR(visitor.VisitConst(fields)); 448 JXL_QUIET_RETURN_IF_ERROR(visitor.GetSizes(extension_bits, total_bits)); 449 return true; 450 } 451 Status Bundle::Read(BitReader* reader, Fields* fields) { 452 ReadVisitor visitor(reader); 453 JXL_RETURN_IF_ERROR(visitor.Visit(fields)); 454 return visitor.OK(); 455 } 456 bool Bundle::CanRead(BitReader* reader, Fields* fields) { 457 ReadVisitor visitor(reader); 458 Status status = visitor.Visit(fields); 459 // We are only checking here whether there are enough bytes. We still return 460 // true for other errors because it means there are enough bytes to determine 461 // there's an error. Use Read() to determine which error it is. 462 return status.code() != StatusCode::kNotEnoughBytes; 463 } 464 465 size_t BitsCoder::MaxEncodedBits(const size_t bits) { return bits; } 466 467 Status BitsCoder::CanEncode(const size_t bits, const uint32_t value, 468 size_t* JXL_RESTRICT encoded_bits) { 469 *encoded_bits = bits; 470 if (value >= (1ULL << bits)) { 471 return JXL_FAILURE("Value %u too large for %" PRIu64 " bits", value, 472 static_cast<uint64_t>(bits)); 473 } 474 return true; 475 } 476 477 uint32_t BitsCoder::Read(const size_t bits, BitReader* JXL_RESTRICT reader) { 478 return reader->ReadBits(bits); 479 } 480 481 size_t U32Coder::MaxEncodedBits(const U32Enc enc) { 482 size_t extra_bits = 0; 483 for (uint32_t selector = 0; selector < 4; ++selector) { 484 const U32Distr d = enc.GetDistr(selector); 485 if (d.IsDirect()) { 486 continue; 487 } else { 488 extra_bits = std::max<size_t>(extra_bits, d.ExtraBits()); 489 } 490 } 491 return 2 + extra_bits; 492 } 493 494 Status U32Coder::CanEncode(const U32Enc enc, const uint32_t value, 495 size_t* JXL_RESTRICT encoded_bits) { 496 uint32_t selector; 497 size_t total_bits; 498 const Status ok = ChooseSelector(enc, value, &selector, &total_bits); 499 *encoded_bits = ok ? total_bits : 0; 500 return ok; 501 } 502 503 uint32_t U32Coder::Read(const U32Enc enc, BitReader* JXL_RESTRICT reader) { 504 const uint32_t selector = reader->ReadFixedBits<2>(); 505 const U32Distr d = enc.GetDistr(selector); 506 if (d.IsDirect()) { 507 return d.Direct(); 508 } else { 509 return reader->ReadBits(d.ExtraBits()) + d.Offset(); 510 } 511 } 512 513 Status U32Coder::ChooseSelector(const U32Enc enc, const uint32_t value, 514 uint32_t* JXL_RESTRICT selector, 515 size_t* JXL_RESTRICT total_bits) { 516 #if JXL_ENABLE_ASSERT 517 const size_t bits_required = 32 - Num0BitsAboveMS1Bit(value); 518 #endif // JXL_ENABLE_ASSERT 519 JXL_ASSERT(bits_required <= 32); 520 521 *selector = 0; 522 *total_bits = 0; 523 524 // It is difficult to verify whether Dist32Byte are sorted, so check all 525 // selectors and keep the one with the fewest total_bits. 526 *total_bits = 64; // more than any valid encoding 527 for (uint32_t s = 0; s < 4; ++s) { 528 const U32Distr d = enc.GetDistr(s); 529 if (d.IsDirect()) { 530 if (d.Direct() == value) { 531 *selector = s; 532 *total_bits = 2; 533 return true; // Done, direct is always the best possible. 534 } 535 continue; 536 } 537 const size_t extra_bits = d.ExtraBits(); 538 const uint32_t offset = d.Offset(); 539 if (value < offset || value >= offset + (1ULL << extra_bits)) continue; 540 541 // Better than prior encoding, remember it: 542 if (2 + extra_bits < *total_bits) { 543 *selector = s; 544 *total_bits = 2 + extra_bits; 545 } 546 } 547 548 if (*total_bits == 64) { 549 return JXL_FAILURE("No feasible selector for %u", value); 550 } 551 552 return true; 553 } 554 555 uint64_t U64Coder::Read(BitReader* JXL_RESTRICT reader) { 556 uint64_t selector = reader->ReadFixedBits<2>(); 557 if (selector == 0) { 558 return 0; 559 } 560 if (selector == 1) { 561 return 1 + reader->ReadFixedBits<4>(); 562 } 563 if (selector == 2) { 564 return 17 + reader->ReadFixedBits<8>(); 565 } 566 567 // selector 3, varint, groups have first 12, then 8, and last 4 bits. 568 uint64_t result = reader->ReadFixedBits<12>(); 569 570 uint64_t shift = 12; 571 while (reader->ReadFixedBits<1>()) { 572 if (shift == 60) { 573 result |= static_cast<uint64_t>(reader->ReadFixedBits<4>()) << shift; 574 break; 575 } 576 result |= static_cast<uint64_t>(reader->ReadFixedBits<8>()) << shift; 577 shift += 8; 578 } 579 580 return result; 581 } 582 583 // Can always encode, but useful because it also returns bit size. 584 Status U64Coder::CanEncode(uint64_t value, size_t* JXL_RESTRICT encoded_bits) { 585 if (value == 0) { 586 *encoded_bits = 2; // 2 selector bits 587 } else if (value <= 16) { 588 *encoded_bits = 2 + 4; // 2 selector bits + 4 payload bits 589 } else if (value <= 272) { 590 *encoded_bits = 2 + 8; // 2 selector bits + 8 payload bits 591 } else { 592 *encoded_bits = 2 + 12; // 2 selector bits + 12 payload bits 593 value >>= 12; 594 int shift = 12; 595 while (value > 0 && shift < 60) { 596 *encoded_bits += 1 + 8; // 1 continuation bit + 8 payload bits 597 value >>= 8; 598 shift += 8; 599 } 600 if (value > 0) { 601 // This only could happen if shift == N - 4. 602 *encoded_bits += 1 + 4; // 1 continuation bit + 4 payload bits 603 } else { 604 *encoded_bits += 1; // 1 stop bit 605 } 606 } 607 608 return true; 609 } 610 611 Status F16Coder::Read(BitReader* JXL_RESTRICT reader, 612 float* JXL_RESTRICT value) { 613 const uint32_t bits16 = reader->ReadFixedBits<16>(); 614 const uint32_t sign = bits16 >> 15; 615 const uint32_t biased_exp = (bits16 >> 10) & 0x1F; 616 const uint32_t mantissa = bits16 & 0x3FF; 617 618 if (JXL_UNLIKELY(biased_exp == 31)) { 619 return JXL_FAILURE("F16 infinity or NaN are not supported"); 620 } 621 622 // Subnormal or zero 623 if (JXL_UNLIKELY(biased_exp == 0)) { 624 *value = (1.0f / 16384) * (mantissa * (1.0f / 1024)); 625 if (sign) *value = -*value; 626 return true; 627 } 628 629 // Normalized: convert the representation directly (faster than ldexp/tables). 630 const uint32_t biased_exp32 = biased_exp + (127 - 15); 631 const uint32_t mantissa32 = mantissa << (23 - 10); 632 const uint32_t bits32 = (sign << 31) | (biased_exp32 << 23) | mantissa32; 633 memcpy(value, &bits32, sizeof(bits32)); 634 return true; 635 } 636 637 Status F16Coder::CanEncode(float value, size_t* JXL_RESTRICT encoded_bits) { 638 *encoded_bits = MaxEncodedBits(); 639 if (std::isnan(value) || std::isinf(value)) { 640 return JXL_FAILURE("Should not attempt to store NaN and infinity"); 641 } 642 return std::abs(value) <= 65504.0f; 643 } 644 645 Status CheckHasEnoughBits(Visitor* visitor, size_t bits) { 646 if (!visitor->IsReading()) return false; 647 ReadVisitor* rv = static_cast<ReadVisitor*>(visitor); 648 size_t have_bits = rv->reader_->TotalBytes() * kBitsPerByte; 649 size_t want_bits = bits + rv->reader_->TotalBitsConsumed(); 650 if (have_bits < want_bits) { 651 return JXL_STATUS(StatusCode::kNotEnoughBytes, 652 "Not enough bytes for header"); 653 } 654 return true; 655 } 656 657 } // namespace jxl