variant_stream.hpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. //
  2. // Copyright (c) 2019-2024 Ruben Perez Hidalgo (rubenperez038 at gmail dot com)
  3. //
  4. // Distributed under the Boost Software License, Version 1.0. (See accompanying
  5. // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
  6. //
  7. #ifndef BOOST_MYSQL_IMPL_INTERNAL_VARIANT_STREAM_HPP
  8. #define BOOST_MYSQL_IMPL_INTERNAL_VARIANT_STREAM_HPP
  9. #include <boost/mysql/any_address.hpp>
  10. #include <boost/mysql/error_code.hpp>
  11. #include <boost/mysql/string_view.hpp>
  12. #include <boost/mysql/detail/config.hpp>
  13. #include <boost/mysql/detail/connect_params_helpers.hpp>
  14. #include <boost/mysql/impl/internal/coroutine.hpp>
  15. #include <boost/mysql/impl/internal/ssl_context_with_default.hpp>
  16. #include <boost/asio/any_io_executor.hpp>
  17. #include <boost/asio/compose.hpp>
  18. #include <boost/asio/connect.hpp>
  19. #include <boost/asio/error.hpp>
  20. #include <boost/asio/ip/tcp.hpp>
  21. #include <boost/asio/local/stream_protocol.hpp>
  22. #include <boost/asio/post.hpp>
  23. #include <boost/asio/ssl/context.hpp>
  24. #include <boost/asio/ssl/stream.hpp>
  25. #include <boost/optional/optional.hpp>
  26. #include <boost/variant2/variant.hpp>
  27. #include <string>
  28. #include <utility>
  29. namespace boost {
  30. namespace mysql {
  31. namespace detail {
  32. // Asio defines a "string view parameter" to be either const std::string&,
  33. // std::experimental::string_view or std::string_view. Casting from the Boost
  34. // version doesn't work for std::experimental::string_view
  35. #if defined(BOOST_ASIO_HAS_STD_STRING_VIEW)
  36. inline std::string_view cast_asio_sv_param(string_view input) noexcept { return input; }
  37. #elif defined(BOOST_ASIO_HAS_STD_EXPERIMENTAL_STRING_VIEW)
  38. inline std::experimental::string_view cast_asio_sv_param(string_view input) noexcept
  39. {
  40. return {input.data(), input.size()};
  41. }
  42. #else
  43. inline std::string cast_asio_sv_param(string_view input) { return input; }
  44. #endif
  45. // Implements the EngineStream concept (see stream_adaptor)
  46. class variant_stream
  47. {
  48. public:
  49. variant_stream(asio::any_io_executor ex, asio::ssl::context* ctx) : ex_(std::move(ex)), ssl_ctx_(ctx) {}
  50. bool supports_ssl() const { return true; }
  51. void set_endpoint(const void* value) { address_ = static_cast<const any_address*>(value); }
  52. // Executor
  53. using executor_type = asio::any_io_executor;
  54. executor_type get_executor() { return ex_; }
  55. // SSL
  56. void ssl_handshake(error_code& ec)
  57. {
  58. create_ssl_stream();
  59. ssl_->handshake(asio::ssl::stream_base::client, ec);
  60. }
  61. template <class CompletionToken>
  62. void async_ssl_handshake(CompletionToken&& token)
  63. {
  64. create_ssl_stream();
  65. ssl_->async_handshake(asio::ssl::stream_base::client, std::forward<CompletionToken>(token));
  66. }
  67. void ssl_shutdown(error_code& ec)
  68. {
  69. BOOST_ASSERT(ssl_.has_value());
  70. ssl_->shutdown(ec);
  71. }
  72. template <class CompletionToken>
  73. void async_ssl_shutdown(CompletionToken&& token)
  74. {
  75. BOOST_ASSERT(ssl_.has_value());
  76. ssl_->async_shutdown(std::forward<CompletionToken>(token));
  77. }
  78. // Reading
  79. std::size_t read_some(asio::mutable_buffer buff, bool use_ssl, error_code& ec)
  80. {
  81. if (use_ssl)
  82. {
  83. BOOST_ASSERT(ssl_.has_value());
  84. return ssl_->read_some(buff, ec);
  85. }
  86. else if (auto* tcp_sock = variant2::get_if<socket_and_resolver>(&sock_))
  87. {
  88. return tcp_sock->sock.read_some(buff, ec);
  89. }
  90. #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
  91. else if (auto* unix_sock = variant2::get_if<unix_socket>(&sock_))
  92. {
  93. return unix_sock->read_some(buff, ec);
  94. }
  95. #endif
  96. else
  97. {
  98. BOOST_ASSERT(false);
  99. return 0u;
  100. }
  101. }
  102. template <class CompletionToken>
  103. void async_read_some(asio::mutable_buffer buff, bool use_ssl, CompletionToken&& token)
  104. {
  105. if (use_ssl)
  106. {
  107. BOOST_ASSERT(ssl_.has_value());
  108. ssl_->async_read_some(buff, std::forward<CompletionToken>(token));
  109. }
  110. else if (auto* tcp_sock = variant2::get_if<socket_and_resolver>(&sock_))
  111. {
  112. tcp_sock->sock.async_read_some(buff, std::forward<CompletionToken>(token));
  113. }
  114. #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
  115. else if (auto* unix_sock = variant2::get_if<unix_socket>(&sock_))
  116. {
  117. unix_sock->async_read_some(buff, std::forward<CompletionToken>(token));
  118. }
  119. #endif
  120. else
  121. {
  122. BOOST_ASSERT(false);
  123. }
  124. }
  125. // Writing
  126. std::size_t write_some(boost::asio::const_buffer buff, bool use_ssl, error_code& ec)
  127. {
  128. if (use_ssl)
  129. {
  130. BOOST_ASSERT(ssl_.has_value());
  131. return ssl_->write_some(buff, ec);
  132. }
  133. else if (auto* tcp_sock = variant2::get_if<socket_and_resolver>(&sock_))
  134. {
  135. return tcp_sock->sock.write_some(buff, ec);
  136. }
  137. #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
  138. else if (auto* unix_sock = variant2::get_if<unix_socket>(&sock_))
  139. {
  140. return unix_sock->write_some(buff, ec);
  141. }
  142. #endif
  143. else
  144. {
  145. BOOST_ASSERT(false);
  146. return 0u;
  147. }
  148. }
  149. template <class CompletionToken>
  150. void async_write_some(boost::asio::const_buffer buff, bool use_ssl, CompletionToken&& token)
  151. {
  152. if (use_ssl)
  153. {
  154. BOOST_ASSERT(ssl_.has_value());
  155. return ssl_->async_write_some(buff, std::forward<CompletionToken>(token));
  156. }
  157. else if (auto* tcp_sock = variant2::get_if<socket_and_resolver>(&sock_))
  158. {
  159. return tcp_sock->sock.async_write_some(buff, std::forward<CompletionToken>(token));
  160. }
  161. #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
  162. else if (auto* unix_sock = variant2::get_if<unix_socket>(&sock_))
  163. {
  164. return unix_sock->async_write_some(buff, std::forward<CompletionToken>(token));
  165. }
  166. #endif
  167. else
  168. {
  169. BOOST_ASSERT(false);
  170. }
  171. }
  172. // Connect and close
  173. void connect(error_code& ec)
  174. {
  175. ec = setup_stream();
  176. if (ec)
  177. return;
  178. if (address_->type() == address_type::host_and_port)
  179. {
  180. // Resolve endpoints
  181. auto& tcp_sock = variant2::unsafe_get<1>(sock_);
  182. auto endpoints = tcp_sock.resolv.resolve(
  183. cast_asio_sv_param(address_->hostname()),
  184. std::to_string(address_->port()),
  185. ec
  186. );
  187. if (ec)
  188. return;
  189. // Connect stream
  190. asio::connect(tcp_sock.sock, std::move(endpoints), ec);
  191. if (ec)
  192. return;
  193. // Disable Naggle's algorithm
  194. set_tcp_nodelay();
  195. }
  196. #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
  197. else
  198. {
  199. BOOST_ASSERT(address_->type() == address_type::unix_path);
  200. // Just connect the stream
  201. auto& unix_sock = variant2::unsafe_get<2>(sock_);
  202. unix_sock.connect(cast_asio_sv_param(address_->unix_socket_path()), ec);
  203. }
  204. #endif
  205. }
  206. template <class CompletionToken>
  207. void async_connect(CompletionToken&& token)
  208. {
  209. asio::async_compose<CompletionToken, void(error_code)>(connect_op(*this), token, ex_);
  210. }
  211. void close(error_code& ec)
  212. {
  213. if (auto* tcp_sock = variant2::get_if<socket_and_resolver>(&sock_))
  214. {
  215. tcp_sock->sock.close(ec);
  216. }
  217. #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
  218. else if (auto* unix_sock = variant2::get_if<unix_socket>(&sock_))
  219. {
  220. unix_sock->close(ec);
  221. }
  222. #endif
  223. }
  224. // Exposed for testing
  225. const asio::ip::tcp::socket& tcp_socket() const { return variant2::get<socket_and_resolver>(sock_).sock; }
  226. private:
  227. struct socket_and_resolver
  228. {
  229. asio::ip::tcp::socket sock;
  230. asio::ip::tcp::resolver resolv;
  231. socket_and_resolver(asio::any_io_executor ex) : sock(ex), resolv(std::move(ex)) {}
  232. };
  233. #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
  234. using unix_socket = asio::local::stream_protocol::socket;
  235. #endif
  236. const any_address* address_{};
  237. asio::any_io_executor ex_;
  238. variant2::variant<
  239. variant2::monostate,
  240. socket_and_resolver
  241. #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
  242. ,
  243. unix_socket
  244. #endif
  245. >
  246. sock_;
  247. ssl_context_with_default ssl_ctx_;
  248. boost::optional<asio::ssl::stream<asio::ip::tcp::socket&>> ssl_;
  249. error_code setup_stream()
  250. {
  251. if (address_->type() == address_type::host_and_port)
  252. {
  253. // Clean up any previous state
  254. sock_.emplace<socket_and_resolver>(ex_);
  255. }
  256. else if (address_->type() == address_type::unix_path)
  257. {
  258. #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
  259. // Clean up any previous state
  260. sock_.emplace<unix_socket>(ex_);
  261. #else
  262. return asio::error::operation_not_supported;
  263. #endif
  264. }
  265. return error_code();
  266. }
  267. void set_tcp_nodelay() { variant2::unsafe_get<1u>(sock_).sock.set_option(asio::ip::tcp::no_delay(true)); }
  268. void create_ssl_stream()
  269. {
  270. // The stream object must be re-created even if it already exists, since
  271. // once used for a connection (anytime after ssl::stream::handshake is called),
  272. // it can't be re-used for any subsequent connections
  273. BOOST_ASSERT(variant2::holds_alternative<socket_and_resolver>(sock_));
  274. ssl_.emplace(variant2::unsafe_get<1>(sock_).sock, ssl_ctx_.get());
  275. }
  276. struct connect_op
  277. {
  278. int resume_point_{0};
  279. variant_stream& this_obj_;
  280. error_code stored_ec_;
  281. connect_op(variant_stream& this_obj) noexcept : this_obj_(this_obj) {}
  282. template <class Self>
  283. void operator()(Self& self, error_code ec = {}, asio::ip::tcp::resolver::results_type endpoints = {})
  284. {
  285. if (ec)
  286. {
  287. self.complete(ec);
  288. return;
  289. }
  290. switch (resume_point_)
  291. {
  292. case 0:
  293. // Setup stream
  294. stored_ec_ = this_obj_.setup_stream();
  295. if (stored_ec_)
  296. {
  297. BOOST_MYSQL_YIELD(resume_point_, 1, asio::post(this_obj_.ex_, std::move(self)))
  298. self.complete(stored_ec_);
  299. return;
  300. }
  301. if (this_obj_.address_->type() == address_type::host_and_port)
  302. {
  303. // Resolve endpoints
  304. BOOST_MYSQL_YIELD(
  305. resume_point_,
  306. 2,
  307. variant2::unsafe_get<1>(this_obj_.sock_)
  308. .resolv.async_resolve(
  309. cast_asio_sv_param(this_obj_.address_->hostname()),
  310. std::to_string(this_obj_.address_->port()),
  311. std::move(self)
  312. )
  313. )
  314. // Connect stream
  315. BOOST_MYSQL_YIELD(
  316. resume_point_,
  317. 3,
  318. asio::async_connect(
  319. variant2::unsafe_get<1>(this_obj_.sock_).sock,
  320. std::move(endpoints),
  321. std::move(self)
  322. )
  323. )
  324. // The final handler requires a void(error_code, tcp::endpoint signature),
  325. // which this function can't implement. See operator() overload below.
  326. }
  327. #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
  328. else
  329. {
  330. BOOST_ASSERT(this_obj_.address_->type() == address_type::unix_path);
  331. // Just connect the stream
  332. BOOST_MYSQL_YIELD(
  333. resume_point_,
  334. 4,
  335. variant2::unsafe_get<2>(this_obj_.sock_)
  336. .async_connect(
  337. cast_asio_sv_param(this_obj_.address_->unix_socket_path()),
  338. std::move(self)
  339. )
  340. )
  341. self.complete(error_code());
  342. }
  343. #endif
  344. }
  345. }
  346. template <class Self>
  347. void operator()(Self& self, error_code ec, asio::ip::tcp::endpoint)
  348. {
  349. if (!ec)
  350. {
  351. // Disable Naggle's algorithm
  352. this_obj_.set_tcp_nodelay();
  353. }
  354. self.complete(ec);
  355. }
  356. };
  357. };
  358. } // namespace detail
  359. } // namespace mysql
  360. } // namespace boost
  361. #endif