/* * Copyright (c) 2017-2023 zhllxt * * author : zhllxt * email : 37792738@qq.com * * Distributed under the Boost Software License, Version 1.0. (See accompanying * file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) */ #ifndef __ASIO2_KCP_STREAM_CP_HPP__ #define __ASIO2_KCP_STREAM_CP_HPP__ #if defined(_MSC_VER) && (_MSC_VER >= 1200) #pragma once #endif // defined(_MSC_VER) && (_MSC_VER >= 1200) #include #include #include #include #include #include #include #include #include namespace asio2::detail { ASIO2_CLASS_FORWARD_DECLARE_UDP_BASE; ASIO2_CLASS_FORWARD_DECLARE_UDP_CLIENT; ASIO2_CLASS_FORWARD_DECLARE_UDP_SERVER; ASIO2_CLASS_FORWARD_DECLARE_UDP_SESSION; /* * because udp is connectionless, in order to simplify the code logic, KCP shakes * hands only twice (compared with TCP handshakes three times) * 1 : client send syn to server * 2 : server send synack to client */ template class kcp_stream_cp { friend derived_t; // C++11 ASIO2_CLASS_FRIEND_DECLARE_UDP_BASE; ASIO2_CLASS_FRIEND_DECLARE_UDP_CLIENT; ASIO2_CLASS_FRIEND_DECLARE_UDP_SERVER; ASIO2_CLASS_FRIEND_DECLARE_UDP_SESSION; public: /** * @brief constructor */ kcp_stream_cp(derived_t& d, asio::io_context& ioc) : derive(d), kcp_timer_(ioc) { } /** * @brief destructor */ ~kcp_stream_cp() noexcept { if (this->kcp_) { kcp::ikcp_release(this->kcp_); this->kcp_ = nullptr; } } /** * @brief The udp protocol may received some illegal data. Use this function * to set up an illegal data handler. The default illegal data handler is empty, * when recvd illegal data, the default illegal data handler will do nothing. */ inline void set_illegal_response_handler(std::function fn) noexcept { this->illegal_response_handler_ = std::move(fn); } protected: void _kcp_start(std::shared_ptr this_ptr, std::uint32_t conv) { // used to restore configs kcp::ikcpcb* old = this->kcp_; struct old_kcp_destructor { old_kcp_destructor(kcp::ikcpcb* p) : p_(p) {} ~old_kcp_destructor() { if (p_) kcp::ikcp_release(p_); } kcp::ikcpcb* p_ = nullptr; } old_kcp_destructor_guard(old); ASIO2_ASSERT(conv != 0); this->kcp_ = kcp::ikcp_create(conv, (void*)this); this->kcp_->output = &kcp_stream_cp::_kcp_output; if (old) { // ikcp_setmtu kcp::ikcp_setmtu(this->kcp_, old->mtu); // ikcp_wndsize kcp::ikcp_wndsize(this->kcp_, old->snd_wnd, old->rcv_wnd); // ikcp_nodelay kcp::ikcp_nodelay(this->kcp_, old->nodelay, old->interval, old->fastresend, old->nocwnd); } else { kcp::ikcp_nodelay(this->kcp_, 1, 10, 2, 1); kcp::ikcp_wndsize(this->kcp_, 128, 512); } // if call kcp_timer_.cancel first, then call _post_kcp_timer second immediately, // use asio::post to avoid start timer failed. asio::post(derive.io_->context(), make_allocator(derive.wallocator(), [this, this_ptr = std::move(this_ptr)]() mutable { this->_post_kcp_timer(std::move(this_ptr)); })); } void _kcp_stop() { error_code ec_ignore{}; // if is kcp mode, send FIN handshake before close if (this->send_fin_) this->_kcp_send_hdr(kcp::make_kcphdr_fin(0), ec_ignore); detail::cancel_timer(this->kcp_timer_); } inline void _kcp_reset() { kcp::ikcp_reset(this->kcp_); } protected: inline std::size_t _kcp_send_hdr(kcp::kcphdr hdr, error_code& ec) { std::string msg = kcp::to_string(hdr); std::size_t sent_bytes = 0; #if defined(_DEBUG) || defined(DEBUG) ASIO2_ASSERT(derive.post_send_counter_.load() == 0); derive.post_send_counter_++; #endif if constexpr (args_t::is_session) sent_bytes = derive.stream().send_to(asio::buffer(msg), derive.remote_endpoint_, 0, ec); else sent_bytes = derive.stream().send(asio::buffer(msg), 0, ec); #if defined(_DEBUG) || defined(DEBUG) derive.post_send_counter_--; #endif return sent_bytes; } inline std::size_t _kcp_send_syn(std::uint32_t seq, error_code& ec) { kcp::kcphdr syn = kcp::make_kcphdr_syn(derive.kcp_conv_, seq); return this->_kcp_send_hdr(syn, ec); } inline std::size_t _kcp_send_synack(kcp::kcphdr syn, error_code& ec) { // the syn.th_ack is the kcp conv kcp::kcphdr synack = kcp::make_kcphdr_synack(syn.th_ack, syn.th_seq); return this->_kcp_send_hdr(synack, ec); } template inline bool _kcp_send(Data& data, Callback&& callback) { #if defined(ASIO2_ENABLE_LOG) static_assert(decltype(tallocator_)::storage_size == 168); #endif auto buffer = asio::buffer(data); #if defined(_DEBUG) || defined(DEBUG) ASIO2_ASSERT(derive.post_send_counter_.load() == 0); derive.post_send_counter_++; #endif int ret = kcp::ikcp_send(this->kcp_, (const char *)buffer.data(), (int)buffer.size()); #if defined(_DEBUG) || defined(DEBUG) derive.post_send_counter_--; #endif switch (ret) { case 0: set_last_error(error_code{} ); break; case -1: set_last_error(asio::error::invalid_argument ); break; case -2: set_last_error(asio::error::no_memory ); break; default: set_last_error(asio::error::operation_not_supported); break; } if (ret == 0) { kcp::ikcp_flush(this->kcp_); } callback(get_last_error(), ret < 0 ? 0 : buffer.size()); return (ret == 0); } void _post_kcp_timer(std::shared_ptr this_ptr) { std::uint32_t clock1 = static_cast(std::chrono::duration_cast< std::chrono::milliseconds>(std::chrono::steady_clock::now().time_since_epoch()).count()); std::uint32_t clock2 = kcp::ikcp_check(this->kcp_, clock1); this->kcp_timer_.expires_after(std::chrono::milliseconds(clock2 - clock1)); this->kcp_timer_.async_wait(make_allocator(this->tallocator_, [this, this_ptr = std::move(this_ptr)](const error_code & ec) mutable { this->_handle_kcp_timer(ec, std::move(this_ptr)); })); } void _handle_kcp_timer(const error_code & ec, std::shared_ptr this_ptr) { if (ec == asio::error::operation_aborted) return; std::uint32_t clock = static_cast(std::chrono::duration_cast< std::chrono::milliseconds>(std::chrono::steady_clock::now().time_since_epoch()).count()); kcp::ikcp_update(this->kcp_, clock); if (this->kcp_->state == (kcp::IUINT32)-1) { if (derive.state_ == state_t::started) { derive._do_disconnect(asio::error::network_reset, std::move(this_ptr)); } return; } if (derive.is_started()) this->_post_kcp_timer(std::move(this_ptr)); } template void _kcp_recv( std::shared_ptr& this_ptr, std::shared_ptr>& ecs, std::string_view data) { auto& buffer = derive.buffer(); int len = kcp::ikcp_input(this->kcp_, (const char *)data.data(), (long)data.size()); buffer.consume(buffer.size()); if (len != 0) { set_last_error(asio::error::message_size); this->_call_illegal_data_callback(data); return; } for (;;) { len = kcp::ikcp_recv(this->kcp_, (char *)buffer.prepare( buffer.pre_size()).data(), (int)buffer.pre_size()); if /**/ (len >= 0) { buffer.commit(len); derive._fire_recv(this_ptr, ecs, std::string_view(static_cast (buffer.data().data()), len)); buffer.consume(len); } else if (len == -3) { buffer.pre_size((std::min)(buffer.pre_size() * 2, buffer.max_size())); } else { break; } } kcp::ikcp_flush(this->kcp_); } template inline void _kcp_handle_recv( error_code ec, std::string_view data, std::shared_ptr& this_ptr, std::shared_ptr>& ecs) { // the kcp message header length is 24 // the kcphdr length is 12 if /**/ (data.size() > kcp::kcphdr::required_size()) { this->_kcp_recv(this_ptr, ecs, data); } else if (data.size() == kcp::kcphdr::required_size()) { // Check whether the packet is SYN handshake // It is possible that the client did not receive the synack package, then the client // will resend the syn package, so we just need to reply to the syncack package directly. // If the client is disconnect without send a "fin" or the server has't recvd the // "fin", and then the client connect again a later, at this time, the client // is in the session map already, and we need check whether the first message is fin if /**/ (kcp::is_kcphdr_syn(data)) { ASIO2_ASSERT(this->kcp_); if (this->kcp_) { kcp::kcphdr syn = kcp::to_kcphdr(data); std::uint32_t conv = syn.th_ack; if (conv == 0) { conv = this->kcp_->conv; syn.th_ack = conv; } // If the client is forced disconnect after sent some messages, and the server // has recvd the messages already, we must recreated the kcp object, otherwise // the client and server will can't handle the next messages correctly. #if 0 // set send_fin_ = false to make the _kcp_stop don't sent the fin frame. this->send_fin_ = false; this->_kcp_stop(); this->_kcp_start(this_ptr, conv); #else this->_kcp_reset(); #endif this->send_fin_ = true; // every time we recv kcp syn, we sent synack to the client this->_kcp_send_synack(syn, ec); if (ec) { derive._do_disconnect(ec, this_ptr); } } else { derive._do_disconnect(asio::error::operation_aborted, this_ptr); } } else if (kcp::is_kcphdr_synack(data, 0, true)) { ASIO2_ASSERT(this->kcp_); } else if (kcp::is_kcphdr_ack(data, 0, true)) { ASIO2_ASSERT(this->kcp_); } else if (kcp::is_kcphdr_fin(data)) { ASIO2_ASSERT(this->kcp_); this->send_fin_ = false; derive._do_disconnect(asio::error::connection_reset, this_ptr); } else { this->_call_illegal_data_callback(data); } } else { this->_call_illegal_data_callback(data); } } template void _session_post_handshake( std::shared_ptr this_ptr, std::shared_ptr> ecs, DeferEvent chain) { error_code ec; std::string& data = *(derive.first_data_); // step 3 : server recvd syn from client (the first data is the syn) kcp::kcphdr syn = kcp::to_kcphdr(data); std::uint32_t conv = syn.th_ack; if (conv == 0) { conv = derive.kcp_conv_; syn.th_ack = conv; } // step 4 : server send synack to client this->_kcp_send_synack(syn, ec); if (ec) { derive._do_disconnect(ec, std::move(this_ptr), std::move(chain)); return; } this->_kcp_start(this_ptr, conv); this->_handle_handshake(ec, std::move(this_ptr), std::move(ecs), std::move(chain)); } template void _client_post_handshake( std::shared_ptr this_ptr, std::shared_ptr> ecs, DeferEvent chain) { error_code ec; // step 1 : client send syn to server std::uint32_t seq = static_cast( std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()).count()); this->_kcp_send_syn(seq, ec); if (ec) { derive._do_disconnect(ec, std::move(this_ptr), defer_event(chain.move_guard())); return; } // use a loop timer to execute "client send syn to server" until the server // has recvd the syn packet and this client recvd reply. std::shared_ptr timer = mktimer(derive.io_->context(), std::chrono::milliseconds(500), [this, this_ptr, seq](error_code ec) mutable { if (ec == asio::error::operation_aborted) return false; this->_kcp_send_syn(seq, ec); if (ec) { set_last_error(ec); if (derive.state_ == state_t::started) { derive._do_disconnect(ec, std::move(this_ptr)); } return false; } // return true : let the timer continue execute. // return false : kill the timer. return true; }); #if defined(_DEBUG) || defined(DEBUG) ASIO2_ASSERT(derive.post_recv_counter_.load() == 0); derive.post_recv_counter_++; #endif // step 2 : client wait for recv synack util connect timeout or recvd some data derive.socket().async_receive(derive.buffer().prepare(derive.buffer().pre_size()), make_allocator(derive.rallocator(), [this, this_ptr = std::move(this_ptr), ecs = std::move(ecs), chain = std::move(chain), seq, timer = std::move(timer)] (const error_code & ec, std::size_t bytes_recvd) mutable { #if defined(_DEBUG) || defined(DEBUG) derive.post_recv_counter_--; #endif ASIO2_ASSERT(derive.io_->running_in_this_thread()); timer->cancel(); if (ec) { // if connect_timeout_timer_ is empty, it means that the connect timeout timer is // timeout and the callback has called already, so reset the error to timed_out. // note : when the async_resolve is failed, the socket is invalid to. this->_handle_handshake( derive.connect_timeout_timer_ ? ec : asio::error::timed_out, std::move(this_ptr), std::move(ecs), std::move(chain)); return; } derive.buffer().commit(bytes_recvd); std::string_view data = std::string_view(static_cast (derive.buffer().data().data()), bytes_recvd); // Check whether the data is the correct handshake information if (kcp::is_kcphdr_synack(data, seq)) { kcp::kcphdr hdr = kcp::to_kcphdr(data); std::uint32_t conv = hdr.th_seq; if (derive.kcp_conv_ != 0) { ASIO2_ASSERT(derive.kcp_conv_ == conv); } this->_kcp_start(this_ptr, conv); this->_handle_handshake(ec, std::move(this_ptr), std::move(ecs), std::move(chain)); } else { this->_handle_handshake(asio::error::address_family_not_supported, std::move(this_ptr), std::move(ecs), std::move(chain)); } derive.buffer().consume(bytes_recvd); })); } template void _post_handshake(std::shared_ptr this_ptr, std::shared_ptr> ecs, DeferEvent chain) { if constexpr (args_t::is_session) { this->_session_post_handshake(std::move(this_ptr), std::move(ecs), std::move(chain)); } else { this->_client_post_handshake(std::move(this_ptr), std::move(ecs), std::move(chain)); } } template void _handle_handshake( const error_code& ec, std::shared_ptr this_ptr, std::shared_ptr> ecs, DeferEvent chain) { set_last_error(ec); if constexpr (args_t::is_session) { derive._fire_handshake(this_ptr); if (ec) { derive._do_disconnect(ec, std::move(this_ptr), std::move(chain)); return; } derive._done_connect(ec, std::move(this_ptr), std::move(ecs), std::move(chain)); } else { derive._fire_handshake(this_ptr); derive._done_connect(ec, std::move(this_ptr), std::move(ecs), std::move(chain)); } } static int _kcp_output(const char *buf, int len, kcp::ikcpcb *kcp, void *user) { std::ignore = kcp; kcp_stream_cp * zhis = ((kcp_stream_cp*)user); derived_t & derive = zhis->derive; error_code ec; if constexpr (args_t::is_session) derive.stream().send_to(asio::buffer(buf, len), derive.remote_endpoint_, 0, ec); else derive.stream().send(asio::buffer(buf, len), 0, ec); return 0; } inline void _call_illegal_data_callback(std::string_view data) { if (this->illegal_response_handler_) { this->illegal_response_handler_(data); } } protected: derived_t & derive; kcp::ikcpcb * kcp_ = nullptr; bool send_fin_ = true; asio::steady_timer kcp_timer_; handler_memory> tallocator_; std::function illegal_response_handler_; }; } #endif // !__ASIO2_KCP_STREAM_CP_HPP__