mqtt_security.hpp 37 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361
  1. /*
  2. * Copyright (c) 2017-2023 zhllxt
  3. *
  4. * author : zhllxt
  5. * email : 37792738@qq.com
  6. *
  7. * Distributed under the Boost Software License, Version 1.0. (See accompanying
  8. * file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
  9. *
  10. * refrenced from : mqtt_cpp/include/mqtt/broker/security.hpp
  11. */
  12. #ifndef __ASIO2_MQTT_SECURITY_HPP__
  13. #define __ASIO2_MQTT_SECURITY_HPP__
  14. #if defined(_MSC_VER) && (_MSC_VER >= 1200)
  15. #pragma once
  16. #endif // defined(_MSC_VER) && (_MSC_VER >= 1200)
  17. #include <cstdint>
  18. #include <string>
  19. #include <string_view>
  20. #include <type_traits>
  21. #include <unordered_map>
  22. #include <algorithm>
  23. #include <variant>
  24. #include <map>
  25. #include <set>
  26. #include <optional>
  27. #include <iterator> // for std::iterator_traits
  28. #include <stdexcept>
  29. #include <asio2/base/iopool.hpp>
  30. #include <asio2/util/string.hpp>
  31. #include <asio2/base/detail/shared_mutex.hpp>
  32. #if !defined(INCLUDE_NLOHMANN_JSON_HPP_) && !defined(NLOHMANN_JSON_HPP)
  33. #include <asio2/external/json.hpp>
  34. #endif
  35. #include <asio2/mqtt/message.hpp>
  36. #include <asio2/mqtt/detail/mqtt_topic_util.hpp>
  37. #include <asio2/mqtt/detail/mqtt_subscription_map.hpp>
  38. #if defined(ASIO2_ENABLE_SSL) || defined(ASIO2_USE_SSL)
  39. #include <openssl/evp.h>
  40. #endif
  41. namespace asio2::mqtt
  42. {
  43. namespace detail { namespace algorithm {
  44. /*!
  45. \struct hex_decode_error
  46. \brief Base exception class for all hex decoding errors
  47. */ /*!
  48. \struct non_hex_input
  49. \brief Thrown when a non-hex value (0-9, A-F) encountered when decoding.
  50. Contains the offending character
  51. */ /*!
  52. \struct not_enough_input
  53. \brief Thrown when the input sequence unexpectedly ends
  54. */
  55. struct hex_decode_error : virtual std::exception {};
  56. struct not_enough_input : virtual hex_decode_error {};
  57. struct non_hex_input : virtual hex_decode_error {};
  58. namespace detail {
  59. /// \cond DOXYGEN_HIDE
  60. template <typename T, typename OutputIterator>
  61. OutputIterator encode_one ( T val, OutputIterator out, const char * hexDigits ) {
  62. const std::size_t num_hex_digits = 2 * sizeof ( T );
  63. char res [ num_hex_digits ];
  64. char *p = res + num_hex_digits;
  65. for ( std::size_t i = 0; i < num_hex_digits; ++i, val >>= 4 )
  66. *--p = hexDigits [ val & 0x0F ];
  67. return std::copy ( res, res + num_hex_digits, out );
  68. }
  69. template <typename T>
  70. unsigned char hex_char_to_int ( T val ) {
  71. char c = static_cast<char> ( val );
  72. unsigned retval = 0;
  73. if ( c >= '0' && c <= '9' ) retval = c - '0';
  74. else if ( c >= 'A' && c <= 'F' ) retval = c - 'A' + 10;
  75. else if ( c >= 'a' && c <= 'f' ) retval = c - 'a' + 10;
  76. else throw non_hex_input();
  77. return static_cast<char>(retval);
  78. }
  79. // My own iterator_traits class.
  80. // It is here so that I can "reach inside" some kinds of output iterators
  81. // and get the type to write.
  82. template <typename Iterator>
  83. struct hex_iterator_traits {
  84. typedef typename std::iterator_traits<Iterator>::value_type value_type;
  85. };
  86. template<typename Container>
  87. struct hex_iterator_traits< std::back_insert_iterator<Container> > {
  88. typedef typename Container::value_type value_type;
  89. };
  90. template<typename Container>
  91. struct hex_iterator_traits< std::front_insert_iterator<Container> > {
  92. typedef typename Container::value_type value_type;
  93. };
  94. template<typename Container>
  95. struct hex_iterator_traits< std::insert_iterator<Container> > {
  96. typedef typename Container::value_type value_type;
  97. };
  98. // ostream_iterators have three template parameters.
  99. // The first one is the output type, the second one is the character type of
  100. // the underlying stream, the third is the character traits.
  101. // We only care about the first one.
  102. template<typename T, typename charType, typename traits>
  103. struct hex_iterator_traits< std::ostream_iterator<T, charType, traits> > {
  104. typedef T value_type;
  105. };
  106. template <typename Iterator>
  107. bool iter_end ( Iterator current, Iterator last ) { return current == last; }
  108. template <typename T>
  109. bool ptr_end ( const T* ptr, const T* /*end*/ ) { return *ptr == '\0'; }
  110. // What can we assume here about the inputs?
  111. // is std::iterator_traits<InputIterator>::value_type always 'char' ?
  112. // Could it be wchar_t, say? Does it matter?
  113. // We are assuming ASCII for the values - but what about the storage?
  114. template <typename InputIterator, typename OutputIterator, typename EndPred>
  115. typename std::enable_if<std::is_integral_v<typename hex_iterator_traits<OutputIterator>::value_type>, OutputIterator>::type
  116. decode_one ( InputIterator &first, InputIterator last, OutputIterator out, EndPred pred ) {
  117. typedef typename hex_iterator_traits<OutputIterator>::value_type T;
  118. T res (0);
  119. // Need to make sure that we get can read that many chars here.
  120. for ( std::size_t i = 0; i < 2 * sizeof ( T ); ++i, ++first ) {
  121. if ( pred ( first, last ))
  122. throw not_enough_input();
  123. res = ( 16 * res ) + hex_char_to_int (*first);
  124. }
  125. *out = res;
  126. return ++out;
  127. }
  128. /// \endcond
  129. }
  130. /// \fn hex ( InputIterator first, InputIterator last, OutputIterator out )
  131. /// \brief Converts a sequence of integral types into a hexadecimal sequence of characters.
  132. ///
  133. /// \param first The start of the input sequence
  134. /// \param last One past the end of the input sequence
  135. /// \param out An output iterator to the results into
  136. /// \return The updated output iterator
  137. /// \note Based on the MySQL function of the same name
  138. template <typename InputIterator, typename OutputIterator>
  139. typename std::enable_if<std::is_integral_v<typename detail::hex_iterator_traits<InputIterator>::value_type>, OutputIterator>::type
  140. hex ( InputIterator first, InputIterator last, OutputIterator out ) {
  141. for ( ; first != last; ++first )
  142. out = detail::encode_one ( *first, out, "0123456789ABCDEF" );
  143. return out;
  144. }
  145. /// \fn hex_lower ( InputIterator first, InputIterator last, OutputIterator out )
  146. /// \brief Converts a sequence of integral types into a lower case hexadecimal sequence of characters.
  147. ///
  148. /// \param first The start of the input sequence
  149. /// \param last One past the end of the input sequence
  150. /// \param out An output iterator to the results into
  151. /// \return The updated output iterator
  152. /// \note Based on the MySQL function of the same name
  153. template <typename InputIterator, typename OutputIterator>
  154. typename std::enable_if<std::is_integral_v<typename detail::hex_iterator_traits<InputIterator>::value_type>, OutputIterator>::type
  155. hex_lower ( InputIterator first, InputIterator last, OutputIterator out ) {
  156. for ( ; first != last; ++first )
  157. out = detail::encode_one ( *first, out, "0123456789abcdef" );
  158. return out;
  159. }
  160. /// \fn hex ( const T *ptr, OutputIterator out )
  161. /// \brief Converts a sequence of integral types into a hexadecimal sequence of characters.
  162. ///
  163. /// \param ptr A pointer to a 0-terminated sequence of data.
  164. /// \param out An output iterator to the results into
  165. /// \return The updated output iterator
  166. /// \note Based on the MySQL function of the same name
  167. template <typename T, typename OutputIterator>
  168. typename std::enable_if<std::is_integral_v<T>, OutputIterator>::type
  169. hex ( const T *ptr, OutputIterator out ) {
  170. while ( *ptr )
  171. out = detail::encode_one ( *ptr++, out, "0123456789ABCDEF" );
  172. return out;
  173. }
  174. /// \fn hex_lower ( const T *ptr, OutputIterator out )
  175. /// \brief Converts a sequence of integral types into a lower case hexadecimal sequence of characters.
  176. ///
  177. /// \param ptr A pointer to a 0-terminated sequence of data.
  178. /// \param out An output iterator to the results into
  179. /// \return The updated output iterator
  180. /// \note Based on the MySQL function of the same name
  181. template <typename T, typename OutputIterator>
  182. typename std::enable_if<std::is_integral_v<T>, OutputIterator>::type
  183. hex_lower ( const T *ptr, OutputIterator out ) {
  184. while ( *ptr )
  185. out = detail::encode_one ( *ptr++, out, "0123456789abcdef" );
  186. return out;
  187. }
  188. /// \fn hex ( const Range &r, OutputIterator out )
  189. /// \brief Converts a sequence of integral types into a hexadecimal sequence of characters.
  190. ///
  191. /// \param r The input range
  192. /// \param out An output iterator to the results into
  193. /// \return The updated output iterator
  194. /// \note Based on the MySQL function of the same name
  195. template <typename Range, typename OutputIterator>
  196. typename std::enable_if<std::is_integral_v<typename detail::hex_iterator_traits<typename Range::iterator>::value_type>, OutputIterator>::type
  197. hex ( const Range &r, OutputIterator out ) {
  198. return hex (std::begin(r), std::end(r), out);
  199. }
  200. /// \fn hex_lower ( const Range &r, OutputIterator out )
  201. /// \brief Converts a sequence of integral types into a lower case hexadecimal sequence of characters.
  202. ///
  203. /// \param r The input range
  204. /// \param out An output iterator to the results into
  205. /// \return The updated output iterator
  206. /// \note Based on the MySQL function of the same name
  207. template <typename Range, typename OutputIterator>
  208. typename std::enable_if<std::is_integral_v<typename detail::hex_iterator_traits<typename Range::iterator>::value_type>, OutputIterator>::type
  209. hex_lower ( const Range &r, OutputIterator out ) {
  210. return hex_lower (std::begin(r), std::end(r), out);
  211. }
  212. /// \fn unhex ( InputIterator first, InputIterator last, OutputIterator out )
  213. /// \brief Converts a sequence of hexadecimal characters into a sequence of integers.
  214. ///
  215. /// \param first The start of the input sequence
  216. /// \param last One past the end of the input sequence
  217. /// \param out An output iterator to the results into
  218. /// \return The updated output iterator
  219. /// \note Based on the MySQL function of the same name
  220. template <typename InputIterator, typename OutputIterator>
  221. OutputIterator unhex ( InputIterator first, InputIterator last, OutputIterator out ) {
  222. while ( first != last )
  223. out = detail::decode_one ( first, last, out, detail::iter_end<InputIterator> );
  224. return out;
  225. }
  226. /// \fn unhex ( const T *ptr, OutputIterator out )
  227. /// \brief Converts a sequence of hexadecimal characters into a sequence of integers.
  228. ///
  229. /// \param ptr A pointer to a null-terminated input sequence.
  230. /// \param out An output iterator to the results into
  231. /// \return The updated output iterator
  232. /// \note Based on the MySQL function of the same name
  233. template <typename T, typename OutputIterator>
  234. OutputIterator unhex ( const T *ptr, OutputIterator out ) {
  235. // If we run into the terminator while decoding, we will throw a
  236. // malformed input exception. It would be nicer to throw a 'Not enough input'
  237. // exception - but how much extra work would that require?
  238. while ( *ptr )
  239. out = detail::decode_one ( ptr, (const T *) NULL, out, detail::ptr_end<T> );
  240. return out;
  241. }
  242. /// \fn OutputIterator unhex ( const Range &r, OutputIterator out )
  243. /// \brief Converts a sequence of hexadecimal characters into a sequence of integers.
  244. ///
  245. /// \param r The input range
  246. /// \param out An output iterator to the results into
  247. /// \return The updated output iterator
  248. /// \note Based on the MySQL function of the same name
  249. template <typename Range, typename OutputIterator>
  250. OutputIterator unhex ( const Range &r, OutputIterator out ) {
  251. return unhex (std::begin(r), std::end(r), out);
  252. }
  253. /// \fn String hex ( const String &input )
  254. /// \brief Converts a sequence of integral types into a hexadecimal sequence of characters.
  255. ///
  256. /// \param input A container to be converted
  257. /// \return A container with the encoded text
  258. template<typename String>
  259. String hex ( const String &input ) {
  260. String output;
  261. output.reserve (input.size () * (2 * sizeof (typename String::value_type)));
  262. (void) hex (input, std::back_inserter (output));
  263. return output;
  264. }
  265. /// \fn String hex_lower ( const String &input )
  266. /// \brief Converts a sequence of integral types into a lower case hexadecimal sequence of characters.
  267. ///
  268. /// \param input A container to be converted
  269. /// \return A container with the encoded text
  270. template<typename String>
  271. String hex_lower ( const String &input ) {
  272. String output;
  273. output.reserve (input.size () * (2 * sizeof (typename String::value_type)));
  274. (void) hex_lower (input, std::back_inserter (output));
  275. return output;
  276. }
  277. /// \fn String unhex ( const String &input )
  278. /// \brief Converts a sequence of hexadecimal characters into a sequence of characters.
  279. ///
  280. /// \param input A container to be converted
  281. /// \return A container with the decoded text
  282. template<typename String>
  283. String unhex ( const String &input ) {
  284. String output;
  285. output.reserve (input.size () / (2 * sizeof (typename String::value_type)));
  286. (void) unhex (input, std::back_inserter (output));
  287. return output;
  288. }
  289. }}
  290. namespace detail {
  291. namespace iterators {
  292. template <class UnaryFunction>
  293. class function_output_iterator {
  294. typedef function_output_iterator self;
  295. public:
  296. typedef std::output_iterator_tag iterator_category;
  297. typedef void value_type;
  298. typedef void difference_type;
  299. typedef void pointer;
  300. typedef void reference;
  301. explicit function_output_iterator() {}
  302. explicit function_output_iterator(const UnaryFunction& f)
  303. : m_f(f) {}
  304. struct output_proxy {
  305. output_proxy(UnaryFunction& f) : m_f(f) { }
  306. template <class T> output_proxy& operator=(const T& value) {
  307. m_f(value);
  308. return *this;
  309. }
  310. UnaryFunction& m_f;
  311. };
  312. output_proxy operator*() { return output_proxy(m_f); }
  313. self& operator++() { return *this; }
  314. self& operator++(int) { return *this; }
  315. private:
  316. UnaryFunction m_f;
  317. };
  318. template <class UnaryFunction>
  319. inline function_output_iterator<UnaryFunction>
  320. make_function_output_iterator(const UnaryFunction& f = UnaryFunction()) {
  321. return function_output_iterator<UnaryFunction>(f);
  322. }
  323. } // namespace iterators
  324. using iterators::function_output_iterator;
  325. using iterators::make_function_output_iterator;
  326. } // namespace detail
  327. struct security
  328. {
  329. static constexpr char const* any_group_name = "@any";
  330. struct authentication
  331. {
  332. enum class method : std::uint8_t
  333. {
  334. sha256,
  335. plain_password,
  336. client_cert,
  337. anonymous,
  338. unauthenticated
  339. };
  340. authentication(
  341. method auth_method = method::sha256,
  342. std::optional<std::string> digest = std::nullopt,
  343. std::string salt = std::string()
  344. )
  345. : auth_method(auth_method)
  346. , digest(std::move(digest))
  347. , salt(std::move(salt))
  348. {
  349. }
  350. method auth_method;
  351. std::optional<std::string> digest;
  352. std::string salt;
  353. std::vector<std::string> groups;
  354. };
  355. struct authorization
  356. {
  357. enum class type : std::uint8_t
  358. {
  359. deny, allow, none
  360. };
  361. authorization(std::string_view topic, std::size_t rule_nr)
  362. : topic(topic)
  363. , rule_nr(rule_nr)
  364. , sub_type(type::none)
  365. , pub_type(type::none)
  366. {
  367. }
  368. std::vector<std::string> topic_tokens;
  369. std::string topic;
  370. std::size_t rule_nr;
  371. type sub_type;
  372. std::set<std::string> sub;
  373. type pub_type;
  374. std::set<std::string> pub;
  375. };
  376. struct group
  377. {
  378. std::string name;
  379. std::vector<std::string> members;
  380. };
  381. /** Return username of anonymous user */
  382. std::optional<std::string> const& login_anonymous() const
  383. {
  384. asio2::shared_locker g(this->security_mutex_);
  385. return this->anonymous_;
  386. }
  387. /** Return username of unauthorized user */
  388. std::optional<std::string> const& login_unauthenticated() const
  389. {
  390. asio2::shared_locker g(this->security_mutex_);
  391. return this->unauthenticated_;
  392. }
  393. template<typename T>
  394. static std::string to_hex(T start, T end)
  395. {
  396. std::string result;
  397. detail::algorithm::hex(start, end, std::back_inserter(result));
  398. return result;
  399. }
  400. #if defined(ASIO2_ENABLE_SSL) || defined(ASIO2_USE_SSL)
  401. static std::string sha256hash(std::string_view msg) {
  402. std::shared_ptr<EVP_MD_CTX> mdctx(EVP_MD_CTX_new(), EVP_MD_CTX_free);
  403. EVP_DigestInit_ex(mdctx.get(), EVP_sha256(), NULL);
  404. EVP_DigestUpdate(mdctx.get(), msg.data(), msg.size());
  405. std::vector<unsigned char> digest(static_cast<std::size_t>(EVP_MD_size(EVP_sha256())));
  406. unsigned int digest_size = static_cast<unsigned int>(digest.size());
  407. EVP_DigestFinal_ex(mdctx.get(), digest.data(), &digest_size);
  408. return to_hex(digest.data(), digest.data() + digest_size);
  409. }
  410. #else
  411. static std::string sha256hash(std::string_view msg)
  412. {
  413. return std::string(msg);
  414. }
  415. #endif
  416. bool login_cert(std::string_view username) const
  417. {
  418. asio2::shared_locker g(this->security_mutex_);
  419. auto i = authentication_.find(std::string(username));
  420. return
  421. i != authentication_.end() &&
  422. i->second.auth_method == security::authentication::method::client_cert;
  423. }
  424. std::optional<std::string> login(std::string_view username, std::string_view password) const
  425. {
  426. asio2::shared_locker g(this->security_mutex_);
  427. auto i = authentication_.find(std::string(username));
  428. if (i != authentication_.end() &&
  429. i->second.auth_method == security::authentication::method::sha256)
  430. {
  431. return [&]() -> std::optional<std::string>
  432. {
  433. if (asio2::iequals(
  434. i->second.digest.value(),
  435. sha256hash(i->second.salt + std::string(password))))
  436. {
  437. return std::string(username);
  438. }
  439. else
  440. {
  441. return std::nullopt;
  442. }
  443. } ();
  444. }
  445. else if (
  446. i != authentication_.end() &&
  447. i->second.auth_method == security::authentication::method::plain_password)
  448. {
  449. return [&]() -> std::optional<std::string>
  450. {
  451. if (i->second.digest.value() == password)
  452. {
  453. return std::string(username);
  454. }
  455. else
  456. {
  457. return std::nullopt;
  458. }
  459. } ();
  460. }
  461. return std::nullopt;
  462. }
  463. static authorization::type get_auth_type(std::string_view type)
  464. {
  465. if (type == "allow") return authorization::type::allow;
  466. if (type == "deny") return authorization::type::deny;
  467. throw std::runtime_error(
  468. "An invalid authorization type was specified: " +
  469. std::string(type)
  470. );
  471. }
  472. static bool is_valid_group_name(std::string_view name)
  473. {
  474. return !name.empty() && name[0] == '@'; // TODO: validate utf-8
  475. }
  476. static bool is_valid_user_name(std::string_view name)
  477. {
  478. return !name.empty() && name[0] != '@'; // TODO: validate utf-8
  479. }
  480. std::size_t get_next_rule_nr() const
  481. {
  482. asio2::shared_locker g(this->security_mutex_);
  483. return this->get_next_rule_nr_impl();
  484. }
  485. void default_config()
  486. {
  487. asio2::unique_locker g(this->security_mutex_);
  488. char const* username = "anonymous";
  489. authentication login(authentication::method::anonymous);
  490. authentication_.emplace(username, login);
  491. anonymous_ = username;
  492. char const* topic = "#";
  493. authorization auth(topic, get_next_rule_nr_impl());
  494. auth.topic_tokens = get_topic_filter_tokens("#");
  495. auth.sub_type = authorization::type::allow;
  496. auth.sub.emplace(username);
  497. auth.pub_type = authorization::type::allow;
  498. auth.pub.emplace(username);
  499. authorization_.push_back(auth);
  500. groups_.emplace(std::string(any_group_name), group());
  501. validate();
  502. }
  503. std::size_t add_auth(
  504. std::string const& topic_filter,
  505. std::set<std::string> const& pub,
  506. authorization::type auth_pub_type,
  507. std::set<std::string> const& sub,
  508. authorization::type auth_sub_type
  509. )
  510. {
  511. asio2::unique_locker g(this->security_mutex_);
  512. for(auto const& j : pub)
  513. {
  514. if (!is_valid_user_name(j) && !is_valid_group_name(j))
  515. {
  516. throw std::runtime_error(
  517. "An invalid username or groupname was specified for the authorization: " + j
  518. );
  519. }
  520. validate_entry("topic " + topic_filter, j);
  521. }
  522. for(auto const& j : sub)
  523. {
  524. if (!is_valid_user_name(j) && !is_valid_group_name(j))
  525. {
  526. throw std::runtime_error(
  527. "An invalid username or groupname was specified for the authorization: " + j
  528. );
  529. }
  530. validate_entry("topic " + topic_filter, j);
  531. }
  532. std::size_t rule_nr = get_next_rule_nr_impl();
  533. authorization auth(topic_filter, rule_nr);
  534. auth.topic_tokens = get_topic_filter_tokens(topic_filter);
  535. auth.pub = pub;
  536. auth.pub_type = auth_pub_type;
  537. auth.sub = sub;
  538. auth.sub_type = auth_sub_type;
  539. for (auto const& j: sub)
  540. {
  541. auth_sub_map_.insert_or_assign(
  542. topic_filter,
  543. j,
  544. std::make_pair(auth_sub_type, rule_nr)
  545. );
  546. }
  547. for (auto const& j: pub)
  548. {
  549. auth_pub_map_.insert_or_assign(
  550. topic_filter,
  551. j,
  552. std::make_pair(auth_pub_type, rule_nr)
  553. );
  554. }
  555. authorization_.push_back(auth);
  556. return rule_nr;
  557. }
  558. void remove_auth(std::size_t rule_nr)
  559. {
  560. asio2::unique_locker g(this->security_mutex_);
  561. for (auto i = authorization_.begin(); i != authorization_.end(); ++i)
  562. {
  563. if (i->rule_nr == rule_nr)
  564. {
  565. for (auto const& j : i->sub)
  566. {
  567. auth_sub_map_.erase(i->topic, j);
  568. }
  569. for (auto const& j : i->pub)
  570. {
  571. auth_pub_map_.erase(i->topic, j);
  572. }
  573. authorization_.erase(i);
  574. return;
  575. }
  576. }
  577. }
  578. void add_sha256_authentication(std::string username, std::string digest, std::string salt)
  579. {
  580. asio2::unique_locker g(this->security_mutex_);
  581. authentication auth(authentication::method::sha256, std::move(digest), std::move(salt));
  582. authentication_.emplace(std::move(username), std::move(auth));
  583. }
  584. void add_plain_password_authentication(std::string username, std::string password)
  585. {
  586. asio2::unique_locker g(this->security_mutex_);
  587. authentication auth(authentication::method::plain_password, std::move(password));
  588. authentication_.emplace(std::move(username), std::move(auth));
  589. }
  590. void add_certificate_authentication(std::string username)
  591. {
  592. asio2::unique_locker g(this->security_mutex_);
  593. authentication auth(authentication::method::client_cert);
  594. authentication_.emplace(std::move(username), std::move(auth));
  595. }
  596. void load_json(std::istream& input)
  597. {
  598. using json = nlohmann::json;
  599. json j;
  600. input >> j;
  601. asio2::unique_locker g(this->security_mutex_);
  602. groups_.emplace(std::string(any_group_name), group());
  603. if (auto& j_authentication = j["authentication"]; j_authentication.is_array())
  604. {
  605. for (auto& i : j_authentication)
  606. {
  607. auto& j_name = i["name"];
  608. if (!j_name.is_string())
  609. {
  610. ASIO2_ASSERT(false);
  611. continue;
  612. }
  613. std::string name = j_name.get<std::string>();
  614. if (!is_valid_user_name(name))
  615. {
  616. ASIO2_ASSERT(false);
  617. continue;
  618. }
  619. auto& j_method = i["method"];
  620. if (!j_method.is_string())
  621. {
  622. ASIO2_ASSERT(false);
  623. continue;
  624. }
  625. std::string method = j_method.get<std::string>();
  626. if (method == "sha256")
  627. {
  628. auto& j_digest = i["digest"];
  629. auto& j_salt = i["salt"];
  630. if (j_digest.is_string())
  631. {
  632. std::string digest = j_digest.get<std::string>();
  633. std::string salt;
  634. if (j_salt.is_string())
  635. salt = j_salt.get<std::string>();
  636. authentication auth(authentication::method::sha256, std::move(digest), std::move(salt));
  637. authentication_.emplace(std::move(name), std::move(auth));
  638. }
  639. }
  640. else if (method == "plain_password")
  641. {
  642. auto& j_password = i["password"];
  643. if (j_password.is_string())
  644. {
  645. std::string digest = j_password.get<std::string>();
  646. authentication auth(authentication::method::plain_password, digest);
  647. authentication_.emplace(std::move(name), std::move(auth));
  648. }
  649. }
  650. else if (method == "client_cert")
  651. {
  652. authentication auth(authentication::method::client_cert);
  653. authentication_.emplace(std::move(name), std::move(auth));
  654. }
  655. else if (method == "anonymous")
  656. {
  657. if (anonymous_)
  658. {
  659. ASIO2_ASSERT(false && "Only a single anonymous user can be configured");
  660. }
  661. else
  662. {
  663. anonymous_ = name;
  664. authentication auth(authentication::method::anonymous);
  665. authentication_.emplace(std::move(name), std::move(auth));
  666. }
  667. }
  668. else if (method == "unauthenticated")
  669. {
  670. if (unauthenticated_)
  671. {
  672. ASIO2_ASSERT(false && "Only a single unauthenticated user can be configured");
  673. }
  674. else
  675. {
  676. unauthenticated_ = name;
  677. authentication auth(authentication::method::unauthenticated);
  678. authentication_.emplace(std::move(name), std::move(auth));
  679. }
  680. }
  681. else
  682. {
  683. ASIO2_ASSERT(false);
  684. }
  685. }
  686. }
  687. if (auto& j_groups = j["groups"]; j_groups.is_array())
  688. {
  689. for (auto& i : j_groups)
  690. {
  691. auto& j_name = i["name"];
  692. if (!j_name.is_string())
  693. {
  694. ASIO2_ASSERT(false);
  695. continue;
  696. }
  697. std::string name = j_name.get<std::string>();
  698. if (!is_valid_group_name(name))
  699. {
  700. ASIO2_ASSERT(false);
  701. continue;
  702. }
  703. group group;
  704. if (auto& j_members = j["members"]; j_members.is_array())
  705. {
  706. for (auto& j_username : j_members)
  707. {
  708. if (!j_username.is_string())
  709. {
  710. ASIO2_ASSERT(false);
  711. continue;
  712. }
  713. auto username = j_username.get<std::string>();
  714. if (!is_valid_user_name(username))
  715. {
  716. ASIO2_ASSERT(false);
  717. continue;
  718. }
  719. group.members.emplace_back(std::move(username));
  720. }
  721. }
  722. else
  723. {
  724. ASIO2_ASSERT(false);
  725. }
  726. groups_.emplace(std::move(name), std::move(group));
  727. }
  728. }
  729. if (auto& j_authorization = j["authorization"]; j_authorization.is_array())
  730. {
  731. for (auto& i : j_authorization)
  732. {
  733. auto& j_name = i["topic"];
  734. if (!j_name.is_string())
  735. {
  736. ASIO2_ASSERT(false);
  737. continue;
  738. }
  739. std::string name = j_name.get<std::string>();
  740. if (!validate_topic_filter(name))
  741. {
  742. ASIO2_ASSERT(false);
  743. continue;
  744. }
  745. authorization auth(name, get_next_rule_nr_impl());
  746. auth.topic_tokens = get_topic_filter_tokens(name);
  747. if (auto& j_allow = i["allow"]; j_allow.is_object())
  748. {
  749. if (auto& j_sub = j_allow["sub"]; j_sub.is_array())
  750. {
  751. for (auto& j_username : j_sub)
  752. {
  753. if (j_username.is_string())
  754. {
  755. auth.sub.emplace(j_username.get<std::string>());
  756. }
  757. }
  758. auth.sub_type = authorization::type::allow;
  759. }
  760. if (auto& j_pub = j_allow["pub"]; j_pub.is_array())
  761. {
  762. for (auto& j_username : j_pub)
  763. {
  764. if (j_username.is_string())
  765. {
  766. auth.pub.emplace(j_username.get<std::string>());
  767. }
  768. }
  769. auth.pub_type = authorization::type::allow;
  770. }
  771. }
  772. if (auto& j_deny = i["deny"]; j_deny.is_object())
  773. {
  774. if (auto& j_sub = j_deny["sub"]; j_sub.is_array())
  775. {
  776. for (auto& j_username : j_sub)
  777. {
  778. if (j_username.is_string())
  779. {
  780. auth.sub.emplace(j_username.get<std::string>());
  781. }
  782. }
  783. auth.sub_type = authorization::type::deny;
  784. }
  785. if (auto& j_pub = j_deny["pub"]; j_pub.is_array())
  786. {
  787. for (auto& j_username : j_pub)
  788. {
  789. auth.pub.emplace(j_username.get<std::string>());
  790. }
  791. auth.pub_type = authorization::type::deny;
  792. }
  793. }
  794. authorization_.emplace_back(std::move(auth));
  795. }
  796. }
  797. validate();
  798. }
  799. template<typename T>
  800. void get_auth_sub_by_user(std::string_view username, T&& callback) const
  801. {
  802. std::set<std::string> username_and_groups;
  803. username_and_groups.insert(std::string(username));
  804. asio2::shared_locker g(this->security_mutex_);
  805. for (auto const& i : groups_)
  806. {
  807. if (i.first == any_group_name ||
  808. std::find(i.second.members.begin(), i.second.members.end(), username) != i.second.members.end())
  809. {
  810. username_and_groups.insert(i.first);
  811. }
  812. }
  813. for (auto const& i : authorization_)
  814. {
  815. if (i.sub_type != authorization::type::none)
  816. {
  817. bool sets_intersect = false;
  818. auto store_intersect = [&sets_intersect](std::string const&) mutable
  819. {
  820. sets_intersect = true;
  821. };
  822. std::set_intersection(
  823. i.sub.begin(),
  824. i.sub.end(),
  825. username_and_groups.begin(),
  826. username_and_groups.end(),
  827. detail::make_function_output_iterator(std::ref(store_intersect))
  828. );
  829. if (sets_intersect)
  830. {
  831. std::forward<T>(callback)(i);
  832. }
  833. }
  834. }
  835. }
  836. authorization::type auth_pub(std::string_view topic, std::string_view username)
  837. {
  838. authorization::type result_type = authorization::type::deny;
  839. std::set<std::string> username_and_groups;
  840. username_and_groups.insert(std::string(username));
  841. asio2::shared_locker g(this->security_mutex_);
  842. for (auto const& i : groups_)
  843. {
  844. if (i.first == any_group_name ||
  845. std::find(i.second.members.begin(), i.second.members.end(), username) != i.second.members.end())
  846. {
  847. username_and_groups.insert(i.first);
  848. }
  849. }
  850. std::size_t priority = 0;
  851. auth_pub_map_.match(topic,
  852. [&](const std::string& allowed_username, std::pair<authorization::type, std::size_t>& entry) mutable
  853. {
  854. if (username_and_groups.find(allowed_username) != username_and_groups.end())
  855. {
  856. if (entry.second >= priority)
  857. {
  858. result_type = entry.first;
  859. priority = entry.second;
  860. }
  861. }
  862. }
  863. );
  864. return result_type;
  865. }
  866. std::map<std::string, authorization::type> auth_sub(std::string_view topic)
  867. {
  868. std::map<std::string, authorization::type> result;
  869. std::size_t priority = 0;
  870. auth_sub_map_.match(topic,
  871. [&](const std::string& allowed_username, std::pair<authorization::type, std::size_t>& entry)
  872. {
  873. if (entry.second >= priority)
  874. {
  875. result[allowed_username] = entry.first;
  876. priority = entry.second;
  877. }
  878. }
  879. );
  880. return result;
  881. }
  882. authorization::type auth_sub_user(
  883. std::map<std::string, authorization::type> const& result, std::string const& username)
  884. {
  885. auto it = result.find(username);
  886. if (it != result.end())
  887. return it->second;
  888. asio2::shared_locker g(this->security_mutex_);
  889. for (auto& [k, v] : groups_)
  890. {
  891. if (k == any_group_name ||
  892. std::find(v.members.begin(), v.members.end(), username) != v.members.end())
  893. {
  894. auto j = result.find(k);
  895. if (j != result.end())
  896. return j->second;
  897. }
  898. }
  899. return authorization::type::deny;
  900. }
  901. static bool is_hash(std::string_view level) { return level == "#"; }
  902. static bool is_plus(std::string_view level) { return level == "+"; }
  903. static bool is_literal(std::string_view level) { return !is_hash(level) && !is_plus(level); }
  904. static std::optional<std::string> is_subscribe_allowed(
  905. std::vector<std::string> const& authorized_filter,std::string_view subscription_filter)
  906. {
  907. std::optional<std::string> result;
  908. auto append_result = [&result](std::string_view token)
  909. {
  910. if (result)
  911. {
  912. result.value() += topic_filter_separator;
  913. result.value().append(token.data(), token.size());
  914. }
  915. else
  916. {
  917. result = std::string(token);
  918. }
  919. };
  920. auto filter_begin = authorized_filter.begin();
  921. auto subscription_begin = subscription_filter.begin();
  922. auto subscription_next = topic_filter_tokenizer_next(subscription_begin, subscription_filter.end());
  923. while (true)
  924. {
  925. if (filter_begin == authorized_filter.end())
  926. {
  927. return std::nullopt;
  928. }
  929. auto auth = *filter_begin;
  930. ++filter_begin;
  931. if (is_hash(auth))
  932. {
  933. append_result(std::string_view(&(*subscription_begin),
  934. std::distance(subscription_begin, subscription_filter.end())));
  935. return result;
  936. }
  937. auto sub = std::string_view(&(*subscription_begin),
  938. std::distance(subscription_begin, subscription_next));
  939. if (is_hash(sub))
  940. {
  941. append_result(auth);
  942. while (filter_begin < authorized_filter.end())
  943. {
  944. append_result(*filter_begin);
  945. ++filter_begin;
  946. }
  947. return result;
  948. }
  949. if (is_plus(auth))
  950. {
  951. append_result(sub);
  952. }
  953. else if (is_plus(sub))
  954. {
  955. append_result(auth);
  956. }
  957. else
  958. {
  959. if (auth != sub)
  960. {
  961. return std::nullopt;
  962. }
  963. append_result(auth);
  964. }
  965. if (subscription_next == subscription_filter.end())
  966. break;
  967. subscription_begin = std::next(subscription_next);
  968. subscription_next = topic_filter_tokenizer_next(subscription_begin, subscription_filter.end());
  969. }
  970. if (filter_begin < authorized_filter.end())
  971. {
  972. return std::nullopt;
  973. }
  974. return result;
  975. }
  976. static bool is_subscribe_denied(
  977. std::vector<std::string> const& deny_filter, std::string_view subscription_filter)
  978. {
  979. bool result = true;
  980. auto filter_begin = deny_filter.begin();
  981. auto tokens_count = topic_filter_tokenizer(subscription_filter,
  982. [&](auto sub)
  983. {
  984. if (filter_begin == deny_filter.end())
  985. {
  986. result = false;
  987. return false;
  988. };
  989. std::string deny = *filter_begin;
  990. ++filter_begin;
  991. if (deny != sub)
  992. {
  993. if (is_hash(deny))
  994. {
  995. result = true;
  996. return false;
  997. }
  998. if (is_hash(sub))
  999. {
  1000. result = false;
  1001. return false;
  1002. }
  1003. if (is_plus(deny))
  1004. {
  1005. result = true;
  1006. return true;
  1007. }
  1008. result = false;
  1009. return false;
  1010. }
  1011. return true;
  1012. }
  1013. );
  1014. return result && (tokens_count == deny_filter.size());
  1015. }
  1016. std::vector<std::string> get_auth_sub_topics(std::string_view username, std::string_view topic_filter) const
  1017. {
  1018. std::vector<std::string> auth_topics;
  1019. get_auth_sub_by_user(username,
  1020. [&](authorization const& i)
  1021. {
  1022. if (i.sub_type == authorization::type::allow)
  1023. {
  1024. auto entry = is_subscribe_allowed(i.topic_tokens, topic_filter);
  1025. if (entry)
  1026. {
  1027. auth_topics.push_back(entry.value());
  1028. }
  1029. }
  1030. else
  1031. {
  1032. for (auto j = auth_topics.begin(); j != auth_topics.end();)
  1033. {
  1034. if (is_subscribe_denied(i.topic_tokens, topic_filter))
  1035. {
  1036. j = auth_topics.erase(j);
  1037. }
  1038. else
  1039. {
  1040. ++j;
  1041. }
  1042. }
  1043. }
  1044. }
  1045. );
  1046. return auth_topics;
  1047. }
  1048. /**
  1049. * @brief Determine if user is allowed to subscribe to the specified topic filter
  1050. * @param username - The username to check
  1051. * @param topic_filter - Topic filter the user would like to subscribe to
  1052. * @return true if the user is authorized
  1053. */
  1054. bool is_subscribe_authorized(std::string_view username, std::string_view topic_filter) const
  1055. {
  1056. return !get_auth_sub_topics(username, topic_filter).empty();
  1057. }
  1058. // Get the individual path elements of the topic filter
  1059. static std::vector<std::string> get_topic_filter_tokens(std::string_view topic_filter)
  1060. {
  1061. std::vector<std::string> result;
  1062. topic_filter_tokenizer(topic_filter,
  1063. [&result](auto str)
  1064. {
  1065. result.push_back(std::string(str));
  1066. return true;
  1067. }
  1068. );
  1069. return result;
  1070. }
  1071. inline bool enabled() const noexcept { return enabled_; }
  1072. inline void enabled(bool v) noexcept { enabled_ = v; }
  1073. /// use rwlock to make thread safe
  1074. mutable asio2::shared_mutexer security_mutex_;
  1075. bool enabled_ = true;
  1076. std::map<std::string, authentication> authentication_ ASIO2_GUARDED_BY(security_mutex_);
  1077. std::map<std::string, group > groups_ ASIO2_GUARDED_BY(security_mutex_);
  1078. std::vector<authorization> authorization_ ASIO2_GUARDED_BY(security_mutex_);
  1079. std::optional<std::string> anonymous_ ASIO2_GUARDED_BY(security_mutex_);
  1080. std::optional<std::string> unauthenticated_ ASIO2_GUARDED_BY(security_mutex_);
  1081. using auth_map_type = subscription_map<std::string, std::pair<authorization::type, std::size_t>>;
  1082. auth_map_type auth_pub_map_;
  1083. auth_map_type auth_sub_map_;
  1084. protected:
  1085. std::size_t get_next_rule_nr_impl() const ASIO2_NO_THREAD_SAFETY_ANALYSIS
  1086. {
  1087. std::size_t rule_nr = 0;
  1088. for (auto const& i : authorization_)
  1089. {
  1090. rule_nr = (std::max)(rule_nr, i.rule_nr);
  1091. }
  1092. return rule_nr + 1;
  1093. }
  1094. void validate_entry(std::string const& context, std::string const& name) const ASIO2_NO_THREAD_SAFETY_ANALYSIS
  1095. {
  1096. if (is_valid_group_name(name) && groups_.find(name) == groups_.end())
  1097. {
  1098. throw std::runtime_error("An invalid group name was specified for " + context + ": " + name);
  1099. }
  1100. if (is_valid_user_name(name) && authentication_.find(name) == authentication_.end())
  1101. {
  1102. throw std::runtime_error("An invalid username name was specified for " + context + ": " + name);
  1103. }
  1104. }
  1105. void validate() ASIO2_NO_THREAD_SAFETY_ANALYSIS
  1106. {
  1107. for (auto const& i : groups_)
  1108. {
  1109. for (auto const& j : i.second.members)
  1110. {
  1111. auto iter = authentication_.find(j);
  1112. if (is_valid_user_name(j) && iter == authentication_.end())
  1113. throw std::runtime_error("An invalid username name was specified for group " + i.first + ": " + j);
  1114. }
  1115. }
  1116. std::string unsalted;
  1117. for (auto const& i : authentication_)
  1118. {
  1119. if (i.second.auth_method == authentication::method::sha256 && i.second.salt.empty())
  1120. {
  1121. if (!unsalted.empty()) unsalted += ", ";
  1122. unsalted += i.first;
  1123. }
  1124. }
  1125. if (!unsalted.empty())
  1126. {
  1127. //MQTT_LOG("mqtt_broker", warning)
  1128. // << "The following users have no salt specified: "
  1129. // << unsalted;
  1130. }
  1131. for (auto const& i : authorization_)
  1132. {
  1133. for (auto const& j : i.sub)
  1134. {
  1135. validate_entry("topic " + i.topic, j);
  1136. if (is_valid_user_name(j) || is_valid_group_name(j))
  1137. {
  1138. auth_sub_map_.insert_or_assign(i.topic, j, std::make_pair(i.sub_type, i.rule_nr));
  1139. }
  1140. }
  1141. for (auto const& j : i.pub)
  1142. {
  1143. validate_entry("topic " + i.topic, j);
  1144. if (is_valid_user_name(j) || is_valid_group_name(j))
  1145. {
  1146. auth_pub_map_.insert_or_assign(i.topic, j, std::make_pair(i.pub_type, i.rule_nr));
  1147. }
  1148. }
  1149. }
  1150. }
  1151. };
  1152. }
  1153. #endif // __ASIO2_MQTT_SECURITY_HPP__