kcp_stream_cp.hpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599
  1. /*
  2. * Copyright (c) 2017-2023 zhllxt
  3. *
  4. * author : zhllxt
  5. * email : 37792738@qq.com
  6. *
  7. * Distributed under the Boost Software License, Version 1.0. (See accompanying
  8. * file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
  9. */
  10. #ifndef __ASIO2_KCP_STREAM_CP_HPP__
  11. #define __ASIO2_KCP_STREAM_CP_HPP__
  12. #if defined(_MSC_VER) && (_MSC_VER >= 1200)
  13. #pragma once
  14. #endif // defined(_MSC_VER) && (_MSC_VER >= 1200)
  15. #include <asio2/base/iopool.hpp>
  16. #include <asio2/base/define.hpp>
  17. #include <asio2/base/listener.hpp>
  18. #include <asio2/base/session_mgr.hpp>
  19. #include <asio2/base/detail/object.hpp>
  20. #include <asio2/base/detail/allocator.hpp>
  21. #include <asio2/base/detail/util.hpp>
  22. #include <asio2/base/detail/buffer_wrap.hpp>
  23. #include <asio2/udp/detail/kcp_util.hpp>
  24. namespace asio2::detail
  25. {
  26. ASIO2_CLASS_FORWARD_DECLARE_UDP_BASE;
  27. ASIO2_CLASS_FORWARD_DECLARE_UDP_CLIENT;
  28. ASIO2_CLASS_FORWARD_DECLARE_UDP_SERVER;
  29. ASIO2_CLASS_FORWARD_DECLARE_UDP_SESSION;
  30. /*
  31. * because udp is connectionless, in order to simplify the code logic, KCP shakes
  32. * hands only twice (compared with TCP handshakes three times)
  33. * 1 : client send syn to server
  34. * 2 : server send synack to client
  35. */
  36. template<class derived_t, class args_t>
  37. class kcp_stream_cp
  38. {
  39. friend derived_t; // C++11
  40. ASIO2_CLASS_FRIEND_DECLARE_UDP_BASE;
  41. ASIO2_CLASS_FRIEND_DECLARE_UDP_CLIENT;
  42. ASIO2_CLASS_FRIEND_DECLARE_UDP_SERVER;
  43. ASIO2_CLASS_FRIEND_DECLARE_UDP_SESSION;
  44. public:
  45. /**
  46. * @brief constructor
  47. */
  48. kcp_stream_cp(derived_t& d, asio::io_context& ioc)
  49. : derive(d), kcp_timer_(ioc)
  50. {
  51. }
  52. /**
  53. * @brief destructor
  54. */
  55. ~kcp_stream_cp() noexcept
  56. {
  57. if (this->kcp_)
  58. {
  59. kcp::ikcp_release(this->kcp_);
  60. this->kcp_ = nullptr;
  61. }
  62. }
  63. /**
  64. * @brief The udp protocol may received some illegal data. Use this function
  65. * to set up an illegal data handler. The default illegal data handler is empty,
  66. * when recvd illegal data, the default illegal data handler will do nothing.
  67. */
  68. inline void set_illegal_response_handler(std::function<void(std::string_view)> fn) noexcept
  69. {
  70. this->illegal_response_handler_ = std::move(fn);
  71. }
  72. protected:
  73. void _kcp_start(std::shared_ptr<derived_t> this_ptr, std::uint32_t conv)
  74. {
  75. // used to restore configs
  76. kcp::ikcpcb* old = this->kcp_;
  77. struct old_kcp_destructor
  78. {
  79. old_kcp_destructor(kcp::ikcpcb* p) : p_(p) {}
  80. ~old_kcp_destructor()
  81. {
  82. if (p_)
  83. kcp::ikcp_release(p_);
  84. }
  85. kcp::ikcpcb* p_ = nullptr;
  86. } old_kcp_destructor_guard(old);
  87. ASIO2_ASSERT(conv != 0);
  88. this->kcp_ = kcp::ikcp_create(conv, (void*)this);
  89. this->kcp_->output = &kcp_stream_cp<derived_t, args_t>::_kcp_output;
  90. if (old)
  91. {
  92. // ikcp_setmtu
  93. kcp::ikcp_setmtu(this->kcp_, old->mtu);
  94. // ikcp_wndsize
  95. kcp::ikcp_wndsize(this->kcp_, old->snd_wnd, old->rcv_wnd);
  96. // ikcp_nodelay
  97. kcp::ikcp_nodelay(this->kcp_, old->nodelay, old->interval, old->fastresend, old->nocwnd);
  98. }
  99. else
  100. {
  101. kcp::ikcp_nodelay(this->kcp_, 1, 10, 2, 1);
  102. kcp::ikcp_wndsize(this->kcp_, 128, 512);
  103. }
  104. // if call kcp_timer_.cancel first, then call _post_kcp_timer second immediately,
  105. // use asio::post to avoid start timer failed.
  106. asio::post(derive.io_->context(), make_allocator(derive.wallocator(),
  107. [this, this_ptr = std::move(this_ptr)]() mutable
  108. {
  109. this->_post_kcp_timer(std::move(this_ptr));
  110. }));
  111. }
  112. void _kcp_stop()
  113. {
  114. error_code ec_ignore{};
  115. // if is kcp mode, send FIN handshake before close
  116. if (this->send_fin_)
  117. this->_kcp_send_hdr(kcp::make_kcphdr_fin(0), ec_ignore);
  118. detail::cancel_timer(this->kcp_timer_);
  119. }
  120. inline void _kcp_reset()
  121. {
  122. kcp::ikcp_reset(this->kcp_);
  123. }
  124. protected:
  125. inline std::size_t _kcp_send_hdr(kcp::kcphdr hdr, error_code& ec)
  126. {
  127. std::string msg = kcp::to_string(hdr);
  128. std::size_t sent_bytes = 0;
  129. #if defined(_DEBUG) || defined(DEBUG)
  130. ASIO2_ASSERT(derive.post_send_counter_.load() == 0);
  131. derive.post_send_counter_++;
  132. #endif
  133. if constexpr (args_t::is_session)
  134. sent_bytes = derive.stream().send_to(asio::buffer(msg), derive.remote_endpoint_, 0, ec);
  135. else
  136. sent_bytes = derive.stream().send(asio::buffer(msg), 0, ec);
  137. #if defined(_DEBUG) || defined(DEBUG)
  138. derive.post_send_counter_--;
  139. #endif
  140. return sent_bytes;
  141. }
  142. inline std::size_t _kcp_send_syn(std::uint32_t seq, error_code& ec)
  143. {
  144. kcp::kcphdr syn = kcp::make_kcphdr_syn(derive.kcp_conv_, seq);
  145. return this->_kcp_send_hdr(syn, ec);
  146. }
  147. inline std::size_t _kcp_send_synack(kcp::kcphdr syn, error_code& ec)
  148. {
  149. // the syn.th_ack is the kcp conv
  150. kcp::kcphdr synack = kcp::make_kcphdr_synack(syn.th_ack, syn.th_seq);
  151. return this->_kcp_send_hdr(synack, ec);
  152. }
  153. template<class Data, class Callback>
  154. inline bool _kcp_send(Data& data, Callback&& callback)
  155. {
  156. #if defined(ASIO2_ENABLE_LOG)
  157. static_assert(decltype(tallocator_)::storage_size == 168);
  158. #endif
  159. auto buffer = asio::buffer(data);
  160. #if defined(_DEBUG) || defined(DEBUG)
  161. ASIO2_ASSERT(derive.post_send_counter_.load() == 0);
  162. derive.post_send_counter_++;
  163. #endif
  164. int ret = kcp::ikcp_send(this->kcp_, (const char *)buffer.data(), (int)buffer.size());
  165. #if defined(_DEBUG) || defined(DEBUG)
  166. derive.post_send_counter_--;
  167. #endif
  168. switch (ret)
  169. {
  170. case 0: set_last_error(error_code{} ); break;
  171. case -1: set_last_error(asio::error::invalid_argument ); break;
  172. case -2: set_last_error(asio::error::no_memory ); break;
  173. default: set_last_error(asio::error::operation_not_supported); break;
  174. }
  175. if (ret == 0)
  176. {
  177. kcp::ikcp_flush(this->kcp_);
  178. }
  179. callback(get_last_error(), ret < 0 ? 0 : buffer.size());
  180. return (ret == 0);
  181. }
  182. void _post_kcp_timer(std::shared_ptr<derived_t> this_ptr)
  183. {
  184. std::uint32_t clock1 = static_cast<std::uint32_t>(std::chrono::duration_cast<
  185. std::chrono::milliseconds>(std::chrono::steady_clock::now().time_since_epoch()).count());
  186. std::uint32_t clock2 = kcp::ikcp_check(this->kcp_, clock1);
  187. this->kcp_timer_.expires_after(std::chrono::milliseconds(clock2 - clock1));
  188. this->kcp_timer_.async_wait(make_allocator(this->tallocator_,
  189. [this, this_ptr = std::move(this_ptr)](const error_code & ec) mutable
  190. {
  191. this->_handle_kcp_timer(ec, std::move(this_ptr));
  192. }));
  193. }
  194. void _handle_kcp_timer(const error_code & ec, std::shared_ptr<derived_t> this_ptr)
  195. {
  196. if (ec == asio::error::operation_aborted) return;
  197. std::uint32_t clock = static_cast<std::uint32_t>(std::chrono::duration_cast<
  198. std::chrono::milliseconds>(std::chrono::steady_clock::now().time_since_epoch()).count());
  199. kcp::ikcp_update(this->kcp_, clock);
  200. if (this->kcp_->state == (kcp::IUINT32)-1)
  201. {
  202. if (derive.state_ == state_t::started)
  203. {
  204. derive._do_disconnect(asio::error::network_reset, std::move(this_ptr));
  205. }
  206. return;
  207. }
  208. if (derive.is_started())
  209. this->_post_kcp_timer(std::move(this_ptr));
  210. }
  211. template<typename C>
  212. void _kcp_recv(
  213. std::shared_ptr<derived_t>& this_ptr, std::shared_ptr<ecs_t<C>>& ecs, std::string_view data)
  214. {
  215. auto& buffer = derive.buffer();
  216. int len = kcp::ikcp_input(this->kcp_, (const char *)data.data(), (long)data.size());
  217. buffer.consume(buffer.size());
  218. if (len != 0)
  219. {
  220. set_last_error(asio::error::message_size);
  221. this->_call_illegal_data_callback(data);
  222. return;
  223. }
  224. for (;;)
  225. {
  226. len = kcp::ikcp_recv(this->kcp_, (char *)buffer.prepare(
  227. buffer.pre_size()).data(), (int)buffer.pre_size());
  228. if /**/ (len >= 0)
  229. {
  230. buffer.commit(len);
  231. derive._fire_recv(this_ptr, ecs, std::string_view(static_cast
  232. <std::string_view::const_pointer>(buffer.data().data()), len));
  233. buffer.consume(len);
  234. }
  235. else if (len == -3)
  236. {
  237. buffer.pre_size((std::min)(buffer.pre_size() * 2, buffer.max_size()));
  238. }
  239. else
  240. {
  241. break;
  242. }
  243. }
  244. kcp::ikcp_flush(this->kcp_);
  245. }
  246. template<typename C>
  247. inline void _kcp_handle_recv(
  248. error_code ec, std::string_view data,
  249. std::shared_ptr<derived_t>& this_ptr, std::shared_ptr<ecs_t<C>>& ecs)
  250. {
  251. // the kcp message header length is 24
  252. // the kcphdr length is 12
  253. if /**/ (data.size() > kcp::kcphdr::required_size())
  254. {
  255. this->_kcp_recv(this_ptr, ecs, data);
  256. }
  257. else if (data.size() == kcp::kcphdr::required_size())
  258. {
  259. // Check whether the packet is SYN handshake
  260. // It is possible that the client did not receive the synack package, then the client
  261. // will resend the syn package, so we just need to reply to the syncack package directly.
  262. // If the client is disconnect without send a "fin" or the server has't recvd the
  263. // "fin", and then the client connect again a later, at this time, the client
  264. // is in the session map already, and we need check whether the first message is fin
  265. if /**/ (kcp::is_kcphdr_syn(data))
  266. {
  267. ASIO2_ASSERT(this->kcp_);
  268. if (this->kcp_)
  269. {
  270. kcp::kcphdr syn = kcp::to_kcphdr(data);
  271. std::uint32_t conv = syn.th_ack;
  272. if (conv == 0)
  273. {
  274. conv = this->kcp_->conv;
  275. syn.th_ack = conv;
  276. }
  277. // If the client is forced disconnect after sent some messages, and the server
  278. // has recvd the messages already, we must recreated the kcp object, otherwise
  279. // the client and server will can't handle the next messages correctly.
  280. #if 0
  281. // set send_fin_ = false to make the _kcp_stop don't sent the fin frame.
  282. this->send_fin_ = false;
  283. this->_kcp_stop();
  284. this->_kcp_start(this_ptr, conv);
  285. #else
  286. this->_kcp_reset();
  287. #endif
  288. this->send_fin_ = true;
  289. // every time we recv kcp syn, we sent synack to the client
  290. this->_kcp_send_synack(syn, ec);
  291. if (ec)
  292. {
  293. derive._do_disconnect(ec, this_ptr);
  294. }
  295. }
  296. else
  297. {
  298. derive._do_disconnect(asio::error::operation_aborted, this_ptr);
  299. }
  300. }
  301. else if (kcp::is_kcphdr_synack(data, 0, true))
  302. {
  303. ASIO2_ASSERT(this->kcp_);
  304. }
  305. else if (kcp::is_kcphdr_ack(data, 0, true))
  306. {
  307. ASIO2_ASSERT(this->kcp_);
  308. }
  309. else if (kcp::is_kcphdr_fin(data))
  310. {
  311. ASIO2_ASSERT(this->kcp_);
  312. this->send_fin_ = false;
  313. derive._do_disconnect(asio::error::connection_reset, this_ptr);
  314. }
  315. else
  316. {
  317. this->_call_illegal_data_callback(data);
  318. }
  319. }
  320. else
  321. {
  322. this->_call_illegal_data_callback(data);
  323. }
  324. }
  325. template<typename C, typename DeferEvent>
  326. void _session_post_handshake(
  327. std::shared_ptr<derived_t> this_ptr, std::shared_ptr<ecs_t<C>> ecs, DeferEvent chain)
  328. {
  329. error_code ec;
  330. std::string& data = *(derive.first_data_);
  331. // step 3 : server recvd syn from client (the first data is the syn)
  332. kcp::kcphdr syn = kcp::to_kcphdr(data);
  333. std::uint32_t conv = syn.th_ack;
  334. if (conv == 0)
  335. {
  336. conv = derive.kcp_conv_;
  337. syn.th_ack = conv;
  338. }
  339. // step 4 : server send synack to client
  340. this->_kcp_send_synack(syn, ec);
  341. if (ec)
  342. {
  343. derive._do_disconnect(ec, std::move(this_ptr), std::move(chain));
  344. return;
  345. }
  346. this->_kcp_start(this_ptr, conv);
  347. this->_handle_handshake(ec, std::move(this_ptr), std::move(ecs), std::move(chain));
  348. }
  349. template<typename C, typename DeferEvent>
  350. void _client_post_handshake(
  351. std::shared_ptr<derived_t> this_ptr, std::shared_ptr<ecs_t<C>> ecs, DeferEvent chain)
  352. {
  353. error_code ec;
  354. // step 1 : client send syn to server
  355. std::uint32_t seq = static_cast<std::uint32_t>(
  356. std::chrono::duration_cast<std::chrono::milliseconds>(
  357. std::chrono::system_clock::now().time_since_epoch()).count());
  358. this->_kcp_send_syn(seq, ec);
  359. if (ec)
  360. {
  361. derive._do_disconnect(ec, std::move(this_ptr), defer_event(chain.move_guard()));
  362. return;
  363. }
  364. // use a loop timer to execute "client send syn to server" until the server
  365. // has recvd the syn packet and this client recvd reply.
  366. std::shared_ptr<detail::safe_timer> timer =
  367. mktimer(derive.io_->context(), std::chrono::milliseconds(500),
  368. [this, this_ptr, seq](error_code ec) mutable
  369. {
  370. if (ec == asio::error::operation_aborted)
  371. return false;
  372. this->_kcp_send_syn(seq, ec);
  373. if (ec)
  374. {
  375. set_last_error(ec);
  376. if (derive.state_ == state_t::started)
  377. {
  378. derive._do_disconnect(ec, std::move(this_ptr));
  379. }
  380. return false;
  381. }
  382. // return true : let the timer continue execute.
  383. // return false : kill the timer.
  384. return true;
  385. });
  386. #if defined(_DEBUG) || defined(DEBUG)
  387. ASIO2_ASSERT(derive.post_recv_counter_.load() == 0);
  388. derive.post_recv_counter_++;
  389. #endif
  390. // step 2 : client wait for recv synack util connect timeout or recvd some data
  391. derive.socket().async_receive(derive.buffer().prepare(derive.buffer().pre_size()),
  392. make_allocator(derive.rallocator(),
  393. [this, this_ptr = std::move(this_ptr), ecs = std::move(ecs), chain = std::move(chain),
  394. seq, timer = std::move(timer)]
  395. (const error_code & ec, std::size_t bytes_recvd) mutable
  396. {
  397. #if defined(_DEBUG) || defined(DEBUG)
  398. derive.post_recv_counter_--;
  399. #endif
  400. ASIO2_ASSERT(derive.io_->running_in_this_thread());
  401. timer->cancel();
  402. if (ec)
  403. {
  404. // if connect_timeout_timer_ is empty, it means that the connect timeout timer is
  405. // timeout and the callback has called already, so reset the error to timed_out.
  406. // note : when the async_resolve is failed, the socket is invalid to.
  407. this->_handle_handshake(
  408. derive.connect_timeout_timer_ ? ec : asio::error::timed_out,
  409. std::move(this_ptr), std::move(ecs), std::move(chain));
  410. return;
  411. }
  412. derive.buffer().commit(bytes_recvd);
  413. std::string_view data = std::string_view(static_cast<std::string_view::const_pointer>
  414. (derive.buffer().data().data()), bytes_recvd);
  415. // Check whether the data is the correct handshake information
  416. if (kcp::is_kcphdr_synack(data, seq))
  417. {
  418. kcp::kcphdr hdr = kcp::to_kcphdr(data);
  419. std::uint32_t conv = hdr.th_seq;
  420. if (derive.kcp_conv_ != 0)
  421. {
  422. ASIO2_ASSERT(derive.kcp_conv_ == conv);
  423. }
  424. this->_kcp_start(this_ptr, conv);
  425. this->_handle_handshake(ec, std::move(this_ptr), std::move(ecs), std::move(chain));
  426. }
  427. else
  428. {
  429. this->_handle_handshake(asio::error::address_family_not_supported,
  430. std::move(this_ptr), std::move(ecs), std::move(chain));
  431. }
  432. derive.buffer().consume(bytes_recvd);
  433. }));
  434. }
  435. template<typename C, typename DeferEvent>
  436. void _post_handshake(std::shared_ptr<derived_t> this_ptr, std::shared_ptr<ecs_t<C>> ecs, DeferEvent chain)
  437. {
  438. if constexpr (args_t::is_session)
  439. {
  440. this->_session_post_handshake(std::move(this_ptr), std::move(ecs), std::move(chain));
  441. }
  442. else
  443. {
  444. this->_client_post_handshake(std::move(this_ptr), std::move(ecs), std::move(chain));
  445. }
  446. }
  447. template<typename C, typename DeferEvent>
  448. void _handle_handshake(
  449. const error_code& ec, std::shared_ptr<derived_t> this_ptr, std::shared_ptr<ecs_t<C>> ecs,
  450. DeferEvent chain)
  451. {
  452. set_last_error(ec);
  453. if constexpr (args_t::is_session)
  454. {
  455. derive._fire_handshake(this_ptr);
  456. if (ec)
  457. {
  458. derive._do_disconnect(ec, std::move(this_ptr), std::move(chain));
  459. return;
  460. }
  461. derive._done_connect(ec, std::move(this_ptr), std::move(ecs), std::move(chain));
  462. }
  463. else
  464. {
  465. derive._fire_handshake(this_ptr);
  466. derive._done_connect(ec, std::move(this_ptr), std::move(ecs), std::move(chain));
  467. }
  468. }
  469. static int _kcp_output(const char *buf, int len, kcp::ikcpcb *kcp, void *user)
  470. {
  471. std::ignore = kcp;
  472. kcp_stream_cp * zhis = ((kcp_stream_cp*)user);
  473. derived_t & derive = zhis->derive;
  474. error_code ec;
  475. if constexpr (args_t::is_session)
  476. derive.stream().send_to(asio::buffer(buf, len), derive.remote_endpoint_, 0, ec);
  477. else
  478. derive.stream().send(asio::buffer(buf, len), 0, ec);
  479. return 0;
  480. }
  481. inline void _call_illegal_data_callback(std::string_view data)
  482. {
  483. if (this->illegal_response_handler_)
  484. {
  485. this->illegal_response_handler_(data);
  486. }
  487. }
  488. protected:
  489. derived_t & derive;
  490. kcp::ikcpcb * kcp_ = nullptr;
  491. bool send_fin_ = true;
  492. asio::steady_timer kcp_timer_;
  493. handler_memory<std::true_type, allocator_fixed_size_op<168>> tallocator_;
  494. std::function<void(std::string_view)> illegal_response_handler_;
  495. };
  496. }
  497. #endif // !__ASIO2_KCP_STREAM_CP_HPP__