// // Copyright (c) 2019-2024 Ruben Perez Hidalgo (rubenperez038 at gmail dot com) // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #ifndef BOOST_MYSQL_IMPL_INTERNAL_VARIANT_STREAM_HPP #define BOOST_MYSQL_IMPL_INTERNAL_VARIANT_STREAM_HPP #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace boost { namespace mysql { namespace detail { // Asio defines a "string view parameter" to be either const std::string&, // std::experimental::string_view or std::string_view. Casting from the Boost // version doesn't work for std::experimental::string_view #if defined(BOOST_ASIO_HAS_STD_STRING_VIEW) inline std::string_view cast_asio_sv_param(string_view input) noexcept { return input; } #elif defined(BOOST_ASIO_HAS_STD_EXPERIMENTAL_STRING_VIEW) inline std::experimental::string_view cast_asio_sv_param(string_view input) noexcept { return {input.data(), input.size()}; } #else inline std::string cast_asio_sv_param(string_view input) { return input; } #endif // Implements the EngineStream concept (see stream_adaptor) class variant_stream { public: variant_stream(asio::any_io_executor ex, asio::ssl::context* ctx) : ex_(std::move(ex)), ssl_ctx_(ctx) {} bool supports_ssl() const { return true; } void set_endpoint(const void* value) { address_ = static_cast(value); } // Executor using executor_type = asio::any_io_executor; executor_type get_executor() { return ex_; } // SSL void ssl_handshake(error_code& ec) { create_ssl_stream(); ssl_->handshake(asio::ssl::stream_base::client, ec); } template void async_ssl_handshake(CompletionToken&& token) { create_ssl_stream(); ssl_->async_handshake(asio::ssl::stream_base::client, std::forward(token)); } void ssl_shutdown(error_code& ec) { BOOST_ASSERT(ssl_.has_value()); ssl_->shutdown(ec); } template void async_ssl_shutdown(CompletionToken&& token) { BOOST_ASSERT(ssl_.has_value()); ssl_->async_shutdown(std::forward(token)); } // Reading std::size_t read_some(asio::mutable_buffer buff, bool use_ssl, error_code& ec) { if (use_ssl) { BOOST_ASSERT(ssl_.has_value()); return ssl_->read_some(buff, ec); } else if (auto* tcp_sock = variant2::get_if(&sock_)) { return tcp_sock->sock.read_some(buff, ec); } #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS else if (auto* unix_sock = variant2::get_if(&sock_)) { return unix_sock->read_some(buff, ec); } #endif else { BOOST_ASSERT(false); return 0u; } } template void async_read_some(asio::mutable_buffer buff, bool use_ssl, CompletionToken&& token) { if (use_ssl) { BOOST_ASSERT(ssl_.has_value()); ssl_->async_read_some(buff, std::forward(token)); } else if (auto* tcp_sock = variant2::get_if(&sock_)) { tcp_sock->sock.async_read_some(buff, std::forward(token)); } #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS else if (auto* unix_sock = variant2::get_if(&sock_)) { unix_sock->async_read_some(buff, std::forward(token)); } #endif else { BOOST_ASSERT(false); } } // Writing std::size_t write_some(boost::asio::const_buffer buff, bool use_ssl, error_code& ec) { if (use_ssl) { BOOST_ASSERT(ssl_.has_value()); return ssl_->write_some(buff, ec); } else if (auto* tcp_sock = variant2::get_if(&sock_)) { return tcp_sock->sock.write_some(buff, ec); } #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS else if (auto* unix_sock = variant2::get_if(&sock_)) { return unix_sock->write_some(buff, ec); } #endif else { BOOST_ASSERT(false); return 0u; } } template void async_write_some(boost::asio::const_buffer buff, bool use_ssl, CompletionToken&& token) { if (use_ssl) { BOOST_ASSERT(ssl_.has_value()); return ssl_->async_write_some(buff, std::forward(token)); } else if (auto* tcp_sock = variant2::get_if(&sock_)) { return tcp_sock->sock.async_write_some(buff, std::forward(token)); } #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS else if (auto* unix_sock = variant2::get_if(&sock_)) { return unix_sock->async_write_some(buff, std::forward(token)); } #endif else { BOOST_ASSERT(false); } } // Connect and close void connect(error_code& ec) { ec = setup_stream(); if (ec) return; if (address_->type() == address_type::host_and_port) { // Resolve endpoints auto& tcp_sock = variant2::unsafe_get<1>(sock_); auto endpoints = tcp_sock.resolv.resolve( cast_asio_sv_param(address_->hostname()), std::to_string(address_->port()), ec ); if (ec) return; // Connect stream asio::connect(tcp_sock.sock, std::move(endpoints), ec); if (ec) return; // Disable Naggle's algorithm set_tcp_nodelay(); } #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS else { BOOST_ASSERT(address_->type() == address_type::unix_path); // Just connect the stream auto& unix_sock = variant2::unsafe_get<2>(sock_); unix_sock.connect(cast_asio_sv_param(address_->unix_socket_path()), ec); } #endif } template void async_connect(CompletionToken&& token) { asio::async_compose(connect_op(*this), token, ex_); } void close(error_code& ec) { if (auto* tcp_sock = variant2::get_if(&sock_)) { tcp_sock->sock.close(ec); } #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS else if (auto* unix_sock = variant2::get_if(&sock_)) { unix_sock->close(ec); } #endif } // Exposed for testing const asio::ip::tcp::socket& tcp_socket() const { return variant2::get(sock_).sock; } private: struct socket_and_resolver { asio::ip::tcp::socket sock; asio::ip::tcp::resolver resolv; socket_and_resolver(asio::any_io_executor ex) : sock(ex), resolv(std::move(ex)) {} }; #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS using unix_socket = asio::local::stream_protocol::socket; #endif const any_address* address_{}; asio::any_io_executor ex_; variant2::variant< variant2::monostate, socket_and_resolver #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS , unix_socket #endif > sock_; ssl_context_with_default ssl_ctx_; boost::optional> ssl_; error_code setup_stream() { if (address_->type() == address_type::host_and_port) { // Clean up any previous state sock_.emplace(ex_); } else if (address_->type() == address_type::unix_path) { #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS // Clean up any previous state sock_.emplace(ex_); #else return asio::error::operation_not_supported; #endif } return error_code(); } void set_tcp_nodelay() { variant2::unsafe_get<1u>(sock_).sock.set_option(asio::ip::tcp::no_delay(true)); } void create_ssl_stream() { // The stream object must be re-created even if it already exists, since // once used for a connection (anytime after ssl::stream::handshake is called), // it can't be re-used for any subsequent connections BOOST_ASSERT(variant2::holds_alternative(sock_)); ssl_.emplace(variant2::unsafe_get<1>(sock_).sock, ssl_ctx_.get()); } struct connect_op { int resume_point_{0}; variant_stream& this_obj_; error_code stored_ec_; connect_op(variant_stream& this_obj) noexcept : this_obj_(this_obj) {} template void operator()(Self& self, error_code ec = {}, asio::ip::tcp::resolver::results_type endpoints = {}) { if (ec) { self.complete(ec); return; } switch (resume_point_) { case 0: // Setup stream stored_ec_ = this_obj_.setup_stream(); if (stored_ec_) { BOOST_MYSQL_YIELD(resume_point_, 1, asio::post(this_obj_.ex_, std::move(self))) self.complete(stored_ec_); return; } if (this_obj_.address_->type() == address_type::host_and_port) { // Resolve endpoints BOOST_MYSQL_YIELD( resume_point_, 2, variant2::unsafe_get<1>(this_obj_.sock_) .resolv.async_resolve( cast_asio_sv_param(this_obj_.address_->hostname()), std::to_string(this_obj_.address_->port()), std::move(self) ) ) // Connect stream BOOST_MYSQL_YIELD( resume_point_, 3, asio::async_connect( variant2::unsafe_get<1>(this_obj_.sock_).sock, std::move(endpoints), std::move(self) ) ) // The final handler requires a void(error_code, tcp::endpoint signature), // which this function can't implement. See operator() overload below. } #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS else { BOOST_ASSERT(this_obj_.address_->type() == address_type::unix_path); // Just connect the stream BOOST_MYSQL_YIELD( resume_point_, 4, variant2::unsafe_get<2>(this_obj_.sock_) .async_connect( cast_asio_sv_param(this_obj_.address_->unix_socket_path()), std::move(self) ) ) self.complete(error_code()); } #endif } } template void operator()(Self& self, error_code ec, asio::ip::tcp::endpoint) { if (!ec) { // Disable Naggle's algorithm this_obj_.set_tcp_nodelay(); } self.complete(ec); } }; }; } // namespace detail } // namespace mysql } // namespace boost #endif