mqtt_session.hpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. /*
  2. * Copyright (c) 2017-2023 zhllxt
  3. *
  4. * author : zhllxt
  5. * email : 37792738@qq.com
  6. *
  7. * Distributed under the Boost Software License, Version 1.0. (See accompanying
  8. * file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
  9. */
  10. #ifndef __ASIO2_MQTT_SESSION_HPP__
  11. #define __ASIO2_MQTT_SESSION_HPP__
  12. #if defined(_MSC_VER) && (_MSC_VER >= 1200)
  13. #pragma once
  14. #endif // defined(_MSC_VER) && (_MSC_VER >= 1200)
  15. #include <asio2/base/detail/push_options.hpp>
  16. #include <asio2/tcp/tcp_session.hpp>
  17. #include <asio2/mqtt/impl/mqtt_recv_connect_op.hpp>
  18. #include <asio2/mqtt/impl/mqtt_send_op.hpp>
  19. #include <asio2/mqtt/detail/mqtt_handler.hpp>
  20. #include <asio2/mqtt/detail/mqtt_invoker.hpp>
  21. #include <asio2/mqtt/detail/mqtt_topic_alias.hpp>
  22. #include <asio2/mqtt/detail/mqtt_session_state.hpp>
  23. #include <asio2/mqtt/detail/mqtt_broker_state.hpp>
  24. #include <asio2/mqtt/idmgr.hpp>
  25. #include <asio2/mqtt/options.hpp>
  26. namespace asio2::detail
  27. {
  28. struct template_args_mqtt_session : public template_args_tcp_session
  29. {
  30. static constexpr bool rdc_call_cp_enabled = false;
  31. template<class caller_t>
  32. struct subnode
  33. {
  34. explicit subnode(
  35. std::weak_ptr<caller_t> c,
  36. mqtt::subscription s,
  37. mqtt::v5::properties_set p = mqtt::v5::properties_set{}
  38. )
  39. : caller(std::move(c))
  40. , sub (std::move(s))
  41. , props (std::move(p))
  42. {
  43. }
  44. inline std::string_view share_name () { return sub.share_name (); }
  45. inline std::string_view topic_filter() { return sub.topic_filter(); }
  46. //
  47. std::weak_ptr<caller_t> caller;
  48. // subscription info
  49. mqtt::subscription sub;
  50. // subscription properties
  51. mqtt::v5::properties_set props;
  52. };
  53. };
  54. ASIO2_CLASS_FORWARD_DECLARE_BASE;
  55. ASIO2_CLASS_FORWARD_DECLARE_TCP_BASE;
  56. ASIO2_CLASS_FORWARD_DECLARE_TCP_SERVER;
  57. ASIO2_CLASS_FORWARD_DECLARE_TCP_SESSION;
  58. template<class derived_t, class args_t = template_args_mqtt_session>
  59. class mqtt_session_impl_t
  60. : public tcp_session_impl_t<derived_t, args_t>
  61. , public mqtt_options
  62. , public mqtt_handler_t <derived_t, args_t>
  63. , public mqtt_topic_alias_t<derived_t, args_t>
  64. , public mqtt_send_op <derived_t, args_t>
  65. , public mqtt::session_state
  66. {
  67. ASIO2_CLASS_FRIEND_DECLARE_BASE;
  68. ASIO2_CLASS_FRIEND_DECLARE_TCP_BASE;
  69. ASIO2_CLASS_FRIEND_DECLARE_TCP_SERVER;
  70. ASIO2_CLASS_FRIEND_DECLARE_TCP_SESSION;
  71. template <class> friend class mqtt::shared_target;
  72. public:
  73. using super = tcp_session_impl_t <derived_t, args_t>;
  74. using self = mqtt_session_impl_t<derived_t, args_t>;
  75. using args_type = args_t;
  76. using key_type = std::size_t;
  77. using subnode_type = typename args_type::template subnode<derived_t>;
  78. using super::send;
  79. using super::async_send;
  80. public:
  81. /**
  82. * @brief constructor
  83. */
  84. explicit mqtt_session_impl_t(
  85. mqtt::broker_state<derived_t, args_t>& broker_state,
  86. session_mgr_t <derived_t>& sessions,
  87. listener_t & listener,
  88. std::shared_ptr<io_t> rwio,
  89. std::size_t init_buf_size,
  90. std::size_t max_buf_size
  91. )
  92. : super(sessions, listener, std::move(rwio), init_buf_size, max_buf_size)
  93. , mqtt_options ()
  94. , mqtt_handler_t <derived_t, args_t>()
  95. , mqtt_topic_alias_t<derived_t, args_t>()
  96. , mqtt_send_op <derived_t, args_t>()
  97. , broker_state_(broker_state)
  98. {
  99. this->set_silence_timeout(std::chrono::milliseconds(mqtt_silence_timeout));
  100. }
  101. /**
  102. * @brief destructor
  103. */
  104. ~mqtt_session_impl_t()
  105. {
  106. }
  107. public:
  108. /**
  109. * @brief get this object hash key,used for session map
  110. */
  111. inline key_type hash_key() const
  112. {
  113. return reinterpret_cast<key_type>(this);
  114. }
  115. /**
  116. * @brief get the mqtt version number
  117. */
  118. inline mqtt::version version() const
  119. {
  120. return this->get_version();
  121. }
  122. /**
  123. * @brief get the mqtt version number
  124. */
  125. inline mqtt::version get_version() const
  126. {
  127. return this->version_;
  128. }
  129. /**
  130. * @brief get the mqtt client id
  131. */
  132. inline std::string_view client_id() const
  133. {
  134. return this->get_client_id();
  135. }
  136. /**
  137. * @brief get the mqtt client id
  138. */
  139. inline std::string_view get_client_id() const
  140. {
  141. std::string_view id{};
  142. if (!this->connect_message_.empty())
  143. {
  144. if /**/ (std::holds_alternative<mqtt::v3::connect>(connect_message_.base()))
  145. {
  146. id = this->connect_message_.template get_if<mqtt::v3::connect>()->client_id();
  147. }
  148. else if (std::holds_alternative<mqtt::v4::connect>(connect_message_.base()))
  149. {
  150. id = this->connect_message_.template get_if<mqtt::v4::connect>()->client_id();
  151. }
  152. else if (std::holds_alternative<mqtt::v5::connect>(connect_message_.base()))
  153. {
  154. id = this->connect_message_.template get_if<mqtt::v5::connect>()->client_id();
  155. }
  156. }
  157. return id;
  158. }
  159. inline void remove_subscribed_topic(std::string_view topic_filter)
  160. {
  161. this->subs_map().erase(topic_filter, this->client_id());
  162. }
  163. inline void remove_all_subscribed_topic()
  164. {
  165. this->subs_map().erase(this->client_id());
  166. }
  167. protected:
  168. template<typename E = defer_event<void, derived_t>>
  169. inline void _do_disconnect(
  170. const error_code& ec, std::shared_ptr<derived_t> this_ptr, E chain = defer_event<void, derived_t>{})
  171. {
  172. state_t expected = state_t::started;
  173. if (this->derived().state_.compare_exchange_strong(expected, state_t::started))
  174. {
  175. mqtt::version ver = this->derived().version();
  176. if /**/ (ver == mqtt::version::v3)
  177. {
  178. mqtt::v3::disconnect disconnect;
  179. this->derived().internal_async_send(std::move(this_ptr), std::move(disconnect),
  180. [this, ec, e = chain.move_event()]
  181. (std::shared_ptr<derived_t> this_ptr, const error_code&,
  182. std::size_t, event_queue_guard<derived_t> g) mutable
  183. {
  184. defer_event chain(std::move(e), std::move(g));
  185. super::_do_disconnect(ec, std::move(this_ptr), std::move(chain));
  186. }, chain.move_guard());
  187. return;
  188. }
  189. else if (ver == mqtt::version::v4)
  190. {
  191. mqtt::v4::disconnect disconnect;
  192. this->derived().internal_async_send(std::move(this_ptr), std::move(disconnect),
  193. [this, ec, e = chain.move_event()]
  194. (std::shared_ptr<derived_t> this_ptr, const error_code&,
  195. std::size_t, event_queue_guard<derived_t> g) mutable
  196. {
  197. defer_event chain(std::move(e), std::move(g));
  198. super::_do_disconnect(ec, std::move(this_ptr), std::move(chain));
  199. }, chain.move_guard());
  200. return;
  201. }
  202. else if (ver == mqtt::version::v5)
  203. {
  204. // https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901208
  205. mqtt::v5::disconnect disconnect;
  206. if (ec.value() != 4)
  207. disconnect.reason_code(static_cast<std::uint8_t>(ec.value()));
  208. this->derived().internal_async_send(std::move(this_ptr), std::move(disconnect),
  209. [this, ec, e = chain.move_event()]
  210. (std::shared_ptr<derived_t> this_ptr, const error_code&,
  211. std::size_t, event_queue_guard<derived_t> g) mutable
  212. {
  213. defer_event chain(std::move(e), std::move(g));
  214. super::_do_disconnect(ec, std::move(this_ptr), std::move(chain));
  215. }, chain.move_guard());
  216. return;
  217. }
  218. else
  219. {
  220. ASIO2_ASSERT(false);
  221. }
  222. }
  223. super::_do_disconnect(ec, std::move(this_ptr), std::move(chain));
  224. }
  225. template<typename C, typename DeferEvent>
  226. inline void _handle_connect(
  227. const error_code& ec,
  228. std::shared_ptr<derived_t> this_ptr, std::shared_ptr<ecs_t<C>> ecs, DeferEvent chain)
  229. {
  230. detail::ignore_unused(ec);
  231. ASIO2_ASSERT(!ec);
  232. ASIO2_ASSERT(this->derived().sessions_.io_->running_in_this_thread());
  233. asio::dispatch(this->derived().io_->context(), make_allocator(this->derived().wallocator(),
  234. [this, this_ptr = std::move(this_ptr), ecs = std::move(ecs), chain = std::move(chain)]
  235. () mutable
  236. {
  237. derived_t& derive = this->derived();
  238. ASIO2_ASSERT(derive.io_->running_in_this_thread());
  239. // wait for the connect message which send by the client.
  240. mqtt_recv_connect_op
  241. {
  242. derive.io_->context(),
  243. derive.stream(),
  244. [this, this_ptr = std::move(this_ptr), ecs = std::move(ecs), chain = std::move(chain)]
  245. (error_code ec, std::unique_ptr<asio::streambuf> stream) mutable
  246. {
  247. this->derived()._handle_mqtt_connect_message(ec,
  248. std::move(this_ptr),std::move(ecs), std::move(stream), std::move(chain));
  249. }
  250. };
  251. }));
  252. }
  253. template<typename C, typename DeferEvent>
  254. inline void _handle_mqtt_connect_message(
  255. error_code ec,
  256. std::shared_ptr<derived_t> this_ptr, std::shared_ptr<ecs_t<C>> ecs,
  257. std::unique_ptr<asio::streambuf> stream, DeferEvent chain)
  258. {
  259. do
  260. {
  261. if (ec)
  262. break;
  263. std::string_view data{ reinterpret_cast<std::string_view::const_pointer>(
  264. static_cast<const char*>(stream->data().data())), stream->size() };
  265. mqtt::control_packet_type type = mqtt::message_type_from_data(data);
  266. // If the server does not receive a CONNECT message within a reasonable amount of time
  267. // after the TCP/IP connection is established, the server should close the connection.
  268. if (type != mqtt::control_packet_type::connect)
  269. {
  270. ec = mqtt::make_error_code(mqtt::error::malformed_packet);
  271. break;
  272. }
  273. // parse the connect message to get the mqtt version
  274. mqtt::version ver = mqtt::version_from_connect_data(data);
  275. if (ver != mqtt::version::v3 && ver != mqtt::version::v4 && ver != mqtt::version::v5)
  276. {
  277. ec = mqtt::make_error_code(mqtt::error::unsupported_protocol_version);
  278. break;
  279. }
  280. this->version_ = ver;
  281. // If the client sends an invalid CONNECT message, the server should close the connection.
  282. // This includes CONNECT messages that provide invalid Protocol Name or Protocol Version Numbers.
  283. // If the server can parse enough of the CONNECT message to determine that an invalid protocol
  284. // has been requested, it may try to send a CONNACK containing the "Connection Refused:
  285. // unacceptable protocol version" code before dropping the connection.
  286. this->invoker()._call_mqtt_handler(type, ec, this_ptr, static_cast<derived_t*>(this), data);
  287. } while (false);
  288. this->derived().sessions_.dispatch(
  289. [this, ec, this_ptr = std::move(this_ptr), ecs = std::move(ecs), chain = std::move(chain)]
  290. () mutable
  291. {
  292. super::_handle_connect(ec, std::move(this_ptr), std::move(ecs), std::move(chain));
  293. });
  294. }
  295. template<typename DeferEvent>
  296. inline void _handle_disconnect(const error_code& ec, std::shared_ptr<derived_t> this_ptr, DeferEvent chain)
  297. {
  298. std::string_view clientid = this->client_id();
  299. this->subs_map().erase(clientid);
  300. this->mqtt_sessions().erase_mqtt_session(clientid, static_cast<derived_t*>(this));
  301. super::_handle_disconnect(ec, std::move(this_ptr), std::move(chain));
  302. }
  303. protected:
  304. template<class Data, class Callback>
  305. inline bool _do_send(Data& data, Callback&& callback)
  306. {
  307. return this->derived()._mqtt_send(data, std::forward<Callback>(callback));
  308. }
  309. protected:
  310. template<typename C>
  311. inline void _fire_recv(
  312. std::shared_ptr<derived_t>& this_ptr, std::shared_ptr<ecs_t<C>>& ecs, std::string_view data)
  313. {
  314. data = detail::call_data_filter_before_recv(this->derived(), data);
  315. this->listener_.notify(event_type::recv, this_ptr, data);
  316. this->derived()._rdc_handle_recv(this_ptr, ecs, data);
  317. mqtt::control_packet_type type = mqtt::message_type_from_data(data);
  318. if (type > mqtt::control_packet_type::auth)
  319. {
  320. ASIO2_ASSERT(false);
  321. this->derived()._do_disconnect(mqtt::make_error_code(mqtt::error::malformed_packet), this_ptr);
  322. return;
  323. }
  324. error_code ec;
  325. this->invoker()._call_mqtt_handler(type, ec, this_ptr, static_cast<derived_t*>(this), data);
  326. if (ec)
  327. {
  328. this->derived()._do_disconnect(ec, this_ptr);
  329. }
  330. }
  331. inline auto& invoker () noexcept { return this->broker_state_.invoker_ ; }
  332. inline auto& mqtt_sessions () noexcept { return this->broker_state_.mqtt_sessions_ ; }
  333. inline auto& subs_map () noexcept { return this->broker_state_.subs_map_ ; }
  334. inline auto& shared_targets () noexcept { return this->broker_state_.shared_targets_ ; }
  335. inline auto& retained_messages() noexcept { return this->broker_state_.retained_messages_; }
  336. inline auto& security () noexcept { return this->broker_state_.security_ ; }
  337. inline auto const& invoker () const noexcept { return this->broker_state_.invoker_ ; }
  338. inline auto const& mqtt_sessions () const noexcept { return this->broker_state_.mqtt_sessions_ ; }
  339. inline auto const& subs_map () const noexcept { return this->broker_state_.subs_map_ ; }
  340. inline auto const& shared_targets () const noexcept { return this->broker_state_.shared_targets_ ; }
  341. inline auto const& retained_messages() const noexcept { return this->broker_state_.retained_messages_; }
  342. inline auto const& security () const noexcept { return this->broker_state_.security_ ; }
  343. inline void set_preauthed_username(std::optional<std::string> username)
  344. {
  345. preauthed_username_ = std::move(username);
  346. }
  347. inline std::optional<std::string> get_preauthed_username() const
  348. {
  349. return preauthed_username_;
  350. }
  351. protected:
  352. ///
  353. mqtt::broker_state<derived_t, args_t> & broker_state_;
  354. /// packet id manager
  355. mqtt::idmgr<std::atomic<mqtt::two_byte_integer::value_type>> idmgr_;
  356. /// user to find session for shared targets
  357. std::chrono::nanoseconds::rep shared_target_key_;
  358. mqtt::message connect_message_{};
  359. std::optional<std::string> preauthed_username_;
  360. mqtt::version version_ = static_cast<mqtt::version>(0);
  361. };
  362. }
  363. namespace asio2
  364. {
  365. using mqtt_session_args = detail::template_args_mqtt_session;
  366. template<class derived_t, class args_t>
  367. using mqtt_session_impl_t = detail::mqtt_session_impl_t<derived_t, args_t>;
  368. template<class derived_t>
  369. class mqtt_session_t : public detail::mqtt_session_impl_t<derived_t, detail::template_args_mqtt_session>
  370. {
  371. public:
  372. using detail::mqtt_session_impl_t<derived_t, detail::template_args_mqtt_session>::mqtt_session_impl_t;
  373. };
  374. class mqtt_session : public mqtt_session_t<mqtt_session>
  375. {
  376. public:
  377. using mqtt_session_t<mqtt_session>::mqtt_session_t;
  378. };
  379. }
  380. #include <asio2/base/detail/pop_options.hpp>
  381. #endif // !__ASIO2_MQTT_SESSION_HPP__