/* * Copyright (c) 2017-2023 zhllxt * * author : zhllxt * email : 37792738@qq.com * * Distributed under the Boost Software License, Version 1.0. (See accompanying * file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) */ #if defined(ASIO2_ENABLE_SSL) || defined(ASIO2_USE_SSL) #ifndef __ASIO2_SSL_STREAM_COMPONENT_HPP__ #define __ASIO2_SSL_STREAM_COMPONENT_HPP__ #if defined(_MSC_VER) && (_MSC_VER >= 1200) #pragma once #endif // defined(_MSC_VER) && (_MSC_VER >= 1200) #include #include #include #include #include #include #include #include namespace asio2::detail { template class ssl_stream_cp : public detail::ssl_stream_tag { public: using ssl_socket_type = typename args_t::socket_t; using ssl_stream_type = asio::ssl::stream; using ssl_handshake_type = typename asio::ssl::stream_base::handshake_type; ssl_stream_cp(asio::ssl::context& ctx, ssl_handshake_type type) noexcept : ssl_ctx_(ctx) , ssl_type_(type) { } ~ssl_stream_cp() = default; /** * @brief get the ssl socket object reference */ inline ssl_stream_type & ssl_stream() noexcept { ASIO2_ASSERT(bool(this->ssl_stream_)); return (*(this->ssl_stream_)); } /** * @brief get the ssl socket object reference */ inline ssl_stream_type const& ssl_stream() const noexcept { ASIO2_ASSERT(bool(this->ssl_stream_)); return (*(this->ssl_stream_)); } protected: template inline void _ssl_init(std::shared_ptr>& ecs, ssl_socket_type& socket, asio::ssl::context& ctx) { derived_t& derive = static_cast(*this); detail::ignore_unused(derive, ecs, socket, ctx); if constexpr (args_t::is_client) { ASIO2_ASSERT(derive.io_->running_in_this_thread()); } else { ASIO2_ASSERT(derive.sessions_.io_->running_in_this_thread()); } // Why put the initialization code of ssl stream here ? // Why not put it in the constructor ? // ----------------------------------------------------------------------- // Beacuse SSL_CTX_use_certificate_chain_file,SSL_CTX_use_PrivateKey and // other SSL_CTX_... functions must be called before SSL_new, otherwise, // those SSL_CTX_... function calls have no effect. // When construct a tcps_client object, beacuse the tcps_client is derived // from ssl_stream_cp, so the ssl_stream_cp's constructor will be called, // but at this time, the SSL_CTX_... function has not been called. this->ssl_stream_ = std::make_unique(socket, ctx); } template inline void _ssl_start( std::shared_ptr& this_ptr, std::shared_ptr>& ecs, ssl_socket_type& socket, asio::ssl::context& ctx) noexcept { derived_t& derive = static_cast(*this); detail::ignore_unused(derive, this_ptr, ecs, socket, ctx); ASIO2_ASSERT(derive.io_->running_in_this_thread()); } template inline void _ssl_stop(std::shared_ptr this_ptr, DeferEvent chain) { derived_t& derive = static_cast(*this); ASIO2_ASSERT(derive.io_->running_in_this_thread()); if (!this->ssl_stream_) return; derive.disp_event([this, &derive, this_ptr = std::move(this_ptr), e = chain.move_event()] (event_queue_guard g) mutable { // must construct a new chain defer_event chain(std::move(e), std::move(g)); struct SSL_clear_op { ssl_stream_type* s{}; // SSL_clear : // Reset ssl to allow another connection. All settings (method, ciphers, BIOs) are kept. // When the client auto reconnect, SSL_clear must be called, // otherwise the SSL handshake will failed. SSL_clear_op(ssl_stream_type* p) : s(p) { } ~SSL_clear_op() { if (s) SSL_clear(s->native_handle()); } }; // use "std::shared_ptr" to enusre that the SSL_clear(...) function is // called after "socket.shutdown, socket.close, ssl_stream.async_shutdown". std::shared_ptr SSL_clear_ptr = std::make_shared(this->ssl_stream_.get()); // if the client socket is not closed forever,this async_shutdown // callback also can't be called forever, so we use a timer to // force close the socket,then the async_shutdown callback will // be called. std::shared_ptr timer = std::make_shared(derive.io_->context()); timer->expires_after(derive.get_disconnect_timeout()); timer->async_wait( [this_ptr, chain = std::move(chain), SSL_clear_ptr] (const error_code& ec) mutable { // note : lambda [chain = std::move(chain), SSL_clear_ptr] // SSL_clear_ptr will destroyed first // chain will destroyed second after SSL_clear_ptr. detail::ignore_unused(this_ptr, chain, SSL_clear_ptr); set_last_error(ec); }); #if defined(_DEBUG) || defined(DEBUG) ASIO2_ASSERT(derive.post_send_counter_.load() == 0); derive.post_send_counter_++; #endif // https://stackoverflow.com/questions/52990455/boost-asio-ssl-stream-shutdownec-always-had-error-which-is-boostasiossl error_code ec_ignore{}; derive.socket().cancel(ec_ignore); ASIO2_LOG_DEBUG("ssl_stream_cp enter async_shutdown"); // when server call ssl stream sync shutdown first,if the client socket is // not closed forever,then here shutdowm will blocking forever. this->ssl_stream_->async_shutdown( [&derive, p = std::move(this_ptr), timer = std::move(timer), clear_ptr = std::move(SSL_clear_ptr)] (const error_code& ec) mutable { #if defined(_DEBUG) || defined(DEBUG) derive.post_send_counter_--; #endif detail::ignore_unused(derive, p, clear_ptr); ASIO2_LOG_DEBUG("ssl_stream_cp leave async_shutdown: {} {}", ec.value(), ec.message()); set_last_error(ec); detail::cancel_timer(*timer); }); }, chain.move_guard()); } template inline void _post_handshake( std::shared_ptr this_ptr, std::shared_ptr> ecs, DeferEvent chain) { derived_t& derive = static_cast(*this); ASIO2_ASSERT(bool(this->ssl_stream_)); ASIO2_ASSERT(derive.io_->running_in_this_thread()); // Used to chech whether the ssl handshake is timeout std::shared_ptr flag_ptr = std::make_shared(); flag_ptr->clear(); std::shared_ptr timer = std::make_shared(derive.io_->context()); timer->expires_after(derive.get_connect_timeout()); timer->async_wait( [&derive, this_ptr, flag_ptr](const error_code& ec) mutable { detail::ignore_unused(this_ptr); // no errors indicating timed out if (!ec) { flag_ptr->test_and_set(); error_code ec_ignore{}; if (derive.socket().is_open()) { error_code oldec = get_last_error(); asio::socket_base::linger linger = derive.get_linger(); // the get_linger maybe change the last error value. set_last_error(oldec); // we close the socket, so the async_handshake will returned // with operation_aborted. if (!(linger.enabled() == true && linger.timeout() == 0)) { derive.socket().shutdown(asio::socket_base::shutdown_both, ec_ignore); } } derive.socket().cancel(ec_ignore); derive.socket().close(ec_ignore); } }); #if defined(_DEBUG) || defined(DEBUG) ASIO2_ASSERT(derive.post_send_counter_.load() == 0); derive.post_send_counter_++; #endif this->ssl_stream_->async_handshake(this->ssl_type_, make_allocator(derive.wallocator(), [&derive, this_ptr = std::move(this_ptr), ecs = std::move(ecs), flag_ptr = std::move(flag_ptr), timer = std::move(timer), chain = std::move(chain)] (const error_code& ec) mutable { #if defined(_DEBUG) || defined(DEBUG) derive.post_send_counter_--; #endif detail::cancel_timer(*timer); if (flag_ptr->test_and_set()) derive._handle_handshake(asio::error::timed_out, std::move(this_ptr), std::move(ecs), std::move(chain)); else derive._handle_handshake(ec, std::move(this_ptr), std::move(ecs), std::move(chain)); })); } template inline void _session_handle_handshake( const error_code& ec, std::shared_ptr this_ptr, std::shared_ptr> ecs, DeferEvent chain) { derived_t& derive = static_cast(*this); // Use "sessions_.dispatch" to ensure that the _fire_accept function and the _fire_handshake // function are fired in the same thread derive.sessions_.dispatch( [&derive, ec, this_ptr = std::move(this_ptr), ecs = std::move(ecs), chain = std::move(chain)] () mutable { set_last_error(ec); derive._fire_handshake(this_ptr); if (ec) { ASIO2_LOG_DEBUG("ssl_stream_cp::handle_handshake {} {}", ec.value(), ec.message()); derive._do_disconnect(ec, std::move(this_ptr), std::move(chain)); return; } derive._done_connect(ec, std::move(this_ptr), std::move(ecs), std::move(chain)); }); } template inline void _client_handle_handshake( const error_code& ec, std::shared_ptr this_ptr, std::shared_ptr> ecs, DeferEvent chain) { derived_t& derive = static_cast(*this); set_last_error(ec); derive._fire_handshake(this_ptr); derive._done_connect(ec, std::move(this_ptr), std::move(ecs), std::move(chain)); } template inline void _handle_handshake( const error_code& ec, std::shared_ptr this_ptr, std::shared_ptr> ecs, DeferEvent chain) { derived_t& derive = static_cast(*this); if constexpr (args_t::is_session) { derive._session_handle_handshake(ec, std::move(this_ptr), std::move(ecs), std::move(chain)); } else { derive._client_handle_handshake(ec, std::move(this_ptr), std::move(ecs), std::move(chain)); } } protected: asio::ssl::context & ssl_ctx_; ssl_handshake_type ssl_type_; std::unique_ptr ssl_stream_; }; } #endif // !__ASIO2_SSL_STREAM_COMPONENT_HPP__ #endif