// // Copyright (c) 2019-2024 Ruben Perez Hidalgo (rubenperez038 at gmail dot com) // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #ifndef BOOST_MYSQL_IMPL_INTERNAL_PROTOCOL_DESERIALIZATION_HPP #define BOOST_MYSQL_IMPL_INTERNAL_PROTOCOL_DESERIALIZATION_HPP #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace boost { namespace mysql { namespace detail { // OK packets (views because strings are non-owning) inline error_code deserialize_ok_packet(span msg, ok_view& output); // for testing // Error packets (exposed for testing) struct err_view { std::uint16_t error_code; string_view error_message; }; BOOST_ATTRIBUTE_NODISCARD inline error_code deserialize_error_packet( span message, err_view& pack, bool has_sql_state = true ); BOOST_ATTRIBUTE_NODISCARD inline error_code process_error_packet( span message, db_flavor flavor, diagnostics& diag, bool has_sql_state = true ); // Deserializes a response that may be an OK or an error packet. // Applicable for commands like ping and reset connection. // If the response is an OK packet, sets backslash_escapes according to the // OK packet's server status flags BOOST_ATTRIBUTE_NODISCARD inline error_code deserialize_ok_response( span message, db_flavor flavor, diagnostics& diag, bool& backslash_escapes ); // Column definition BOOST_ATTRIBUTE_NODISCARD inline error_code deserialize_column_definition( span input, coldef_view& output ); // Prepare statement response struct prepare_stmt_response { std::uint32_t id; std::uint16_t num_columns; std::uint16_t num_params; }; BOOST_ATTRIBUTE_NODISCARD inline error_code deserialize_prepare_stmt_response_impl( span message, prepare_stmt_response& output ); // exposed for testing, doesn't take header into account BOOST_ATTRIBUTE_NODISCARD inline error_code deserialize_prepare_stmt_response( span message, db_flavor flavor, prepare_stmt_response& output, diagnostics& diag ); // Execution messages struct execute_response { enum class type_t { num_fields, ok_packet, error } type; union data_t { std::size_t num_fields; ok_view ok_pack; error_code err; data_t(size_t v) noexcept : num_fields(v) {} data_t(const ok_view& v) noexcept : ok_pack(v) {} data_t(error_code v) noexcept : err(v) {} } data; execute_response(std::size_t v) noexcept : type(type_t::num_fields), data(v) {} execute_response(const ok_view& v) noexcept : type(type_t::ok_packet), data(v) {} execute_response(error_code v) noexcept : type(type_t::error), data(v) {} }; inline execute_response deserialize_execute_response( span msg, db_flavor flavor, diagnostics& diag ); struct row_message { enum class type_t { row, ok_packet, error } type; union data_t { span row; ok_view ok_pack; error_code err; data_t(span row) noexcept : row(row) {} data_t(const ok_view& ok_pack) noexcept : ok_pack(ok_pack) {} data_t(error_code err) noexcept : err(err) {} } data; row_message(span row) noexcept : type(type_t::row), data(row) {} row_message(const ok_view& ok_pack) noexcept : type(type_t::ok_packet), data(ok_pack) {} row_message(error_code v) noexcept : type(type_t::error), data(v) {} }; inline row_message deserialize_row_message(span msg, db_flavor flavor, diagnostics& diag); inline error_code deserialize_row( resultset_encoding encoding, span message, metadata_collection_view meta, span output // Should point to meta.size() field_view objects ); // Server hello struct server_hello { using auth_buffer_type = static_buffer<8 + 0xff>; db_flavor server; auth_buffer_type auth_plugin_data; capabilities server_capabilities{}; string_view auth_plugin_name; }; BOOST_ATTRIBUTE_NODISCARD inline error_code deserialize_server_hello_impl( span msg, server_hello& output ); // exposed for testing, doesn't take message header into account BOOST_ATTRIBUTE_NODISCARD inline error_code deserialize_server_hello( span msg, server_hello& output, diagnostics& diag ); // Auth switch struct auth_switch { string_view plugin_name; span auth_data; }; BOOST_ATTRIBUTE_NODISCARD inline error_code deserialize_auth_switch( span msg, auth_switch& output ); // exposed for testing // Handshake server response struct handhake_server_response { struct ok_follows_t { }; enum class type_t { ok, error, ok_follows, auth_switch, auth_more_data } type; union data_t { ok_view ok; error_code err; ok_follows_t ok_follows; auth_switch auth_sw; span more_data; data_t(const ok_view& ok) noexcept : ok(ok) {} data_t(error_code err) noexcept : err(err) {} data_t(ok_follows_t) noexcept : ok_follows({}) {} data_t(auth_switch msg) noexcept : auth_sw(msg) {} data_t(span more_data) noexcept : more_data(more_data) {} } data; handhake_server_response(const ok_view& ok) noexcept : type(type_t::ok), data(ok) {} handhake_server_response(error_code err) noexcept : type(type_t::error), data(err) {} handhake_server_response(ok_follows_t) noexcept : type(type_t::ok_follows), data(ok_follows_t{}) {} handhake_server_response(auth_switch auth_switch) noexcept : type(type_t::auth_switch), data(auth_switch) { } handhake_server_response(span more_data) noexcept : type(type_t::auth_more_data), data(more_data) { } }; inline handhake_server_response deserialize_handshake_server_response( span buff, db_flavor flavor, diagnostics& diag ); } // namespace detail } // namespace mysql } // namespace boost // // Implementations // namespace boost { namespace mysql { namespace detail { // Constants BOOST_INLINE_CONSTEXPR std::uint8_t error_packet_header = 0xff; BOOST_INLINE_CONSTEXPR std::uint8_t ok_packet_header = 0x00; } // namespace detail } // namespace mysql } // namespace boost // // Deserialization // // OK packets boost::mysql::error_code boost::mysql::detail::deserialize_ok_packet( span msg, ok_view& output ) { struct ok_packet { // header: int<1> header 0x00 or 0xFE the OK packet header int_lenenc affected_rows; int_lenenc last_insert_id; int2 status_flags; // server_status_flags int2 warnings; // CLIENT_SESSION_TRACK: not implemented string_lenenc info; } pack{}; deserialization_context ctx(msg); auto err = ctx.deserialize(pack.affected_rows, pack.last_insert_id, pack.status_flags, pack.warnings); if (err != deserialize_errc::ok) return to_error_code(err); if (ctx.enough_size(1)) // message is optional, may be omitted { err = pack.info.deserialize(ctx); if (err != deserialize_errc::ok) return to_error_code(err); } output = { pack.affected_rows.value, pack.last_insert_id.value, pack.status_flags.value, pack.warnings.value, pack.info.value, }; return ctx.check_extra_bytes(); } // Error packets boost::mysql::error_code boost::mysql::detail::deserialize_error_packet( span msg, err_view& output, bool has_sql_state ) { struct err_packet { // int1 header 0xFF ERR packet header int2 error_code; // if capabilities & CLIENT_PROTOCOL_41 { (modeled here as has_sql_state) string_fixed<1> sql_state_marker; string_fixed<5> sql_state; // } string_eof error_message; } pack{}; deserialization_context ctx(msg); auto err = has_sql_state ? ctx.deserialize( pack.error_code, pack.sql_state_marker, pack.sql_state, pack.error_message ) : ctx.deserialize(pack.error_code, pack.error_message); if (err != deserialize_errc::ok) return to_error_code(err); output = err_view{ pack.error_code.value, pack.error_message.value, }; return ctx.check_extra_bytes(); } boost::mysql::error_code boost::mysql::detail::process_error_packet( span msg, db_flavor flavor, diagnostics& diag, bool has_sql_state ) { err_view error_packet{}; auto err = deserialize_error_packet(msg, error_packet, has_sql_state); if (err) return err; // Error message access::get_impl(diag).assign_server(error_packet.error_message); // Error code if (common_error_to_string(error_packet.error_code)) { // This is an error shared between MySQL and MariaDB, represented as a common_server_errc. // get_common_error_message will check that the code has a common_server_errc representation // (the common error range has "holes" because of removed error codes) return static_cast(error_packet.error_code); } else { // This is a MySQL or MariaDB specific code. There is no fixed list of error codes, // as they both keep adding more codes, so no validation happens. const auto& cat = flavor == db_flavor::mysql ? get_mysql_server_category() : get_mariadb_server_category(); return error_code(error_packet.error_code, cat); } } // Column definition boost::mysql::error_code boost::mysql::detail::deserialize_column_definition( span input, coldef_view& output ) { deserialization_context ctx(input); struct column_definition_packet { string_lenenc catalog; // always "def" string_lenenc schema; // database string_lenenc table; // virtual table string_lenenc org_table; // physical table string_lenenc name; // virtual column name string_lenenc org_name; // physical column name string_lenenc fixed_fields; } pack{}; // pack.fixed_fields itself is a structure like this. // The proto allows for extensibility here - adding fields just increasing fixed_fields.length struct fixed_fields_pack { int2 character_set; // collation id, somehow named character_set in the protocol docs int4 column_length; // maximum length of the field int1 type; // type of the column as defined in enum_field_types - this is a protocol_field_type int2 flags; // Flags as defined in Column Definition Flags int1 decimals; // max shown decimal digits. 0x00 for int/static strings; 0x1f for // dynamic strings, double, float } fixed_fields{}; // Deserialize the main structure auto err = ctx.deserialize( pack.catalog, pack.schema, pack.table, pack.org_table, pack.name, pack.org_name, pack.fixed_fields ); if (err != deserialize_errc::ok) return to_error_code(err); // Deserialize the fixed_fields structure. // Intentionally not checking for extra bytes here, since there may be unknown fields that should just get // ignored deserialization_context subctx(to_span(pack.fixed_fields.value)); err = subctx.deserialize( fixed_fields.character_set, fixed_fields.column_length, fixed_fields.type, fixed_fields.flags, fixed_fields.decimals ); if (err != deserialize_errc::ok) return to_error_code(err); // Compose output output = coldef_view{ pack.schema.value, pack.table.value, pack.org_table.value, pack.name.value, pack.org_name.value, fixed_fields.character_set.value, fixed_fields.column_length.value, compute_column_type( static_cast(fixed_fields.type.value), fixed_fields.flags.value, fixed_fields.character_set.value ), fixed_fields.flags.value, fixed_fields.decimals.value, }; return ctx.check_extra_bytes(); } boost::mysql::error_code boost::mysql::detail::deserialize_ok_response( span message, db_flavor flavor, diagnostics& diag, bool& backslash_escapes ) { // Header int1 header{}; deserialization_context ctx(message); auto err = to_error_code(header.deserialize(ctx)); if (err) return err; if (header.value == ok_packet_header) { // Verify that the ok_packet is correct ok_view ok{}; err = deserialize_ok_packet(ctx.to_span(), ok); if (err) return err; backslash_escapes = ok.backslash_escapes(); return error_code(); } else if (header.value == error_packet_header) { // Theoretically, the server can answer with an error packet, too return process_error_packet(ctx.to_span(), flavor, diag); } else { // Invalid message return client_errc::protocol_value_error; } } boost::mysql::error_code boost::mysql::detail::deserialize_prepare_stmt_response_impl( span message, prepare_stmt_response& output ) { struct com_stmt_prepare_ok_packet { // std::uint8_t status: must be 0 int4 statement_id; int2 num_columns; int2 num_params; int1 reserved_1; // must be 0 int2 warning_count; // int1 metadata_follows when CLIENT_OPTIONAL_RESULTSET_METADATA: not implemented } pack{}; deserialization_context ctx(message); auto err = ctx.deserialize( pack.statement_id, pack.num_columns, pack.num_params, pack.reserved_1, pack.warning_count ); if (err != deserialize_errc::ok) return to_error_code(err); output = prepare_stmt_response{ pack.statement_id.value, pack.num_columns.value, pack.num_params.value, }; return ctx.check_extra_bytes(); } boost::mysql::error_code boost::mysql::detail::deserialize_prepare_stmt_response( span message, db_flavor flavor, prepare_stmt_response& output, diagnostics& diag ) { deserialization_context ctx(message); int1 msg_type{}; auto err = to_error_code(msg_type.deserialize(ctx)); if (err) return err; if (msg_type.value == error_packet_header) { return process_error_packet(ctx.to_span(), flavor, diag); } else if (msg_type.value != 0) { return client_errc::protocol_value_error; } else { return deserialize_prepare_stmt_response_impl(ctx.to_span(), output); } } // execute response boost::mysql::detail::execute_response boost::mysql::detail::deserialize_execute_response( span msg, db_flavor flavor, diagnostics& diag ) { // Response may be: ok_packet, err_packet, local infile request (not implemented) // If it is none of this, then the message type itself is the beginning of // a length-encoded int containing the field count deserialization_context ctx(msg); int1 msg_type{}; auto err = to_error_code(msg_type.deserialize(ctx)); if (err) return err; if (msg_type.value == ok_packet_header) { ok_view ok{}; err = deserialize_ok_packet(ctx.to_span(), ok); if (err) return err; return ok; } else if (msg_type.value == error_packet_header) { return process_error_packet(ctx.to_span(), flavor, diag); } else { // Resultset with metadata. First packet is an int_lenenc with // the number of field definitions to expect. Message type is part // of this packet, so we must rewind the context ctx.rewind(1); int_lenenc num_fields{}; err = to_error_code(num_fields.deserialize(ctx)); if (err) return err; err = ctx.check_extra_bytes(); if (err) return err; // We should have at least one field. // The max number of fields is some value around 1024. For simplicity/extensibility, // we accept anything less than 0xffff if (num_fields.value == 0 || num_fields.value > 0xffffu) { return make_error_code(client_errc::protocol_value_error); } return static_cast(num_fields.value); } } boost::mysql::detail::row_message boost::mysql::detail::deserialize_row_message( span msg, db_flavor flavor, diagnostics& diag ) { constexpr std::uint8_t eof_packet_header = 0xfe; // Message type: row, error or eof? int1 msg_type{}; deserialization_context ctx(msg); auto deser_errc = msg_type.deserialize(ctx); if (deser_errc != deserialize_errc::ok) { return to_error_code(deser_errc); } if (msg_type.value == eof_packet_header) { // end of resultset => this is a ok_packet, not a row ok_view ok{}; auto err = deserialize_ok_packet(ctx.to_span(), ok); if (err) return err; return ok; } else if (msg_type.value == error_packet_header) { // An error occurred during the generation of the rows return process_error_packet(ctx.to_span(), flavor, diag); } else { // An actual row ctx.rewind(1); // keep the 'message type' byte, as it is part of the actual message return span(ctx.first(), ctx.size()); } } // Deserialize row namespace boost { namespace mysql { namespace detail { inline bool is_next_field_null(const deserialization_context& ctx) { if (!ctx.enough_size(1)) return false; return *ctx.first() == 0xfb; } inline error_code deserialize_text_row( deserialization_context& ctx, metadata_collection_view meta, field_view* output ) { for (std::vector::size_type i = 0; i < meta.size(); ++i) { if (is_next_field_null(ctx)) { ctx.advance(1); output[i] = field_view(nullptr); } else { string_lenenc value_str; auto err = value_str.deserialize(ctx); if (err != deserialize_errc::ok) return to_error_code(err); err = deserialize_text_field(value_str.value, meta[i], output[i]); if (err != deserialize_errc::ok) return to_error_code(err); } } return ctx.check_extra_bytes(); } inline error_code deserialize_binary_row( deserialization_context& ctx, metadata_collection_view meta, field_view* output ) { // Skip packet header (it is not part of the message in the binary // protocol but it is in the text protocol, so we include it for homogeneity) if (!ctx.enough_size(1)) return client_errc::incomplete_message; ctx.advance(1); // Number of fields std::size_t num_fields = meta.size(); // Null bitmap null_bitmap_parser null_bitmap(num_fields); const std::uint8_t* null_bitmap_first = ctx.first(); std::size_t null_bitmap_size = null_bitmap.byte_count(); if (!ctx.enough_size(null_bitmap_size)) return client_errc::incomplete_message; ctx.advance(null_bitmap_size); // Actual values for (std::vector::size_type i = 0; i < num_fields; ++i) { if (null_bitmap.is_null(null_bitmap_first, i)) { output[i] = field_view(nullptr); } else { auto err = deserialize_binary_field(ctx, meta[i], output[i]); if (err != deserialize_errc::ok) return to_error_code(err); } } // Check for remaining bytes return ctx.check_extra_bytes(); } } // namespace detail } // namespace mysql } // namespace boost boost::mysql::error_code boost::mysql::detail::deserialize_row( resultset_encoding encoding, span buff, metadata_collection_view meta, span output ) { BOOST_ASSERT(meta.size() == output.size()); deserialization_context ctx(buff); return encoding == detail::resultset_encoding::text ? deserialize_text_row(ctx, meta, output.data()) : deserialize_binary_row(ctx, meta, output.data()); } // Server hello namespace boost { namespace mysql { namespace detail { inline capabilities compose_capabilities(string_fixed<2> low, string_fixed<2> high) { std::uint32_t res = 0; auto capabilities_begin = reinterpret_cast(&res); memcpy(capabilities_begin, low.value.data(), 2); memcpy(capabilities_begin + 2, high.value.data(), 2); return capabilities(boost::endian::little_to_native(res)); } inline db_flavor parse_db_version(string_view version_string) { return version_string.find("MariaDB") != string_view::npos ? db_flavor::mariadb : db_flavor::mysql; } } // namespace detail } // namespace mysql } // namespace boost boost::mysql::error_code boost::mysql::detail::deserialize_server_hello_impl( span msg, server_hello& output ) { struct server_hello_packet { // int<1> protocol version Always 10 string_null server_version; int4 connection_id; string_fixed<8> auth_plugin_data_part_1; int1 filler; // should be 0 string_fixed<2> capability_flags_low; int1 character_set; // default server a_protocol_character_set, only the lower 8-bits int2 status_flags; // server_status_flags string_fixed<2> capability_flags_high; int1 auth_plugin_data_len; string_fixed<10> reserved; // auth plugin data, 2nd part. This has a weird representation that doesn't fit any defined type string_null auth_plugin_name; } pack{}; deserialization_context ctx(msg); auto err = ctx.deserialize( pack.server_version, pack.connection_id, pack.auth_plugin_data_part_1, pack.filler, pack.capability_flags_low, pack.character_set, pack.status_flags, pack.capability_flags_high ); if (err != deserialize_errc::ok) return to_error_code(err); // Compose capabilities auto cap = compose_capabilities(pack.capability_flags_low, pack.capability_flags_high); // Check minimum server capabilities to deserialize this frame if (!cap.has(CLIENT_PLUGIN_AUTH)) return client_errc::server_unsupported; // Deserialize next fields err = ctx.deserialize(pack.auth_plugin_data_len, pack.reserved); if (err != deserialize_errc::ok) return to_error_code(err); // Auth plugin data, second part auto auth2_length = static_cast((std::max)( static_cast(13u), static_cast(pack.auth_plugin_data_len.value - pack.auth_plugin_data_part_1.value.size()) )); const void* auth2_data = ctx.first(); if (!ctx.enough_size(auth2_length)) return client_errc::incomplete_message; ctx.advance(auth2_length); // Auth plugin name err = pack.auth_plugin_name.deserialize(ctx); if (err != deserialize_errc::ok) return to_error_code(err); // Compose output output.server = parse_db_version(pack.server_version.value); output.server_capabilities = cap; output.auth_plugin_name = pack.auth_plugin_name.value; // Compose auth_plugin_data output.auth_plugin_data.clear(); output.auth_plugin_data.append( pack.auth_plugin_data_part_1.value.data(), pack.auth_plugin_data_part_1.value.size() ); output.auth_plugin_data.append(auth2_data, auth2_length - 1); // discard an extra trailing NULL byte return ctx.check_extra_bytes(); } boost::mysql::error_code boost::mysql::detail::deserialize_server_hello( span msg, server_hello& output, diagnostics& diag ) { constexpr std::uint8_t handshake_protocol_version_9 = 9; constexpr std::uint8_t handshake_protocol_version_10 = 10; deserialization_context ctx(msg); // Message type int1 msg_type{}; auto err = to_error_code(msg_type.deserialize(ctx)); if (err) return err; if (msg_type.value == handshake_protocol_version_9) { return make_error_code(client_errc::server_unsupported); } else if (msg_type.value == error_packet_header) { // We don't know which DB is yet. The server has no knowledge of our capabilities // yet, so it will assume we don't support the 4.1 protocol and send an error // packet without SQL state return process_error_packet(ctx.to_span(), db_flavor::mysql, diag, false); } else if (msg_type.value != handshake_protocol_version_10) { return make_error_code(client_errc::protocol_value_error); } else { return deserialize_server_hello_impl(ctx.to_span(), output); } } // auth_switch BOOST_ATTRIBUTE_NODISCARD boost::mysql::error_code boost::mysql::detail::deserialize_auth_switch( span msg, auth_switch& output ) { struct auth_switch_request_packet { string_null plugin_name; string_eof auth_plugin_data; } pack{}; deserialization_context ctx(msg); auto err = ctx.deserialize(pack.plugin_name, pack.auth_plugin_data); if (err != deserialize_errc::ok) return to_error_code(err); // Discard an additional NULL at the end of auth data string_view auth_data = pack.auth_plugin_data.value; if (!auth_data.empty() && auth_data.back() == 0) { auth_data = auth_data.substr(0, auth_data.size() - 1); } output = { pack.plugin_name.value, to_span(auth_data), }; return ctx.check_extra_bytes(); } boost::mysql::detail::handhake_server_response boost::mysql::detail::deserialize_handshake_server_response( span buff, db_flavor flavor, diagnostics& diag ) { constexpr std::uint8_t auth_switch_request_header = 0xfe; constexpr std::uint8_t auth_more_data_header = 0x01; constexpr string_view fast_auth_complete_challenge = make_string_view("\3"); deserialization_context ctx(buff); int1 msg_type{}; auto err = to_error_code(msg_type.deserialize(ctx)); if (err) return err; if (msg_type.value == ok_packet_header) { ok_view ok{}; err = deserialize_ok_packet(ctx.to_span(), ok); if (err) return err; return ok; } else if (msg_type.value == error_packet_header) { return process_error_packet(ctx.to_span(), flavor, diag); } else if (msg_type.value == auth_switch_request_header) { // We have received an auth switch request. Deserialize it auth_switch auth_sw{}; err = deserialize_auth_switch(ctx.to_span(), auth_sw); if (err) return err; return auth_sw; } else if (msg_type.value == auth_more_data_header) { // We have received an auth more data request. Deserialize it. // Note that string_eof never fails deserialization (by definition) string_eof auth_more_data; auto ec = auth_more_data.deserialize(ctx); BOOST_ASSERT(ec == deserialize_errc::ok); boost::ignore_unused(ec); // If the special value fast_auth_complete_challenge // is received as auth data, it means that the auth is complete // but we must wait for another OK message. We consider this // a special type of message string_view challenge = auth_more_data.value; if (challenge == fast_auth_complete_challenge) { return handhake_server_response::ok_follows_t(); } // Otherwise, just return the normal data return handhake_server_response(to_span(challenge)); } else { // Unknown message type return make_error_code(client_errc::protocol_value_error); } } #endif