/* * Copyright (c) 2017-2023 zhllxt * * author : zhllxt * email : 37792738@qq.com * * Distributed under the Boost Software License, Version 1.0. (See accompanying * file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) */ #ifndef __ASIO2_MQTT_MESSAGE_ROUTER_HPP__ #define __ASIO2_MQTT_MESSAGE_ROUTER_HPP__ #if defined(_MSC_VER) && (_MSC_VER >= 1200) #pragma once #endif // defined(_MSC_VER) && (_MSC_VER >= 1200) #include <optional> #include <asio2/base/iopool.hpp> #include <asio2/base/define.hpp> #include <asio2/base/detail/function_traits.hpp> #include <asio2/base/detail/util.hpp> #include <asio2/base/detail/shared_mutex.hpp> #include <asio2/mqtt/detail/mqtt_topic_util.hpp> #include <asio2/mqtt/message.hpp> namespace asio2::detail { ASIO2_CLASS_FORWARD_DECLARE_BASE; ASIO2_CLASS_FORWARD_DECLARE_TCP_BASE; ASIO2_CLASS_FORWARD_DECLARE_TCP_SERVER; ASIO2_CLASS_FORWARD_DECLARE_TCP_SESSION; ASIO2_CLASS_FORWARD_DECLARE_TCP_CLIENT; /** * used for: * * bool ret = client.subscribe("/usr/topic1", 0, [](mqtt::message& msg){}); * util recvd the suback message, then the ret is true. * * bool ret = client.publish("/usr/topic1", "...payload...", 0); * util recvd the puback message, then the ret is true. * * and so on... */ template<class derived_t, class args_t> class mqtt_message_router_t { friend derived_t; ASIO2_CLASS_FRIEND_DECLARE_BASE; ASIO2_CLASS_FRIEND_DECLARE_TCP_BASE; ASIO2_CLASS_FRIEND_DECLARE_TCP_SERVER; ASIO2_CLASS_FRIEND_DECLARE_TCP_SESSION; ASIO2_CLASS_FRIEND_DECLARE_TCP_CLIENT; public: using self = mqtt_message_router_t<derived_t, args_t>; using args_type = args_t; using subnode_type = typename args_type::template subnode<derived_t>; using key_type = std::pair<mqtt::control_packet_type, mqtt::two_byte_integer::value_type>; using val_type = detail::function<void(mqtt::message&)>; struct hasher { inline std::size_t operator()(key_type const& pair) const noexcept { std::size_t v = detail::fnv1a_hash<std::size_t>( (const unsigned char*)(std::addressof(pair.first)), sizeof(pair.first)); return detail::fnv1a_hash<std::size_t>(v, (const unsigned char*)(std::addressof(pair.second)), sizeof(pair.second)); } }; /** * @brief constructor */ mqtt_message_router_t() = default; /** * @brief destructor */ ~mqtt_message_router_t() = default; protected: template<class FunctionT> inline bool _add_router(mqtt::message& msg, FunctionT&& callback) { derived_t& derive = static_cast<derived_t&>(*this); bool r = false; std::visit([&derive, &callback, &r](auto& m) mutable { r = derive._add_router(m, std::forward<FunctionT>(callback)); }, msg.base()); return r; } template<class Message, class FunctionT> typename std::enable_if_t<mqtt::is_rawmsg<Message>(), bool> inline _add_router(Message& msg, FunctionT&& callback) { derived_t& derive = static_cast<derived_t&>(*this); using message_type = typename detail::remove_cvref_t<Message>; if constexpr (!mqtt::has_packet_id<message_type>::value) { static_assert(detail::always_false_v<Message> && "This mqtt message don't has Packet Identifier"); return false; } else { ASIO2_ASSERT( msg.packet_type() >= mqtt::control_packet_type::connect && msg.packet_type() <= mqtt::control_packet_type::auth); if (!( msg.packet_type() >= mqtt::control_packet_type::connect && msg.packet_type() <= mqtt::control_packet_type::auth)) { return false; } key_type key = { msg.packet_type(), msg.packet_id() }; return derive._add_router(std::move(key), std::forward<FunctionT>(callback)); } } template<class FunctionT> inline bool _add_router(key_type key, FunctionT&& callback) { derived_t& derive = static_cast<derived_t&>(*this); derive.dispatch([&derive, key, cb = std::forward<FunctionT>(callback)]() mutable { derive._do_add_router(std::move(key), std::move(cb)); }); return true; } template<class FunctionT> inline bool _do_add_router(key_type key, FunctionT&& callback) { using fun_traits_type = function_traits<FunctionT>; using arg0_type = typename std::remove_cv_t<std::remove_reference_t< typename fun_traits_type::template args<0>::type>>; asio2::unique_locker g(this->message_router_mutex_); if constexpr (std::is_same_v<arg0_type, mqtt::message>) { auto[_1, inserted] = this->message_router_.insert_or_assign(std::move(key), std::forward<FunctionT>(callback)); ASIO2_ASSERT(inserted); asio2::ignore_unused(_1, inserted); } else { auto[_1, inserted] = this->message_router_.insert_or_assign(std::move(key), [cb = std::forward<FunctionT>(callback)](mqtt::message& msg) mutable { arg0_type* p = std::get_if<arg0_type>(std::addressof(msg.base())); if (p) { cb(*p); } }); ASIO2_ASSERT(inserted); asio2::ignore_unused(_1, inserted); } return true; } inline void _del_router(mqtt::message& msg) { derived_t& derive = static_cast<derived_t&>(*this); std::visit([&derive](auto& m) mutable { derive._del_router(m); }, msg.base()); } template<class Message> typename std::enable_if_t<mqtt::is_rawmsg<Message>(), void> inline _del_router(Message& msg) { derived_t& derive = static_cast<derived_t&>(*this); using message_type = typename detail::remove_cvref_t<Message>; if constexpr (!mqtt::has_packet_id<message_type>::value) { static_assert(detail::always_false_v<Message> && "This mqtt message don't has Packet Identifier"); return; } else { ASIO2_ASSERT( msg.packet_type() >= mqtt::control_packet_type::connect && msg.packet_type() <= mqtt::control_packet_type::auth); if (!( msg.packet_type() >= mqtt::control_packet_type::connect && msg.packet_type() <= mqtt::control_packet_type::auth)) { return; } key_type key = { msg.packet_type(), msg.packet_id() }; derive._del_router(std::move(key)); } } inline void _del_router(key_type key) { derived_t& derive = static_cast<derived_t&>(*this); derive.dispatch([this, key = std::move(key)]() mutable { asio2::unique_locker g(this->message_router_mutex_); this->message_router_.erase(key); }); } inline bool _match_router(mqtt::message& msg) { derived_t& derive = static_cast<derived_t&>(*this); std::optional<key_type> key = derive._generate_key(msg); if (!key.has_value()) return false; return derive._call_router(key.value(), msg); } inline bool _call_router(key_type key, mqtt::message& msg) { derived_t& derive = static_cast<derived_t&>(*this); derive.dispatch([this, msg, key = std::move(key)]() mutable { asio2::unique_locker g(this->message_router_mutex_); auto it = this->message_router_.find(key); if (it == this->message_router_.end()) return; (it->second)(msg); this->message_router_.erase(it); }); return true; } inline std::optional<key_type> _generate_key(mqtt::message& msg) { derived_t& derive = static_cast<derived_t&>(*this); std::optional<key_type> r; std::visit([&derive, &r](auto& m) mutable { r = derive._generate_key(m); }, msg.base()); return r; } template<class Message> typename std::enable_if_t<mqtt::is_rawmsg<Message>(), std::optional<key_type>> inline _generate_key(Message& msg) { using message_type = typename detail::remove_cvref_t<Message>; if constexpr (!mqtt::has_packet_id<message_type>::value) { return std::nullopt; } else { ASIO2_ASSERT( msg.packet_type() >= mqtt::control_packet_type::connect && msg.packet_type() <= mqtt::control_packet_type::auth); if (!( msg.packet_type() >= mqtt::control_packet_type::connect && msg.packet_type() <= mqtt::control_packet_type::auth)) { return std::nullopt; } std::optional<key_type> key; if /**/ constexpr (mqtt::is_puback_message<message_type>()) { key = { mqtt::control_packet_type::publish, msg.packet_id() }; } else if constexpr (mqtt::is_pubcomp_message<message_type>()) { key = { mqtt::control_packet_type::publish, msg.packet_id() }; } else if constexpr (mqtt::is_suback_message<message_type>()) { key = { mqtt::control_packet_type::subscribe, msg.packet_id() }; } else if constexpr (mqtt::is_unsuback_message<message_type>()) { key = { mqtt::control_packet_type::unsubscribe, msg.packet_id() }; } else { return std::nullopt; } return key; } } protected: /// use rwlock to make thread safe mutable asio2::shared_mutexer message_router_mutex_; /// router map, key - pair<mqtt::control_packet_type, packet id> std::unordered_map<key_type, val_type, hasher> message_router_ ASIO2_GUARDED_BY(message_router_mutex_); }; } #endif // !__ASIO2_MQTT_MESSAGE_ROUTER_HPP__