connect_op.hpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. //
  2. // Copyright (c) 2023-2025 Ivica Siladic, Bruno Iljazovic, Korina Simicevic
  3. //
  4. // Distributed under the Boost Software License, Version 1.0.
  5. // (See accompanying file LICENSE or copy at http://www.boost.org/LICENSE_1_0.txt)
  6. //
  7. #ifndef BOOST_MQTT5_CONNECT_OP_HPP
  8. #define BOOST_MQTT5_CONNECT_OP_HPP
  9. #include <boost/mqtt5/error.hpp>
  10. #include <boost/mqtt5/reason_codes.hpp>
  11. #include <boost/mqtt5/detail/async_traits.hpp>
  12. #include <boost/mqtt5/detail/control_packet.hpp>
  13. #include <boost/mqtt5/detail/internal_types.hpp>
  14. #include <boost/mqtt5/detail/log_invoke.hpp>
  15. #include <boost/mqtt5/detail/shutdown.hpp>
  16. #include <boost/mqtt5/impl/codecs/message_decoders.hpp>
  17. #include <boost/mqtt5/impl/codecs/message_encoders.hpp>
  18. #include <boost/asio/any_completion_handler.hpp>
  19. #include <boost/asio/append.hpp>
  20. #include <boost/asio/associated_allocator.hpp>
  21. #include <boost/asio/associated_cancellation_slot.hpp>
  22. #include <boost/asio/associated_executor.hpp>
  23. #include <boost/asio/cancellation_state.hpp>
  24. #include <boost/asio/completion_condition.hpp>
  25. #include <boost/asio/consign.hpp>
  26. #include <boost/asio/dispatch.hpp>
  27. #include <boost/asio/error.hpp>
  28. #include <boost/asio/ip/tcp.hpp>
  29. #include <boost/asio/prepend.hpp>
  30. #include <boost/asio/read.hpp>
  31. #include <boost/asio/write.hpp>
  32. #include <cstdint>
  33. #include <memory>
  34. #include <string>
  35. namespace boost::mqtt5::detail {
  36. template <typename Stream, typename LoggerType>
  37. class connect_op {
  38. static constexpr size_t min_packet_sz = 5;
  39. struct on_connect {};
  40. struct on_tls_handshake {};
  41. struct on_ws_handshake {};
  42. struct on_send_connect {};
  43. struct on_fixed_header {};
  44. struct on_read_packet {};
  45. struct on_init_auth_data {};
  46. struct on_auth_data {};
  47. struct on_send_auth {};
  48. struct on_complete_auth {};
  49. struct on_shutdown {};
  50. Stream& _stream;
  51. mqtt_ctx& _ctx;
  52. log_invoke<LoggerType>& _log;
  53. using handler_type = asio::any_completion_handler<void (error_code)>;
  54. handler_type _handler;
  55. std::unique_ptr<std::string> _buffer_ptr;
  56. asio::cancellation_state _cancellation_state;
  57. using endpoint = asio::ip::tcp::endpoint;
  58. public:
  59. template <typename Handler>
  60. connect_op(
  61. Stream& stream, mqtt_ctx& ctx,
  62. log_invoke<LoggerType>& log,
  63. Handler&& handler
  64. ) :
  65. _stream(stream), _ctx(ctx), _log(log),
  66. _handler(std::forward<Handler>(handler)),
  67. _cancellation_state(
  68. asio::get_associated_cancellation_slot(_handler),
  69. asio::enable_total_cancellation {},
  70. asio::enable_total_cancellation {}
  71. )
  72. {}
  73. connect_op(connect_op&&) = default;
  74. connect_op(const connect_op&) = delete;
  75. connect_op& operator=(connect_op&&) = default;
  76. connect_op& operator=(const connect_op&) = delete;
  77. using allocator_type = asio::associated_allocator_t<handler_type>;
  78. allocator_type get_allocator() const noexcept {
  79. return asio::get_associated_allocator(_handler);
  80. }
  81. using cancellation_slot_type = asio::cancellation_slot;
  82. cancellation_slot_type get_cancellation_slot() const noexcept {
  83. return _cancellation_state.slot();
  84. }
  85. using executor_type = asio::associated_executor_t<handler_type>;
  86. executor_type get_executor() const noexcept {
  87. return asio::get_associated_executor(_handler);
  88. }
  89. void perform(const endpoint& ep, authority_path ap) {
  90. lowest_layer(_stream).async_connect(
  91. ep,
  92. asio::append(
  93. asio::prepend(std::move(*this), on_connect {}),
  94. ep, std::move(ap)
  95. )
  96. );
  97. }
  98. void operator()(
  99. on_connect, error_code ec, endpoint ep, authority_path ap
  100. ) {
  101. if (is_cancelled())
  102. return complete(asio::error::operation_aborted);
  103. _log.at_tcp_connect(ec, ep);
  104. if (ec)
  105. return complete(ec);
  106. do_tls_handshake(std::move(ep), std::move(ap));
  107. }
  108. void do_tls_handshake(endpoint ep, authority_path ap) {
  109. if constexpr (has_tls_handshake<Stream>) {
  110. _stream.async_handshake(
  111. tls_handshake_type<Stream>::client,
  112. asio::append(
  113. asio::prepend(std::move(*this), on_tls_handshake {}),
  114. std::move(ep), std::move(ap)
  115. )
  116. );
  117. }
  118. else if constexpr (
  119. has_tls_handshake<next_layer_type<Stream>>
  120. ) {
  121. _stream.next_layer().async_handshake(
  122. tls_handshake_type<next_layer_type<Stream>>::client,
  123. asio::append(
  124. asio::prepend(std::move(*this), on_tls_handshake {}),
  125. std::move(ep), std::move(ap)
  126. )
  127. );
  128. }
  129. else
  130. do_ws_handshake(std::move(ep), std::move(ap));
  131. }
  132. void operator()(
  133. on_tls_handshake, error_code ec, endpoint ep, authority_path ap
  134. ) {
  135. if (is_cancelled())
  136. return complete(asio::error::operation_aborted);
  137. _log.at_tls_handshake(ec, ep);
  138. if (ec)
  139. return complete(ec);
  140. do_ws_handshake(std::move(ep), std::move(ap));
  141. }
  142. void do_ws_handshake(endpoint ep, authority_path ap) {
  143. if constexpr (has_ws_handshake<Stream>)
  144. // If you get a compilation error here,
  145. // it might be because of a missing <boost/mqtt5/websocket.hpp> include
  146. ws_handshake_traits<Stream>::async_handshake(
  147. _stream, std::move(ap),
  148. asio::append(
  149. asio::prepend(std::move(*this), on_ws_handshake {}), ep
  150. )
  151. );
  152. else
  153. (*this)(on_ws_handshake {}, error_code {}, ep);
  154. }
  155. void operator()(on_ws_handshake, error_code ec, endpoint ep) {
  156. if (is_cancelled())
  157. return complete(asio::error::operation_aborted);
  158. if constexpr (has_ws_handshake<Stream>)
  159. _log.at_ws_handshake(ec, ep);
  160. if (ec)
  161. return complete(ec);
  162. auto auth_method = _ctx.authenticator.method();
  163. if (!auth_method.empty()) {
  164. _ctx.co_props[prop::authentication_method] = auth_method;
  165. return _ctx.authenticator.async_auth(
  166. auth_step_e::client_initial, "",
  167. asio::prepend(std::move(*this), on_init_auth_data {})
  168. );
  169. }
  170. send_connect();
  171. }
  172. void operator()(on_init_auth_data, error_code ec, std::string data) {
  173. if (is_cancelled())
  174. return complete(asio::error::operation_aborted);
  175. if (ec)
  176. return do_shutdown(asio::error::try_again);
  177. _ctx.co_props[prop::authentication_data] = std::move(data);
  178. send_connect();
  179. }
  180. void send_connect() {
  181. if (!_ctx.co_props[prop::maximum_packet_size].has_value())
  182. _ctx.co_props[prop::maximum_packet_size] = default_max_recv_size;
  183. auto packet = control_packet<allocator_type>::of(
  184. no_pid, get_allocator(),
  185. encoders::encode_connect,
  186. _ctx.creds.client_id,
  187. _ctx.creds.username, _ctx.creds.password,
  188. _ctx.keep_alive, false, _ctx.co_props, _ctx.will_msg
  189. );
  190. auto wire_data = packet.wire_data();
  191. detail::async_write(
  192. _stream, asio::buffer(wire_data),
  193. asio::consign(
  194. asio::prepend(std::move(*this), on_send_connect {}),
  195. std::move(packet)
  196. )
  197. );
  198. }
  199. void operator()(on_send_connect, error_code ec, size_t) {
  200. if (is_cancelled())
  201. return complete(asio::error::operation_aborted);
  202. if (ec) {
  203. _log.at_transport_error(ec);
  204. return do_shutdown(ec);
  205. }
  206. _buffer_ptr = std::make_unique<std::string>(min_packet_sz, char(0));
  207. auto buff = asio::buffer(_buffer_ptr->data(), min_packet_sz);
  208. asio::async_read(
  209. _stream, buff, asio::transfer_all(),
  210. asio::prepend(std::move(*this), on_fixed_header {})
  211. );
  212. }
  213. void operator()(
  214. on_fixed_header, error_code ec, size_t num_read
  215. ) {
  216. if (is_cancelled())
  217. return complete(asio::error::operation_aborted);
  218. if (ec) {
  219. _log.at_transport_error(ec);
  220. return do_shutdown(ec);
  221. }
  222. auto code = control_code_e((*_buffer_ptr)[0] & 0b11110000);
  223. if (code != control_code_e::auth && code != control_code_e::connack)
  224. return do_shutdown(asio::error::try_again);
  225. auto varlen_ptr = _buffer_ptr->cbegin() + 1;
  226. auto varlen = decoders::type_parse(
  227. varlen_ptr, _buffer_ptr->cend(), decoders::basic::varint_
  228. );
  229. if (!varlen)
  230. return do_shutdown(asio::error::try_again);
  231. auto varlen_sz = std::distance(_buffer_ptr->cbegin() + 1, varlen_ptr);
  232. auto remain_len = *varlen -
  233. std::distance(varlen_ptr, _buffer_ptr->cbegin() + num_read);
  234. if (num_read + remain_len > _buffer_ptr->size())
  235. _buffer_ptr->resize(num_read + remain_len);
  236. auto buff = asio::buffer(_buffer_ptr->data() + num_read, remain_len);
  237. auto first = _buffer_ptr->cbegin() + varlen_sz + 1;
  238. auto last = first + *varlen;
  239. asio::async_read(
  240. _stream, buff, asio::transfer_all(),
  241. asio::prepend(
  242. asio::append(std::move(*this), code, first, last),
  243. on_read_packet {}
  244. )
  245. );
  246. }
  247. void operator()(
  248. on_read_packet, error_code ec, size_t, control_code_e code,
  249. byte_citer first, byte_citer last
  250. ) {
  251. if (is_cancelled())
  252. return complete(asio::error::operation_aborted);
  253. if (ec) {
  254. _log.at_transport_error(ec);
  255. return do_shutdown(ec);
  256. }
  257. if (code == control_code_e::connack)
  258. return on_connack(first, last);
  259. if (!_ctx.co_props[prop::authentication_method].has_value())
  260. return do_shutdown(client::error::malformed_packet);
  261. on_auth(first, last);
  262. }
  263. void on_connack(byte_citer first, byte_citer last) {
  264. auto packet_length = static_cast<uint32_t>(std::distance(first, last));
  265. auto rv = decoders::decode_connack(packet_length, first);
  266. if (!rv.has_value())
  267. return do_shutdown(client::error::malformed_packet);
  268. const auto& [session_present, reason_code, ca_props] = *rv;
  269. _ctx.ca_props = ca_props;
  270. _ctx.state.session_present(session_present);
  271. // Unexpected result handling:
  272. // - If we don't have a Session State, and we get session_present = true,
  273. // we must close the network connection (and restart with a clean start)
  274. // - If we have a Session State, and we get session_present = false,
  275. // we must discard our Session State
  276. auto rc = to_reason_code<reason_codes::category::connack>(reason_code);
  277. if (!rc.has_value()) // reason code not allowed in CONNACK
  278. return do_shutdown(client::error::malformed_packet);
  279. _log.at_connack(*rc, session_present, ca_props);
  280. if (*rc)
  281. return do_shutdown(asio::error::try_again);
  282. if (_ctx.co_props[prop::authentication_method].has_value())
  283. return _ctx.authenticator.async_auth(
  284. auth_step_e::server_final,
  285. ca_props[prop::authentication_data].value_or(""),
  286. asio::prepend(std::move(*this), on_complete_auth {})
  287. );
  288. complete(error_code {});
  289. }
  290. void on_auth(byte_citer first, byte_citer last) {
  291. auto packet_length = static_cast<uint32_t>(std::distance(first, last));
  292. auto rv = decoders::decode_auth(packet_length, first);
  293. if (!rv.has_value())
  294. return do_shutdown(client::error::malformed_packet);
  295. const auto& [reason_code, auth_props] = *rv;
  296. auto rc = to_reason_code<reason_codes::category::auth>(reason_code);
  297. if (
  298. !rc.has_value() ||
  299. auth_props[prop::authentication_method]
  300. != _ctx.co_props[prop::authentication_method]
  301. )
  302. return do_shutdown(client::error::malformed_packet);
  303. _ctx.authenticator.async_auth(
  304. auth_step_e::server_challenge,
  305. auth_props[prop::authentication_data].value_or(""),
  306. asio::prepend(std::move(*this), on_auth_data {})
  307. );
  308. }
  309. void operator()(on_auth_data, error_code ec, std::string data) {
  310. if (is_cancelled())
  311. return complete(asio::error::operation_aborted);
  312. if (ec)
  313. return do_shutdown(asio::error::try_again);
  314. auth_props props;
  315. props[prop::authentication_method] =
  316. _ctx.co_props[prop::authentication_method];
  317. props[prop::authentication_data] = std::move(data);
  318. auto packet = control_packet<allocator_type>::of(
  319. no_pid, get_allocator(),
  320. encoders::encode_auth,
  321. reason_codes::continue_authentication.value(), props
  322. );
  323. auto wire_data = packet.wire_data();
  324. detail::async_write(
  325. _stream, asio::buffer(wire_data),
  326. asio::consign(
  327. asio::prepend(std::move(*this), on_send_auth {}),
  328. std::move(packet)
  329. )
  330. );
  331. }
  332. void operator()(on_send_auth, error_code ec, size_t) {
  333. if (is_cancelled())
  334. return complete(asio::error::operation_aborted);
  335. if (ec) {
  336. _log.at_transport_error(ec);
  337. return do_shutdown(ec);
  338. }
  339. auto buff = asio::buffer(_buffer_ptr->data(), min_packet_sz);
  340. asio::async_read(
  341. _stream, buff, asio::transfer_all(),
  342. asio::prepend(std::move(*this), on_fixed_header {})
  343. );
  344. }
  345. void operator()(on_complete_auth, error_code ec, std::string) {
  346. if (is_cancelled())
  347. return complete(asio::error::operation_aborted);
  348. if (ec)
  349. return do_shutdown(asio::error::try_again);
  350. complete(error_code {});
  351. }
  352. void do_shutdown(error_code connect_ec) {
  353. auto init_shutdown = [&stream = _stream](auto handler) {
  354. async_shutdown(stream, std::move(handler));
  355. };
  356. auto token = asio::prepend(std::move(*this), on_shutdown{}, connect_ec);
  357. return asio::async_initiate<decltype(token), void(error_code)>(
  358. init_shutdown, token
  359. );
  360. }
  361. void operator()(on_shutdown, error_code connect_ec, error_code) {
  362. // ignore shutdown error_code
  363. complete(connect_ec);
  364. }
  365. private:
  366. bool is_cancelled() const {
  367. return _cancellation_state.cancelled() != asio::cancellation_type::none;
  368. }
  369. void complete(error_code ec) {
  370. asio::get_associated_cancellation_slot(_handler).clear();
  371. std::move(_handler)(ec);
  372. }
  373. };
  374. } // end namespace boost::mqtt5::detail
  375. #endif // !BOOST_MQTT5_CONNECT_OP_HPP