// // Copyright (c) 2019-2023 Ruben Perez Hidalgo (rubenperez038 at gmail dot com) // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #ifndef BHO_MYSQL_IMPL_INTERNAL_PROTOCOL_SERIALIZATION_HPP #define BHO_MYSQL_IMPL_INTERNAL_PROTOCOL_SERIALIZATION_HPP #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace bho { namespace mysql { namespace detail { // We operate with this enum directly in the deserialization routines for efficiency, then transform it to an // actual error code enum class deserialize_errc { ok = 0, incomplete_message = 1, protocol_value_error, server_unsupported }; inline error_code to_error_code(deserialize_errc v) noexcept { switch (v) { case deserialize_errc::ok: return error_code(); case deserialize_errc::incomplete_message: return error_code(client_errc::incomplete_message); case deserialize_errc::protocol_value_error: return error_code(client_errc::protocol_value_error); case deserialize_errc::server_unsupported: return error_code(client_errc::server_unsupported); default: BHO_ASSERT(false); return error_code(); // avoid warnings } } class serialization_context { std::uint8_t* first_; public: explicit serialization_context(std::uint8_t* first) noexcept : first_(first) {} std::uint8_t* first() const noexcept { return first_; } void advance(std::size_t size) noexcept { first_ += size; } void write(const void* buffer, std::size_t size) noexcept { if (size) { BHO_ASSERT(buffer != nullptr); std::memcpy(first_, buffer, size); advance(size); } } void write(std::uint8_t elm) noexcept { *first_ = elm; ++first_; } }; class deserialization_context { const std::uint8_t* first_; const std::uint8_t* last_; public: deserialization_context(span data) noexcept : deserialization_context(data.data(), data.size()) { } deserialization_context(const std::uint8_t* first, std::size_t size) noexcept : first_(first), last_(first + size){}; const std::uint8_t* first() const noexcept { return first_; } const std::uint8_t* last() const noexcept { return last_; } void advance(std::size_t sz) noexcept { first_ += sz; BHO_ASSERT(last_ >= first_); } void rewind(std::size_t sz) noexcept { first_ -= sz; } std::size_t size() const noexcept { return last_ - first_; } bool empty() const noexcept { return last_ == first_; } bool enough_size(std::size_t required_size) const noexcept { return size() >= required_size; } deserialize_errc copy(void* to, std::size_t sz) noexcept { if (!enough_size(sz)) return deserialize_errc::incomplete_message; memcpy(to, first_, sz); advance(sz); return deserialize_errc::ok; } string_view get_string(std::size_t sz) const noexcept { return string_view(reinterpret_cast(first_), sz); } error_code check_extra_bytes() const noexcept { return empty() ? error_code() : error_code(client_errc::extra_bytes); } span to_span() const noexcept { return span(first_, size()); } }; // integers template ::value>::type> deserialize_errc deserialize(deserialization_context& ctx, T& output) noexcept { constexpr std::size_t sz = sizeof(T); if (!ctx.enough_size(sz)) { return deserialize_errc::incomplete_message; } output = endian::endian_load(ctx.first()); ctx.advance(sz); return deserialize_errc::ok; } template ::value>::type> void serialize(serialization_context& ctx, T input) noexcept { endian::endian_store(ctx.first(), input); ctx.advance(sizeof(T)); } template ::value>::type> constexpr std::size_t get_size(T) noexcept { return sizeof(T); } // int3 inline deserialize_errc deserialize(deserialization_context& ctx, int3& output) noexcept { if (!ctx.enough_size(3)) return deserialize_errc::incomplete_message; output.value = endian::load_little_u24(ctx.first()); ctx.advance(3); return deserialize_errc::ok; } inline void serialize(serialization_context& ctx, int3 input) noexcept { endian::store_little_u24(ctx.first(), input.value); ctx.advance(3); } constexpr std::size_t get_size(int3) noexcept { return 3; } // int_lenenc inline deserialize_errc deserialize(deserialization_context& ctx, int_lenenc& output) noexcept { std::uint8_t first_byte = 0; auto err = deserialize(ctx, first_byte); if (err != deserialize_errc::ok) { return err; } if (first_byte == 0xFC) { std::uint16_t value = 0; err = deserialize(ctx, value); output.value = value; } else if (first_byte == 0xFD) { int3 value{}; err = deserialize(ctx, value); output.value = value.value; } else if (first_byte == 0xFE) { std::uint64_t value = 0; err = deserialize(ctx, value); output.value = value; } else { err = deserialize_errc::ok; output.value = first_byte; } return err; } inline void serialize(serialization_context& ctx, int_lenenc input) noexcept { if (input.value < 251) { serialize(ctx, static_cast(input.value)); } else if (input.value < 0x10000) { ctx.write(0xfc); serialize(ctx, static_cast(input.value)); } else if (input.value < 0x1000000) { ctx.write(0xfd); serialize(ctx, int3{static_cast(input.value)}); } else { ctx.write(0xfe); serialize(ctx, static_cast(input.value)); } } inline std::size_t get_size(int_lenenc input) noexcept { if (input.value < 251) return 1; else if (input.value < 0x10000) return 3; else if (input.value < 0x1000000) return 4; else return 9; } // protocol_field_type inline deserialize_errc deserialize(deserialization_context& ctx, protocol_field_type& output) noexcept { std::underlying_type::type value = 0; auto err = deserialize(ctx, value); output = static_cast(value); return err; } inline void serialize(serialization_context& ctx, protocol_field_type input) noexcept { serialize(ctx, static_cast::type>(input)); } constexpr std::size_t get_size(protocol_field_type) noexcept { return sizeof(protocol_field_type); } // string_fixed template deserialize_errc deserialize(deserialization_context& ctx, string_fixed& output) noexcept { if (!ctx.enough_size(N)) return deserialize_errc::incomplete_message; memcpy(output.value.data(), ctx.first(), N); ctx.advance(N); return deserialize_errc::ok; } template void serialize(serialization_context& ctx, const string_fixed& input) noexcept { ctx.write(input.value.data(), N); } template constexpr std::size_t get_size(const string_fixed&) noexcept { return N; } // string_null inline deserialize_errc deserialize(deserialization_context& ctx, string_null& output) noexcept { auto string_end = std::find(ctx.first(), ctx.last(), 0); if (string_end == ctx.last()) { return deserialize_errc::incomplete_message; } std::size_t length = string_end - ctx.first(); output.value = ctx.get_string(length); ctx.advance(length + 1); // skip the null terminator return deserialize_errc::ok; } inline void serialize(serialization_context& ctx, string_null input) noexcept { ctx.write(input.value.data(), input.value.size()); ctx.write(0); // null terminator } inline std::size_t get_size(string_null input) noexcept { return input.value.size() + 1; } // string_eof inline deserialize_errc deserialize(deserialization_context& ctx, string_eof& output) noexcept { std::size_t size = ctx.size(); output.value = ctx.get_string(size); ctx.advance(size); return deserialize_errc::ok; } inline void serialize(serialization_context& ctx, string_eof input) noexcept { ctx.write(input.value.data(), input.value.size()); } inline std::size_t get_size(string_eof input) noexcept { return input.value.size(); } // string_lenenc inline deserialize_errc deserialize(deserialization_context& ctx, string_lenenc& output) noexcept { int_lenenc length; auto err = deserialize(ctx, length); if (err != deserialize_errc::ok) { return err; } if (length.value > (std::numeric_limits::max)()) { return deserialize_errc::protocol_value_error; } auto len = static_cast(length.value); if (!ctx.enough_size(len)) { return deserialize_errc::incomplete_message; } output.value = ctx.get_string(len); ctx.advance(len); return deserialize_errc::ok; } inline void serialize(serialization_context& ctx, string_lenenc input) noexcept { serialize(ctx, int_lenenc{input.value.size()}); ctx.write(input.value.data(), input.value.size()); } inline std::size_t get_size(string_lenenc input) noexcept { return get_size(int_lenenc{input.value.size()}) + input.value.size(); } // serialize, deserialize, and get size of multiple fields at the same time template deserialize_errc deserialize( deserialization_context& ctx, FirstType& first, SecondType& second, Rest&... tail ) noexcept { deserialize_errc err = deserialize(ctx, first); if (err == deserialize_errc::ok) { err = deserialize(ctx, second, tail...); } return err; } template void serialize( serialization_context& ctx, const FirstType& first, const SecondType& second, const Rest&... rest ) noexcept { serialize(ctx, first); serialize(ctx, second, rest...); } template std::size_t get_size(const FirstType& first, const SecondType& second, const Rest&... rest) noexcept { return get_size(first) + get_size(second, rest...); } // helpers inline string_view to_string(span v) noexcept { return string_view(reinterpret_cast(v.data()), v.size()); } inline span to_span(string_view v) noexcept { return span(reinterpret_cast(v.data()), v.size()); } } // namespace detail } // namespace mysql } // namespace bho #endif