123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411 |
- //
- // Copyright (c) 2019-2024 Ruben Perez Hidalgo (rubenperez038 at gmail dot com)
- //
- // Distributed under the Boost Software License, Version 1.0. (See accompanying
- // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
- //
- #ifndef BOOST_MYSQL_IMPL_INTERNAL_VARIANT_STREAM_HPP
- #define BOOST_MYSQL_IMPL_INTERNAL_VARIANT_STREAM_HPP
- #include <boost/mysql/any_address.hpp>
- #include <boost/mysql/error_code.hpp>
- #include <boost/mysql/string_view.hpp>
- #include <boost/mysql/detail/config.hpp>
- #include <boost/mysql/detail/connect_params_helpers.hpp>
- #include <boost/mysql/impl/internal/coroutine.hpp>
- #include <boost/mysql/impl/internal/ssl_context_with_default.hpp>
- #include <boost/asio/any_io_executor.hpp>
- #include <boost/asio/compose.hpp>
- #include <boost/asio/connect.hpp>
- #include <boost/asio/error.hpp>
- #include <boost/asio/ip/tcp.hpp>
- #include <boost/asio/local/stream_protocol.hpp>
- #include <boost/asio/post.hpp>
- #include <boost/asio/ssl/context.hpp>
- #include <boost/asio/ssl/stream.hpp>
- #include <boost/optional/optional.hpp>
- #include <boost/variant2/variant.hpp>
- #include <string>
- #include <utility>
- namespace boost {
- namespace mysql {
- namespace detail {
- // Asio defines a "string view parameter" to be either const std::string&,
- // std::experimental::string_view or std::string_view. Casting from the Boost
- // version doesn't work for std::experimental::string_view
- #if defined(BOOST_ASIO_HAS_STD_STRING_VIEW)
- inline std::string_view cast_asio_sv_param(string_view input) noexcept { return input; }
- #elif defined(BOOST_ASIO_HAS_STD_EXPERIMENTAL_STRING_VIEW)
- inline std::experimental::string_view cast_asio_sv_param(string_view input) noexcept
- {
- return {input.data(), input.size()};
- }
- #else
- inline std::string cast_asio_sv_param(string_view input) { return input; }
- #endif
- // Implements the EngineStream concept (see stream_adaptor)
- class variant_stream
- {
- public:
- variant_stream(asio::any_io_executor ex, asio::ssl::context* ctx) : ex_(std::move(ex)), ssl_ctx_(ctx) {}
- bool supports_ssl() const { return true; }
- void set_endpoint(const void* value) { address_ = static_cast<const any_address*>(value); }
- // Executor
- using executor_type = asio::any_io_executor;
- executor_type get_executor() { return ex_; }
- // SSL
- void ssl_handshake(error_code& ec)
- {
- create_ssl_stream();
- ssl_->handshake(asio::ssl::stream_base::client, ec);
- }
- template <class CompletionToken>
- void async_ssl_handshake(CompletionToken&& token)
- {
- create_ssl_stream();
- ssl_->async_handshake(asio::ssl::stream_base::client, std::forward<CompletionToken>(token));
- }
- void ssl_shutdown(error_code& ec)
- {
- BOOST_ASSERT(ssl_.has_value());
- ssl_->shutdown(ec);
- }
- template <class CompletionToken>
- void async_ssl_shutdown(CompletionToken&& token)
- {
- BOOST_ASSERT(ssl_.has_value());
- ssl_->async_shutdown(std::forward<CompletionToken>(token));
- }
- // Reading
- std::size_t read_some(asio::mutable_buffer buff, bool use_ssl, error_code& ec)
- {
- if (use_ssl)
- {
- BOOST_ASSERT(ssl_.has_value());
- return ssl_->read_some(buff, ec);
- }
- else if (auto* tcp_sock = variant2::get_if<socket_and_resolver>(&sock_))
- {
- return tcp_sock->sock.read_some(buff, ec);
- }
- #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
- else if (auto* unix_sock = variant2::get_if<unix_socket>(&sock_))
- {
- return unix_sock->read_some(buff, ec);
- }
- #endif
- else
- {
- BOOST_ASSERT(false);
- return 0u;
- }
- }
- template <class CompletionToken>
- void async_read_some(asio::mutable_buffer buff, bool use_ssl, CompletionToken&& token)
- {
- if (use_ssl)
- {
- BOOST_ASSERT(ssl_.has_value());
- ssl_->async_read_some(buff, std::forward<CompletionToken>(token));
- }
- else if (auto* tcp_sock = variant2::get_if<socket_and_resolver>(&sock_))
- {
- tcp_sock->sock.async_read_some(buff, std::forward<CompletionToken>(token));
- }
- #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
- else if (auto* unix_sock = variant2::get_if<unix_socket>(&sock_))
- {
- unix_sock->async_read_some(buff, std::forward<CompletionToken>(token));
- }
- #endif
- else
- {
- BOOST_ASSERT(false);
- }
- }
- // Writing
- std::size_t write_some(boost::asio::const_buffer buff, bool use_ssl, error_code& ec)
- {
- if (use_ssl)
- {
- BOOST_ASSERT(ssl_.has_value());
- return ssl_->write_some(buff, ec);
- }
- else if (auto* tcp_sock = variant2::get_if<socket_and_resolver>(&sock_))
- {
- return tcp_sock->sock.write_some(buff, ec);
- }
- #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
- else if (auto* unix_sock = variant2::get_if<unix_socket>(&sock_))
- {
- return unix_sock->write_some(buff, ec);
- }
- #endif
- else
- {
- BOOST_ASSERT(false);
- return 0u;
- }
- }
- template <class CompletionToken>
- void async_write_some(boost::asio::const_buffer buff, bool use_ssl, CompletionToken&& token)
- {
- if (use_ssl)
- {
- BOOST_ASSERT(ssl_.has_value());
- return ssl_->async_write_some(buff, std::forward<CompletionToken>(token));
- }
- else if (auto* tcp_sock = variant2::get_if<socket_and_resolver>(&sock_))
- {
- return tcp_sock->sock.async_write_some(buff, std::forward<CompletionToken>(token));
- }
- #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
- else if (auto* unix_sock = variant2::get_if<unix_socket>(&sock_))
- {
- return unix_sock->async_write_some(buff, std::forward<CompletionToken>(token));
- }
- #endif
- else
- {
- BOOST_ASSERT(false);
- }
- }
- // Connect and close
- void connect(error_code& ec)
- {
- ec = setup_stream();
- if (ec)
- return;
- if (address_->type() == address_type::host_and_port)
- {
- // Resolve endpoints
- auto& tcp_sock = variant2::unsafe_get<1>(sock_);
- auto endpoints = tcp_sock.resolv.resolve(
- cast_asio_sv_param(address_->hostname()),
- std::to_string(address_->port()),
- ec
- );
- if (ec)
- return;
- // Connect stream
- asio::connect(tcp_sock.sock, std::move(endpoints), ec);
- if (ec)
- return;
- // Disable Naggle's algorithm
- set_tcp_nodelay();
- }
- #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
- else
- {
- BOOST_ASSERT(address_->type() == address_type::unix_path);
- // Just connect the stream
- auto& unix_sock = variant2::unsafe_get<2>(sock_);
- unix_sock.connect(cast_asio_sv_param(address_->unix_socket_path()), ec);
- }
- #endif
- }
- template <class CompletionToken>
- void async_connect(CompletionToken&& token)
- {
- asio::async_compose<CompletionToken, void(error_code)>(connect_op(*this), token, ex_);
- }
- void close(error_code& ec)
- {
- if (auto* tcp_sock = variant2::get_if<socket_and_resolver>(&sock_))
- {
- tcp_sock->sock.close(ec);
- }
- #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
- else if (auto* unix_sock = variant2::get_if<unix_socket>(&sock_))
- {
- unix_sock->close(ec);
- }
- #endif
- }
- // Exposed for testing
- const asio::ip::tcp::socket& tcp_socket() const { return variant2::get<socket_and_resolver>(sock_).sock; }
- private:
- struct socket_and_resolver
- {
- asio::ip::tcp::socket sock;
- asio::ip::tcp::resolver resolv;
- socket_and_resolver(asio::any_io_executor ex) : sock(ex), resolv(std::move(ex)) {}
- };
- #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
- using unix_socket = asio::local::stream_protocol::socket;
- #endif
- const any_address* address_{};
- asio::any_io_executor ex_;
- variant2::variant<
- variant2::monostate,
- socket_and_resolver
- #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
- ,
- unix_socket
- #endif
- >
- sock_;
- ssl_context_with_default ssl_ctx_;
- boost::optional<asio::ssl::stream<asio::ip::tcp::socket&>> ssl_;
- error_code setup_stream()
- {
- if (address_->type() == address_type::host_and_port)
- {
- // Clean up any previous state
- sock_.emplace<socket_and_resolver>(ex_);
- }
- else if (address_->type() == address_type::unix_path)
- {
- #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
- // Clean up any previous state
- sock_.emplace<unix_socket>(ex_);
- #else
- return asio::error::operation_not_supported;
- #endif
- }
- return error_code();
- }
- void set_tcp_nodelay() { variant2::unsafe_get<1u>(sock_).sock.set_option(asio::ip::tcp::no_delay(true)); }
- void create_ssl_stream()
- {
- // The stream object must be re-created even if it already exists, since
- // once used for a connection (anytime after ssl::stream::handshake is called),
- // it can't be re-used for any subsequent connections
- BOOST_ASSERT(variant2::holds_alternative<socket_and_resolver>(sock_));
- ssl_.emplace(variant2::unsafe_get<1>(sock_).sock, ssl_ctx_.get());
- }
- struct connect_op
- {
- int resume_point_{0};
- variant_stream& this_obj_;
- error_code stored_ec_;
- connect_op(variant_stream& this_obj) noexcept : this_obj_(this_obj) {}
- template <class Self>
- void operator()(Self& self, error_code ec = {}, asio::ip::tcp::resolver::results_type endpoints = {})
- {
- if (ec)
- {
- self.complete(ec);
- return;
- }
- switch (resume_point_)
- {
- case 0:
- // Setup stream
- stored_ec_ = this_obj_.setup_stream();
- if (stored_ec_)
- {
- BOOST_MYSQL_YIELD(resume_point_, 1, asio::post(this_obj_.ex_, std::move(self)))
- self.complete(stored_ec_);
- return;
- }
- if (this_obj_.address_->type() == address_type::host_and_port)
- {
- // Resolve endpoints
- BOOST_MYSQL_YIELD(
- resume_point_,
- 2,
- variant2::unsafe_get<1>(this_obj_.sock_)
- .resolv.async_resolve(
- cast_asio_sv_param(this_obj_.address_->hostname()),
- std::to_string(this_obj_.address_->port()),
- std::move(self)
- )
- )
- // Connect stream
- BOOST_MYSQL_YIELD(
- resume_point_,
- 3,
- asio::async_connect(
- variant2::unsafe_get<1>(this_obj_.sock_).sock,
- std::move(endpoints),
- std::move(self)
- )
- )
- // The final handler requires a void(error_code, tcp::endpoint signature),
- // which this function can't implement. See operator() overload below.
- }
- #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
- else
- {
- BOOST_ASSERT(this_obj_.address_->type() == address_type::unix_path);
- // Just connect the stream
- BOOST_MYSQL_YIELD(
- resume_point_,
- 4,
- variant2::unsafe_get<2>(this_obj_.sock_)
- .async_connect(
- cast_asio_sv_param(this_obj_.address_->unix_socket_path()),
- std::move(self)
- )
- )
- self.complete(error_code());
- }
- #endif
- }
- }
- template <class Self>
- void operator()(Self& self, error_code ec, asio::ip::tcp::endpoint)
- {
- if (!ec)
- {
- // Disable Naggle's algorithm
- this_obj_.set_tcp_nodelay();
- }
- self.complete(ec);
- }
- };
- };
- } // namespace detail
- } // namespace mysql
- } // namespace boost
- #endif
|