handshake.hpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. //
  2. // Copyright (c) 2019-2024 Ruben Perez Hidalgo (rubenperez038 at gmail dot com)
  3. //
  4. // Distributed under the Boost Software License, Version 1.0. (See accompanying
  5. // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
  6. //
  7. #ifndef BOOST_MYSQL_IMPL_INTERNAL_SANSIO_HANDSHAKE_HPP
  8. #define BOOST_MYSQL_IMPL_INTERNAL_SANSIO_HANDSHAKE_HPP
  9. #include <boost/mysql/character_set.hpp>
  10. #include <boost/mysql/diagnostics.hpp>
  11. #include <boost/mysql/error_code.hpp>
  12. #include <boost/mysql/handshake_params.hpp>
  13. #include <boost/mysql/mysql_collations.hpp>
  14. #include <boost/mysql/detail/algo_params.hpp>
  15. #include <boost/mysql/detail/next_action.hpp>
  16. #include <boost/mysql/detail/ok_view.hpp>
  17. #include <boost/mysql/impl/internal/auth/auth.hpp>
  18. #include <boost/mysql/impl/internal/coroutine.hpp>
  19. #include <boost/mysql/impl/internal/protocol/capabilities.hpp>
  20. #include <boost/mysql/impl/internal/protocol/db_flavor.hpp>
  21. #include <boost/mysql/impl/internal/protocol/deserialization.hpp>
  22. #include <boost/mysql/impl/internal/protocol/serialization.hpp>
  23. #include <boost/mysql/impl/internal/sansio/connection_state_data.hpp>
  24. #include <cstdint>
  25. namespace boost {
  26. namespace mysql {
  27. namespace detail {
  28. inline capabilities conditional_capability(bool condition, std::uint32_t cap)
  29. {
  30. return capabilities(condition ? cap : 0);
  31. }
  32. inline error_code process_capabilities(
  33. const handshake_params& params,
  34. const server_hello& hello,
  35. capabilities& negotiated_caps,
  36. bool transport_supports_ssl
  37. )
  38. {
  39. auto ssl = transport_supports_ssl ? params.ssl() : ssl_mode::disable;
  40. capabilities server_caps = hello.server_capabilities;
  41. capabilities required_caps = mandatory_capabilities |
  42. conditional_capability(!params.database().empty(), CLIENT_CONNECT_WITH_DB) |
  43. conditional_capability(params.multi_queries(), CLIENT_MULTI_STATEMENTS) |
  44. conditional_capability(ssl == ssl_mode::require, CLIENT_SSL);
  45. if (required_caps.has(CLIENT_SSL) && !server_caps.has(CLIENT_SSL))
  46. {
  47. // This happens if the server doesn't have SSL configured. This special
  48. // error code helps users diagnosing their problem a lot (server_unsupported doesn't).
  49. return make_error_code(client_errc::server_doesnt_support_ssl);
  50. }
  51. else if (!server_caps.has_all(required_caps))
  52. {
  53. return make_error_code(client_errc::server_unsupported);
  54. }
  55. negotiated_caps = server_caps & (required_caps | optional_capabilities |
  56. conditional_capability(ssl == ssl_mode::enable, CLIENT_SSL));
  57. return error_code();
  58. }
  59. class handshake_algo
  60. {
  61. int resume_point_{0};
  62. diagnostics* diag_;
  63. handshake_params hparams_;
  64. auth_response auth_resp_;
  65. std::uint8_t sequence_number_{0};
  66. bool secure_channel_{false};
  67. // Attempts to map the collection_id to a character set. We try to be conservative
  68. // here, since servers will happily accept unknown collation IDs, silently defaulting
  69. // to the server's default character set (often latin1, which is not Unicode).
  70. static character_set collation_id_to_charset(std::uint16_t collation_id)
  71. {
  72. switch (collation_id)
  73. {
  74. case mysql_collations::utf8mb4_bin:
  75. case mysql_collations::utf8mb4_general_ci: return utf8mb4_charset;
  76. case mysql_collations::ascii_general_ci:
  77. case mysql_collations::ascii_bin: return ascii_charset;
  78. default: return character_set{};
  79. }
  80. }
  81. // Once the handshake is processed, the capabilities are stored in the connection state
  82. bool use_ssl(const connection_state_data& st) const { return st.current_capabilities.has(CLIENT_SSL); }
  83. error_code process_handshake(connection_state_data& st, span<const std::uint8_t> buffer)
  84. {
  85. // Deserialize server hello
  86. server_hello hello{};
  87. auto err = deserialize_server_hello(buffer, hello, *diag_);
  88. if (err)
  89. return err;
  90. // Check capabilities
  91. capabilities negotiated_caps;
  92. err = process_capabilities(hparams_, hello, negotiated_caps, st.supports_ssl());
  93. if (err)
  94. return err;
  95. // Set capabilities & db flavor
  96. st.current_capabilities = negotiated_caps;
  97. st.flavor = hello.server;
  98. // If we're using SSL, mark the channel as secure
  99. secure_channel_ = secure_channel_ || use_ssl(st);
  100. // Compute auth response
  101. return compute_auth_response(
  102. hello.auth_plugin_name,
  103. hparams_.password(),
  104. hello.auth_plugin_data.to_span(),
  105. secure_channel_,
  106. auth_resp_
  107. );
  108. }
  109. // Response to that initial greeting
  110. ssl_request compose_ssl_request(const connection_state_data& st)
  111. {
  112. return ssl_request{
  113. st.current_capabilities,
  114. static_cast<std::uint32_t>(max_packet_size),
  115. hparams_.connection_collation(),
  116. };
  117. }
  118. login_request compose_login_request(const connection_state_data& st)
  119. {
  120. return login_request{
  121. st.current_capabilities,
  122. static_cast<std::uint32_t>(max_packet_size),
  123. hparams_.connection_collation(),
  124. hparams_.username(),
  125. auth_resp_.data,
  126. hparams_.database(),
  127. auth_resp_.plugin_name,
  128. };
  129. }
  130. // Processes auth_switch and auth_more_data messages, and leaves the result in auth_resp_
  131. error_code process_auth_switch(auth_switch msg)
  132. {
  133. return compute_auth_response(
  134. msg.plugin_name,
  135. hparams_.password(),
  136. msg.auth_data,
  137. secure_channel_,
  138. auth_resp_
  139. );
  140. }
  141. error_code process_auth_more_data(span<const std::uint8_t> data)
  142. {
  143. return compute_auth_response(
  144. auth_resp_.plugin_name,
  145. hparams_.password(),
  146. data,
  147. secure_channel_,
  148. auth_resp_
  149. );
  150. }
  151. // Composes an auth_switch_response message with the contents of auth_resp_
  152. auth_switch_response compose_auth_switch_response() const
  153. {
  154. return auth_switch_response{auth_resp_.data};
  155. }
  156. void on_success(connection_state_data& st, const ok_view& ok)
  157. {
  158. st.is_connected = true;
  159. st.backslash_escapes = ok.backslash_escapes();
  160. st.current_charset = collation_id_to_charset(hparams_.connection_collation());
  161. }
  162. error_code process_ok(connection_state_data& st)
  163. {
  164. ok_view res{};
  165. auto ec = deserialize_ok_packet(st.reader.message(), res);
  166. if (ec)
  167. return ec;
  168. on_success(st, res);
  169. return error_code();
  170. }
  171. public:
  172. handshake_algo(handshake_algo_params params) noexcept
  173. : diag_(params.diag), hparams_(params.hparams), secure_channel_(params.secure_channel)
  174. {
  175. }
  176. diagnostics& diag() { return *diag_; }
  177. next_action resume(connection_state_data& st, error_code ec)
  178. {
  179. if (ec)
  180. return ec;
  181. handhake_server_response resp(error_code{});
  182. switch (resume_point_)
  183. {
  184. case 0:
  185. // Setup
  186. diag_->clear();
  187. st.reset();
  188. // Read server greeting
  189. BOOST_MYSQL_YIELD(resume_point_, 1, st.read(sequence_number_))
  190. // Process server greeting
  191. ec = process_handshake(st, st.reader.message());
  192. if (ec)
  193. return ec;
  194. // SSL
  195. if (use_ssl(st))
  196. {
  197. // Send SSL request
  198. BOOST_MYSQL_YIELD(resume_point_, 2, st.write(compose_ssl_request(st), sequence_number_))
  199. // SSL handshake
  200. BOOST_MYSQL_YIELD(resume_point_, 3, next_action::ssl_handshake())
  201. // Mark the connection as using ssl
  202. st.ssl = ssl_state::active;
  203. }
  204. // Compose and send handshake response
  205. BOOST_MYSQL_YIELD(resume_point_, 4, st.write(compose_login_request(st), sequence_number_))
  206. // Auth message exchange
  207. while (true)
  208. {
  209. // Receive response
  210. BOOST_MYSQL_YIELD(resume_point_, 5, st.read(sequence_number_))
  211. // Process it
  212. resp = deserialize_handshake_server_response(st.reader.message(), st.flavor, *diag_);
  213. if (resp.type == handhake_server_response::type_t::ok)
  214. {
  215. // Auth success, quit
  216. on_success(st, resp.data.ok);
  217. return next_action();
  218. }
  219. else if (resp.type == handhake_server_response::type_t::error)
  220. {
  221. // Error, quit
  222. return resp.data.err;
  223. }
  224. else if (resp.type == handhake_server_response::type_t::auth_switch)
  225. {
  226. // Compute response
  227. ec = process_auth_switch(resp.data.auth_sw);
  228. if (ec)
  229. return ec;
  230. BOOST_MYSQL_YIELD(
  231. resume_point_,
  232. 6,
  233. st.write(compose_auth_switch_response(), sequence_number_)
  234. )
  235. }
  236. else if (resp.type == handhake_server_response::type_t::ok_follows)
  237. {
  238. // The next packet must be an OK packet. Read it
  239. BOOST_MYSQL_YIELD(resume_point_, 7, st.read(sequence_number_))
  240. // Process it
  241. // Regardless of whether we succeeded or not, we're done
  242. return process_ok(st);
  243. }
  244. else
  245. {
  246. BOOST_ASSERT(resp.type == handhake_server_response::type_t::auth_more_data);
  247. // Compute response
  248. ec = process_auth_more_data(resp.data.more_data);
  249. if (ec)
  250. return ec;
  251. // Write response
  252. BOOST_MYSQL_YIELD(
  253. resume_point_,
  254. 8,
  255. st.write(compose_auth_switch_response(), sequence_number_)
  256. )
  257. }
  258. }
  259. }
  260. return next_action();
  261. }
  262. };
  263. } // namespace detail
  264. } // namespace mysql
  265. } // namespace boost
  266. #endif