ssl_stream_cp.hpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. /*
  2. * Copyright (c) 2017-2023 zhllxt
  3. *
  4. * author : zhllxt
  5. * email : 37792738@qq.com
  6. *
  7. * Distributed under the Boost Software License, Version 1.0. (See accompanying
  8. * file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
  9. */
  10. #if defined(ASIO2_ENABLE_SSL) || defined(ASIO2_USE_SSL)
  11. #ifndef __ASIO2_SSL_STREAM_COMPONENT_HPP__
  12. #define __ASIO2_SSL_STREAM_COMPONENT_HPP__
  13. #if defined(_MSC_VER) && (_MSC_VER >= 1200)
  14. #pragma once
  15. #endif // defined(_MSC_VER) && (_MSC_VER >= 1200)
  16. #include <memory>
  17. #include <future>
  18. #include <utility>
  19. #include <string_view>
  20. #include <asio2/base/iopool.hpp>
  21. #include <asio2/base/detail/util.hpp>
  22. #include <asio2/base/detail/allocator.hpp>
  23. #include <asio2/base/detail/ecs.hpp>
  24. namespace asio2::detail
  25. {
  26. template<class derived_t, class args_t>
  27. class ssl_stream_cp : public detail::ssl_stream_tag
  28. {
  29. public:
  30. using ssl_socket_type = typename args_t::socket_t;
  31. using ssl_stream_type = asio::ssl::stream<ssl_socket_type&>;
  32. using ssl_handshake_type = typename asio::ssl::stream_base::handshake_type;
  33. ssl_stream_cp(asio::ssl::context& ctx, ssl_handshake_type type) noexcept
  34. : ssl_ctx_(ctx)
  35. , ssl_type_(type)
  36. {
  37. }
  38. ~ssl_stream_cp() = default;
  39. /**
  40. * @brief get the ssl socket object reference
  41. */
  42. inline ssl_stream_type & ssl_stream() noexcept
  43. {
  44. ASIO2_ASSERT(bool(this->ssl_stream_));
  45. return (*(this->ssl_stream_));
  46. }
  47. /**
  48. * @brief get the ssl socket object reference
  49. */
  50. inline ssl_stream_type const& ssl_stream() const noexcept
  51. {
  52. ASIO2_ASSERT(bool(this->ssl_stream_));
  53. return (*(this->ssl_stream_));
  54. }
  55. protected:
  56. template<typename C>
  57. inline void _ssl_init(std::shared_ptr<ecs_t<C>>& ecs, ssl_socket_type& socket, asio::ssl::context& ctx)
  58. {
  59. derived_t& derive = static_cast<derived_t&>(*this);
  60. detail::ignore_unused(derive, ecs, socket, ctx);
  61. if constexpr (args_t::is_client)
  62. {
  63. ASIO2_ASSERT(derive.io_->running_in_this_thread());
  64. }
  65. else
  66. {
  67. ASIO2_ASSERT(derive.sessions_.io_->running_in_this_thread());
  68. }
  69. // Why put the initialization code of ssl stream here ?
  70. // Why not put it in the constructor ?
  71. // -----------------------------------------------------------------------
  72. // Beacuse SSL_CTX_use_certificate_chain_file,SSL_CTX_use_PrivateKey and
  73. // other SSL_CTX_... functions must be called before SSL_new, otherwise,
  74. // those SSL_CTX_... function calls have no effect.
  75. // When construct a tcps_client object, beacuse the tcps_client is derived
  76. // from ssl_stream_cp, so the ssl_stream_cp's constructor will be called,
  77. // but at this time, the SSL_CTX_... function has not been called.
  78. this->ssl_stream_ = std::make_unique<ssl_stream_type>(socket, ctx);
  79. }
  80. template<typename C>
  81. inline void _ssl_start(
  82. std::shared_ptr<derived_t>& this_ptr, std::shared_ptr<ecs_t<C>>& ecs, ssl_socket_type& socket,
  83. asio::ssl::context& ctx) noexcept
  84. {
  85. derived_t& derive = static_cast<derived_t&>(*this);
  86. detail::ignore_unused(derive, this_ptr, ecs, socket, ctx);
  87. ASIO2_ASSERT(derive.io_->running_in_this_thread());
  88. }
  89. template<typename DeferEvent>
  90. inline void _ssl_stop(std::shared_ptr<derived_t> this_ptr, DeferEvent chain)
  91. {
  92. derived_t& derive = static_cast<derived_t&>(*this);
  93. ASIO2_ASSERT(derive.io_->running_in_this_thread());
  94. if (!this->ssl_stream_)
  95. return;
  96. derive.disp_event([this, &derive, this_ptr = std::move(this_ptr), e = chain.move_event()]
  97. (event_queue_guard<derived_t> g) mutable
  98. {
  99. // must construct a new chain
  100. defer_event chain(std::move(e), std::move(g));
  101. struct SSL_clear_op
  102. {
  103. ssl_stream_type* s{};
  104. // SSL_clear :
  105. // Reset ssl to allow another connection. All settings (method, ciphers, BIOs) are kept.
  106. // When the client auto reconnect, SSL_clear must be called,
  107. // otherwise the SSL handshake will failed.
  108. SSL_clear_op(ssl_stream_type* p) : s(p)
  109. {
  110. }
  111. ~SSL_clear_op()
  112. {
  113. if (s)
  114. SSL_clear(s->native_handle());
  115. }
  116. };
  117. // use "std::shared_ptr<SSL_clear_op>" to enusre that the SSL_clear(...) function is
  118. // called after "socket.shutdown, socket.close, ssl_stream.async_shutdown".
  119. std::shared_ptr<SSL_clear_op> SSL_clear_ptr =
  120. std::make_shared<SSL_clear_op>(this->ssl_stream_.get());
  121. // if the client socket is not closed forever,this async_shutdown
  122. // callback also can't be called forever, so we use a timer to
  123. // force close the socket,then the async_shutdown callback will
  124. // be called.
  125. std::shared_ptr<asio::steady_timer> timer =
  126. std::make_shared<asio::steady_timer>(derive.io_->context());
  127. timer->expires_after(derive.get_disconnect_timeout());
  128. timer->async_wait(
  129. [this_ptr, chain = std::move(chain), SSL_clear_ptr]
  130. (const error_code& ec) mutable
  131. {
  132. // note : lambda [chain = std::move(chain), SSL_clear_ptr]
  133. // SSL_clear_ptr will destroyed first
  134. // chain will destroyed second after SSL_clear_ptr.
  135. detail::ignore_unused(this_ptr, chain, SSL_clear_ptr);
  136. set_last_error(ec);
  137. });
  138. #if defined(_DEBUG) || defined(DEBUG)
  139. ASIO2_ASSERT(derive.post_send_counter_.load() == 0);
  140. derive.post_send_counter_++;
  141. #endif
  142. // https://stackoverflow.com/questions/52990455/boost-asio-ssl-stream-shutdownec-always-had-error-which-is-boostasiossl
  143. error_code ec_ignore{};
  144. derive.socket().cancel(ec_ignore);
  145. ASIO2_LOG_DEBUG("ssl_stream_cp enter async_shutdown");
  146. // when server call ssl stream sync shutdown first,if the client socket is
  147. // not closed forever,then here shutdowm will blocking forever.
  148. this->ssl_stream_->async_shutdown(
  149. [&derive, p = std::move(this_ptr), timer = std::move(timer), clear_ptr = std::move(SSL_clear_ptr)]
  150. (const error_code& ec) mutable
  151. {
  152. #if defined(_DEBUG) || defined(DEBUG)
  153. derive.post_send_counter_--;
  154. #endif
  155. detail::ignore_unused(derive, p, clear_ptr);
  156. ASIO2_LOG_DEBUG("ssl_stream_cp leave async_shutdown: {} {}", ec.value(), ec.message());
  157. set_last_error(ec);
  158. detail::cancel_timer(*timer);
  159. });
  160. }, chain.move_guard());
  161. }
  162. template<typename C, typename DeferEvent>
  163. inline void _post_handshake(
  164. std::shared_ptr<derived_t> this_ptr, std::shared_ptr<ecs_t<C>> ecs, DeferEvent chain)
  165. {
  166. derived_t& derive = static_cast<derived_t&>(*this);
  167. ASIO2_ASSERT(bool(this->ssl_stream_));
  168. ASIO2_ASSERT(derive.io_->running_in_this_thread());
  169. // Used to chech whether the ssl handshake is timeout
  170. std::shared_ptr<std::atomic_flag> flag_ptr = std::make_shared<std::atomic_flag>();
  171. flag_ptr->clear();
  172. std::shared_ptr<asio::steady_timer> timer =
  173. std::make_shared<asio::steady_timer>(derive.io_->context());
  174. timer->expires_after(derive.get_connect_timeout());
  175. timer->async_wait(
  176. [&derive, this_ptr, flag_ptr](const error_code& ec) mutable
  177. {
  178. detail::ignore_unused(this_ptr);
  179. // no errors indicating timed out
  180. if (!ec)
  181. {
  182. flag_ptr->test_and_set();
  183. error_code ec_ignore{};
  184. if (derive.socket().is_open())
  185. {
  186. error_code oldec = get_last_error();
  187. asio::socket_base::linger linger = derive.get_linger();
  188. // the get_linger maybe change the last error value.
  189. set_last_error(oldec);
  190. // we close the socket, so the async_handshake will returned
  191. // with operation_aborted.
  192. if (!(linger.enabled() == true && linger.timeout() == 0))
  193. {
  194. derive.socket().shutdown(asio::socket_base::shutdown_both, ec_ignore);
  195. }
  196. }
  197. derive.socket().cancel(ec_ignore);
  198. derive.socket().close(ec_ignore);
  199. }
  200. });
  201. #if defined(_DEBUG) || defined(DEBUG)
  202. ASIO2_ASSERT(derive.post_send_counter_.load() == 0);
  203. derive.post_send_counter_++;
  204. #endif
  205. this->ssl_stream_->async_handshake(this->ssl_type_, make_allocator(derive.wallocator(),
  206. [&derive, this_ptr = std::move(this_ptr), ecs = std::move(ecs),
  207. flag_ptr = std::move(flag_ptr), timer = std::move(timer), chain = std::move(chain)]
  208. (const error_code& ec) mutable
  209. {
  210. #if defined(_DEBUG) || defined(DEBUG)
  211. derive.post_send_counter_--;
  212. #endif
  213. detail::cancel_timer(*timer);
  214. if (flag_ptr->test_and_set())
  215. derive._handle_handshake(asio::error::timed_out,
  216. std::move(this_ptr), std::move(ecs), std::move(chain));
  217. else
  218. derive._handle_handshake(ec,
  219. std::move(this_ptr), std::move(ecs), std::move(chain));
  220. }));
  221. }
  222. template<typename C, typename DeferEvent>
  223. inline void _session_handle_handshake(
  224. const error_code& ec,
  225. std::shared_ptr<derived_t> this_ptr, std::shared_ptr<ecs_t<C>> ecs, DeferEvent chain)
  226. {
  227. derived_t& derive = static_cast<derived_t&>(*this);
  228. // Use "sessions_.dispatch" to ensure that the _fire_accept function and the _fire_handshake
  229. // function are fired in the same thread
  230. derive.sessions_.dispatch(
  231. [&derive, ec, this_ptr = std::move(this_ptr), ecs = std::move(ecs), chain = std::move(chain)]
  232. () mutable
  233. {
  234. set_last_error(ec);
  235. derive._fire_handshake(this_ptr);
  236. if (ec)
  237. {
  238. ASIO2_LOG_DEBUG("ssl_stream_cp::handle_handshake {} {}", ec.value(), ec.message());
  239. derive._do_disconnect(ec, std::move(this_ptr), std::move(chain));
  240. return;
  241. }
  242. derive._done_connect(ec, std::move(this_ptr), std::move(ecs), std::move(chain));
  243. });
  244. }
  245. template<typename C, typename DeferEvent>
  246. inline void _client_handle_handshake(
  247. const error_code& ec,
  248. std::shared_ptr<derived_t> this_ptr, std::shared_ptr<ecs_t<C>> ecs, DeferEvent chain)
  249. {
  250. derived_t& derive = static_cast<derived_t&>(*this);
  251. set_last_error(ec);
  252. derive._fire_handshake(this_ptr);
  253. derive._done_connect(ec, std::move(this_ptr), std::move(ecs), std::move(chain));
  254. }
  255. template<typename C, typename DeferEvent>
  256. inline void _handle_handshake(
  257. const error_code& ec,
  258. std::shared_ptr<derived_t> this_ptr, std::shared_ptr<ecs_t<C>> ecs, DeferEvent chain)
  259. {
  260. derived_t& derive = static_cast<derived_t&>(*this);
  261. if constexpr (args_t::is_session)
  262. {
  263. derive._session_handle_handshake(ec, std::move(this_ptr), std::move(ecs), std::move(chain));
  264. }
  265. else
  266. {
  267. derive._client_handle_handshake(ec, std::move(this_ptr), std::move(ecs), std::move(chain));
  268. }
  269. }
  270. protected:
  271. asio::ssl::context & ssl_ctx_;
  272. ssl_handshake_type ssl_type_;
  273. std::unique_ptr<ssl_stream_type> ssl_stream_;
  274. };
  275. }
  276. #endif // !__ASIO2_SSL_STREAM_COMPONENT_HPP__
  277. #endif