string_pool.cpp (4359B)
1 #include "scraps/format/string_pool.hpp" 2 3 #include "scraps/capnp_helper.hpp" 4 5 #include <libshit/assert.hpp> 6 #include <libshit/char_utils.hpp> 7 #include <libshit/doctest.hpp> 8 #include <libshit/except.hpp> 9 10 #include <capnp/layout.h> 11 #include <capnp/list.h> 12 #include <capnp/message.h> 13 #include <kj/common.h> 14 15 #include <cstring> 16 #include <utility> 17 18 #define LIBSHIT_LOG_NAME "string_pool" 19 #include <libshit/logger_helper.hpp> 20 21 // IWYU pragma: no_forward_declare capnp::List 22 23 namespace Scraps::Format 24 { 25 using namespace Libshit::NonowningStringLiterals; 26 TEST_SUITE_BEGIN("Scraps::Format::StringPool"); 27 28 StringPoolBuilder::StringPoolBuilder() : pos{0} 29 { InternCopy(""_ns); } 30 31 static void CheckStr(Libshit::StringView str, std::uint32_t pos) 32 { 33 auto size = str.size(); 34 if (size >= 0xfffffff0 || pos > 0xfffffff0 - size) 35 LIBSHIT_THROW(StringPoolError, "String pool overflow", 36 "Pool size", pos, "String length", size); 37 38 if (str.find_first_of('\0') != Libshit::StringView::npos) 39 LIBSHIT_THROW(StringPoolError, "String has null bytes"); 40 } 41 42 template <typename T> 43 std::uint32_t StringPoolBuilder::InternGen(T&& str) 44 { 45 CheckStr(str, pos); 46 47 auto it = strings.lower_bound(str); 48 if (it == strings.end() || it->first != str) 49 { 50 it = strings.emplace_hint(it, std::forward<T>(str), pos); 51 DBG(3) << "Interned " << pos << ' ' << Libshit::Quoted(it->first) 52 << std::endl; 53 order[pos] = &it->first; 54 pos += it->first.size() + 1; 55 } 56 else 57 DBG(4) << "Dup " << it->second << ' ' << Libshit::Quoted(it->first) 58 << std::endl; 59 60 return it->second; 61 } 62 63 template std::uint32_t StringPoolBuilder::InternGen(std::string&&); 64 template std::uint32_t StringPoolBuilder::InternGen(Libshit::StringView&); 65 66 capnp::Orphan<capnp::List<std::uint64_t>> 67 StringPoolBuilder::ToCapnp(capnp::Orphanage orphanage) const 68 { 69 DBG(0) << "String pool size = " << pos << std::endl; 70 auto orphan = orphanage.newOrphan<capnp::List<std::uint64_t>>( 71 GetMinCapnpSize()); 72 ToCapnp(orphan.get(), 0); 73 return orphan; 74 } 75 76 void StringPoolBuilder::ToCapnp( 77 capnp::List<std::uint64_t>::Builder lst, std::uint32_t from) const 78 { 79 LIBSHIT_ASSERT(lst.size() >= GetMinCapnpSize()); 80 char* p = reinterpret_cast<char*>( 81 PrivateGetter::GetBuilder(lst).getLocation()); 82 83 for (auto it = order.lower_bound(from); it != order.end(); ++it) 84 memcpy(p + it->first, it->second->c_str(), it->second->size() + 1); 85 } 86 87 TEST_CASE("Builder") 88 { 89 StringPoolBuilder bld; 90 // 1+4+3 == 8 bytes, no padding 91 CHECK(bld.InternString("abc") == 1); 92 CHECK(bld.InternCopy("de") == 5); 93 94 capnp::MallocMessageBuilder cbld; 95 auto orphan = bld.ToCapnp(cbld.getOrphanage()); 96 REQUIRE(orphan.get().size() == 1); 97 CHECK(orphan.get()[0] == 0x006564'00636261'00); 98 99 // strings deduplicated 100 CHECK(bld.InternString("de") == 5); 101 CHECK(bld.InternCopy("abc") == 1); 102 CHECK(bld.InternString("") == 0); 103 104 orphan = bld.ToCapnp(cbld.getOrphanage()); 105 REQUIRE(orphan.get().size() == 1); 106 CHECK(orphan.get()[0] == 0x006564'00636261'00); 107 108 // 8 more bytes -> max padding 109 CHECK(bld.InternCopy("12345678") == 8); 110 orphan = bld.ToCapnp(cbld.getOrphanage()); 111 REQUIRE(orphan.get().size() == 3); 112 CHECK(orphan.get()[0] == 0x006564'00636261'00); 113 CHECK(orphan.get()[1] == 0x3837363534333231); 114 CHECK(orphan.get()[2] == 0); 115 } 116 117 StringPoolReader::StringPoolReader(capnp::List<std::uint64_t>::Reader lst) 118 { 119 auto ptr = PrivateGetter::GetReader(lst).asRawBytes(); 120 if (ptr.size() == 0 || ptr.back() != '\0') 121 LIBSHIT_THROW(Libshit::DecodeError, "Invalid string pool"); 122 buf = { ptr.begin(), ptr.size()-1 }; 123 } 124 125 TEST_CASE("Reader") 126 { 127 capnp::MallocMessageBuilder bld; 128 auto lst = bld.initRoot<capnp::List<std::uint64_t>>(0); 129 CHECK_THROWS(StringPoolReader{lst}); 130 131 lst = bld.initRoot<capnp::List<std::uint64_t>>(1); 132 lst.set(0, 0x01000000'00000000); 133 CHECK_THROWS(StringPoolReader{lst}); 134 135 lst.set(0, 0x006564'00636261'00); 136 StringPoolReader rd{lst}; 137 CHECK(rd.Get(0) == ""_ns); 138 CHECK(rd.Get(1) == "abc"_ns); 139 CHECK(rd.Get(2) == "bc"_ns); 140 CHECK(rd.Get(5) == "de"_ns); 141 CHECK_THROWS(rd.Get(8)); 142 CHECK_THROWS(rd.Get(0x80000001)); 143 } 144 145 TEST_SUITE_END(); 146 }