matrix_ops.h (2638B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #ifndef LIB_JXL_MATRIX_OPS_H_ 7 #define LIB_JXL_MATRIX_OPS_H_ 8 9 // 3x3 matrix operations. 10 11 #include <cmath> // abs 12 #include <cstddef> 13 14 #include "lib/jxl/base/status.h" 15 16 namespace jxl { 17 18 // Computes C = A * B, where A, B, C are 3x3 matrices. 19 template <typename T> 20 void Mul3x3Matrix(const T* a, const T* b, T* c) { 21 alignas(16) T temp[3]; // For transposed column 22 for (size_t x = 0; x < 3; x++) { 23 for (size_t z = 0; z < 3; z++) { 24 temp[z] = b[z * 3 + x]; 25 } 26 for (size_t y = 0; y < 3; y++) { 27 double e = 0; 28 for (size_t z = 0; z < 3; z++) { 29 e += a[y * 3 + z] * temp[z]; 30 } 31 c[y * 3 + x] = e; 32 } 33 } 34 } 35 36 // Computes C = A * B, where A is 3x3 matrix and B is vector. 37 template <typename T> 38 void Mul3x3Vector(const T* a, const T* b, T* c) { 39 for (size_t y = 0; y < 3; y++) { 40 double e = 0; 41 for (size_t x = 0; x < 3; x++) { 42 e += a[y * 3 + x] * b[x]; 43 } 44 c[y] = e; 45 } 46 } 47 48 // Inverts a 3x3 matrix in place. 49 template <typename T> 50 Status Inv3x3Matrix(T* matrix) { 51 // Intermediate computation is done in double precision. 52 double temp[9]; 53 temp[0] = static_cast<double>(matrix[4]) * matrix[8] - 54 static_cast<double>(matrix[5]) * matrix[7]; 55 temp[1] = static_cast<double>(matrix[2]) * matrix[7] - 56 static_cast<double>(matrix[1]) * matrix[8]; 57 temp[2] = static_cast<double>(matrix[1]) * matrix[5] - 58 static_cast<double>(matrix[2]) * matrix[4]; 59 temp[3] = static_cast<double>(matrix[5]) * matrix[6] - 60 static_cast<double>(matrix[3]) * matrix[8]; 61 temp[4] = static_cast<double>(matrix[0]) * matrix[8] - 62 static_cast<double>(matrix[2]) * matrix[6]; 63 temp[5] = static_cast<double>(matrix[2]) * matrix[3] - 64 static_cast<double>(matrix[0]) * matrix[5]; 65 temp[6] = static_cast<double>(matrix[3]) * matrix[7] - 66 static_cast<double>(matrix[4]) * matrix[6]; 67 temp[7] = static_cast<double>(matrix[1]) * matrix[6] - 68 static_cast<double>(matrix[0]) * matrix[7]; 69 temp[8] = static_cast<double>(matrix[0]) * matrix[4] - 70 static_cast<double>(matrix[1]) * matrix[3]; 71 double det = matrix[0] * temp[0] + matrix[1] * temp[3] + matrix[2] * temp[6]; 72 if (std::abs(det) < 1e-10) { 73 return JXL_FAILURE("Matrix determinant is too close to 0"); 74 } 75 double idet = 1.0 / det; 76 for (size_t i = 0; i < 9; i++) { 77 matrix[i] = temp[i] * idet; 78 } 79 return true; 80 } 81 82 } // namespace jxl 83 84 #endif // LIB_JXL_MATRIX_OPS_H_