compress_helpers.cpp (9287B)
1 // SPDX-FileCopyrightText: 2019-2024 Connor McLaughlin <stenzek@gmail.com> 2 // SPDX-License-Identifier: (GPL-3.0 OR PolyForm-Strict-1.0.0) 3 4 #include "compress_helpers.h" 5 6 #include "common/assert.h" 7 #include "common/error.h" 8 #include "common/file_system.h" 9 #include "common/path.h" 10 #include "common/string_util.h" 11 12 #include <zstd.h> 13 #include <zstd_errors.h> 14 15 // TODO: Use streaming API to avoid mallocing the whole input buffer. But one read() call is probably still faster.. 16 17 namespace CompressHelpers { 18 static std::optional<CompressType> GetCompressType(const std::string_view path, Error* error); 19 20 template<typename T> 21 static bool DecompressHelper(OptionalByteBuffer& ret, CompressType type, T data, 22 std::optional<size_t> decompressed_size, Error* error); 23 24 template<typename T> 25 static bool CompressHelper(OptionalByteBuffer& ret, CompressType type, T data, int clevel, Error* error); 26 } // namespace CompressHelpers 27 28 std::optional<CompressHelpers::CompressType> CompressHelpers::GetCompressType(const std::string_view path, Error* error) 29 { 30 const std::string_view extension = Path::GetExtension(path); 31 if (StringUtil::EqualNoCase(extension, "zst")) 32 return CompressType::Zstandard; 33 34 return CompressType::Uncompressed; 35 } 36 37 template<typename T> 38 bool CompressHelpers::DecompressHelper(CompressHelpers::OptionalByteBuffer& ret, CompressType type, T data, 39 std::optional<size_t> decompressed_size, Error* error) 40 { 41 if (data.size() == 0) [[unlikely]] 42 { 43 Error::SetStringView(error, "Buffer is empty."); 44 return false; 45 } 46 47 switch (type) 48 { 49 case CompressType::Uncompressed: 50 { 51 ret = ByteBuffer(std::move(data)); 52 return true; 53 } 54 55 case CompressType::Zstandard: 56 { 57 size_t real_decompressed_size; 58 if (!decompressed_size.has_value()) 59 { 60 const unsigned long long runtime_decompressed_size = ZSTD_getFrameContentSize(data.data(), data.size()); 61 if (runtime_decompressed_size == ZSTD_CONTENTSIZE_UNKNOWN || 62 runtime_decompressed_size == ZSTD_CONTENTSIZE_ERROR || 63 runtime_decompressed_size >= std::numeric_limits<size_t>::max()) [[unlikely]] 64 { 65 Error::SetStringView(error, "Failed to get uncompressed size."); 66 return false; 67 } 68 69 real_decompressed_size = static_cast<size_t>(runtime_decompressed_size); 70 } 71 else 72 { 73 real_decompressed_size = decompressed_size.value(); 74 } 75 76 ret = DynamicHeapArray<u8>(real_decompressed_size); 77 78 const size_t result = ZSTD_decompress(ret->data(), ret->size(), data.data(), data.size()); 79 if (ZSTD_isError(result)) [[unlikely]] 80 { 81 const char* errstr = ZSTD_getErrorString(ZSTD_getErrorCode(result)); 82 Error::SetStringFmt(error, "ZSTD_decompress() failed: {}", errstr ? errstr : "<unknown>"); 83 ret.reset(); 84 return false; 85 } 86 else if (result != real_decompressed_size) [[unlikely]] 87 { 88 Error::SetStringFmt(error, "ZSTD_decompress() only returned {} of {} bytes.", result, real_decompressed_size); 89 ret.reset(); 90 return false; 91 } 92 93 return true; 94 } 95 break; 96 97 DefaultCaseIsUnreachable() 98 } 99 } 100 101 template<typename T> 102 bool CompressHelpers::CompressHelper(OptionalByteBuffer& ret, CompressType type, T data, int clevel, Error* error) 103 { 104 if (data.size() == 0) [[unlikely]] 105 { 106 Error::SetStringView(error, "Buffer is empty."); 107 return false; 108 } 109 110 switch (type) 111 { 112 case CompressType::Uncompressed: 113 { 114 ret = ByteBuffer(std::move(data)); 115 return true; 116 } 117 118 case CompressType::Zstandard: 119 { 120 const size_t compressed_size = ZSTD_compressBound(data.size()); 121 if (compressed_size == 0) [[unlikely]] 122 { 123 Error::SetStringView(error, "ZSTD_compressBound() failed."); 124 return false; 125 } 126 127 ret = ByteBuffer(compressed_size); 128 129 const size_t result = ZSTD_compress(ret->data(), compressed_size, data.data(), data.size(), 130 (clevel < 0) ? 0 : std::clamp(clevel, 1, 22)); 131 if (ZSTD_isError(result)) [[unlikely]] 132 { 133 const char* errstr = ZSTD_getErrorString(ZSTD_getErrorCode(result)); 134 Error::SetStringFmt(error, "ZSTD_compress() failed: {}", errstr ? errstr : "<unknown>"); 135 return false; 136 } 137 138 ret->resize(result); 139 return true; 140 } 141 142 DefaultCaseIsUnreachable() 143 } 144 } 145 146 CompressHelpers::OptionalByteBuffer CompressHelpers::DecompressBuffer(CompressType type, std::span<const u8> data, 147 std::optional<size_t> decompressed_size, 148 Error* error) 149 { 150 CompressHelpers::OptionalByteBuffer ret; 151 DecompressHelper(ret, type, data, decompressed_size, error); 152 return ret; 153 } 154 155 CompressHelpers::OptionalByteBuffer CompressHelpers::DecompressBuffer(CompressType type, OptionalByteBuffer data, 156 std::optional<size_t> decompressed_size, 157 Error* error) 158 { 159 OptionalByteBuffer ret; 160 if (data.has_value()) 161 { 162 DecompressHelper(ret, type, std::move(data.value()), decompressed_size, error); 163 } 164 else 165 { 166 if (error && !error->IsValid()) 167 error->SetStringView("Data buffer is empty."); 168 } 169 170 return ret; 171 } 172 173 CompressHelpers::OptionalByteBuffer CompressHelpers::DecompressFile(std::string_view path, std::span<const u8> data, 174 std::optional<size_t> decompressed_size, 175 Error* error) 176 { 177 OptionalByteBuffer ret; 178 const std::optional<CompressType> type = GetCompressType(path, error); 179 if (type.has_value()) 180 ret = DecompressBuffer(type.value(), data, decompressed_size, error); 181 return ret; 182 } 183 184 CompressHelpers::OptionalByteBuffer CompressHelpers::DecompressFile(std::string_view path, OptionalByteBuffer data, 185 std::optional<size_t> decompressed_size, 186 Error* error) 187 { 188 OptionalByteBuffer ret; 189 const std::optional<CompressType> type = GetCompressType(path, error); 190 if (type.has_value()) 191 ret = DecompressBuffer(type.value(), std::move(data), decompressed_size, error); 192 return ret; 193 } 194 195 CompressHelpers::OptionalByteBuffer 196 CompressHelpers::DecompressFile(const char* path, std::optional<size_t> decompressed_size, Error* error) 197 { 198 OptionalByteBuffer ret; 199 const std::optional<CompressType> type = GetCompressType(path, error); 200 if (type.has_value()) 201 ret = DecompressFile(type.value(), path, decompressed_size, error); 202 return ret; 203 } 204 205 CompressHelpers::OptionalByteBuffer CompressHelpers::DecompressFile(CompressType type, const char* path, 206 std::optional<size_t> decompressed_size, 207 Error* error) 208 { 209 OptionalByteBuffer ret; 210 OptionalByteBuffer data = FileSystem::ReadBinaryFile(path, error); 211 if (data.has_value()) 212 ret = DecompressBuffer(type, std::move(data), decompressed_size, error); 213 return ret; 214 } 215 216 CompressHelpers::OptionalByteBuffer CompressHelpers::CompressToBuffer(CompressType type, std::span<const u8> data, 217 int clevel, Error* error) 218 { 219 OptionalByteBuffer ret; 220 CompressHelper(ret, type, data, clevel, error); 221 return ret; 222 } 223 224 CompressHelpers::OptionalByteBuffer CompressHelpers::CompressToBuffer(CompressType type, const void* data, 225 size_t data_size, int clevel, Error* error) 226 { 227 OptionalByteBuffer ret; 228 CompressHelper(ret, type, std::span<const u8>(static_cast<const u8*>(data), data_size), clevel, error); 229 return ret; 230 } 231 232 CompressHelpers::OptionalByteBuffer CompressHelpers::CompressToBuffer(CompressType type, OptionalByteBuffer data, 233 int clevel, Error* error) 234 { 235 OptionalByteBuffer ret; 236 CompressHelper(ret, type, std::move(data.value()), clevel, error); 237 return ret; 238 } 239 240 bool CompressHelpers::CompressToFile(const char* path, std::span<const u8> data, int clevel, bool atomic_write, 241 Error* error) 242 { 243 const std::optional<CompressType> type = GetCompressType(path, error); 244 if (!type.has_value()) 245 return false; 246 247 return CompressToFile(type.value(), path, data, clevel, atomic_write, error); 248 } 249 250 bool CompressHelpers::CompressToFile(CompressType type, const char* path, std::span<const u8> data, int clevel, 251 bool atomic_write, Error* error) 252 { 253 const OptionalByteBuffer cdata = CompressToBuffer(type, data, clevel, error); 254 if (!cdata.has_value()) 255 return false; 256 257 return atomic_write ? FileSystem::WriteAtomicRenamedFile(path, cdata->data(), cdata->size(), error) : 258 FileSystem::WriteBinaryFile(path, cdata->data(), cdata->size(), error); 259 }