variant_stream.hpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. //
  2. // Copyright (c) 2019-2025 Ruben Perez Hidalgo (rubenperez038 at gmail dot com)
  3. //
  4. // Distributed under the Boost Software License, Version 1.0. (See accompanying
  5. // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
  6. //
  7. #ifndef BOOST_MYSQL_IMPL_INTERNAL_VARIANT_STREAM_HPP
  8. #define BOOST_MYSQL_IMPL_INTERNAL_VARIANT_STREAM_HPP
  9. #include <boost/mysql/any_address.hpp>
  10. #include <boost/mysql/error_code.hpp>
  11. #include <boost/mysql/string_view.hpp>
  12. #include <boost/mysql/detail/access.hpp>
  13. #include <boost/mysql/impl/internal/coroutine.hpp>
  14. #include <boost/mysql/impl/internal/ssl_context_with_default.hpp>
  15. #include <boost/asio/any_io_executor.hpp>
  16. #include <boost/asio/associated_immediate_executor.hpp>
  17. #include <boost/asio/cancellation_type.hpp>
  18. #include <boost/asio/compose.hpp>
  19. #include <boost/asio/connect.hpp>
  20. #include <boost/asio/dispatch.hpp>
  21. #include <boost/asio/error.hpp>
  22. #include <boost/asio/generic/stream_protocol.hpp>
  23. #include <boost/asio/ip/tcp.hpp>
  24. #include <boost/asio/local/stream_protocol.hpp>
  25. #include <boost/asio/ssl/context.hpp>
  26. #include <boost/asio/ssl/stream.hpp>
  27. #include <boost/core/span.hpp>
  28. #include <boost/optional/optional.hpp>
  29. #include <memory>
  30. #include <string>
  31. #include <utility>
  32. #include <vector>
  33. namespace boost {
  34. namespace mysql {
  35. namespace detail {
  36. struct variant_stream_state
  37. {
  38. asio::generic::stream_protocol::socket sock;
  39. ssl_context_with_default ssl_ctx;
  40. boost::optional<asio::ssl::stream<asio::generic::stream_protocol::socket&>> ssl;
  41. variant_stream_state(asio::any_io_executor ex, asio::ssl::context* ctx) : sock(ex), ssl_ctx(ctx) {}
  42. asio::ssl::stream<asio::generic::stream_protocol::socket&>& create_ssl_stream()
  43. {
  44. // The stream object must be re-created even if it already exists, since
  45. // once used for a connection (anytime after ssl::stream::handshake is called),
  46. // it can't be re-used for any subsequent connections
  47. ssl.emplace(sock, ssl_ctx.get());
  48. return *ssl;
  49. }
  50. };
  51. enum class vsconnect_action_type
  52. {
  53. none,
  54. resolve,
  55. connect,
  56. immediate, // we'll be performing an immediate completion
  57. };
  58. struct vsconnect_action
  59. {
  60. vsconnect_action_type type;
  61. union data_t
  62. {
  63. error_code err;
  64. struct resolve_t
  65. {
  66. const std::string* hostname;
  67. const std::string* service;
  68. } resolve;
  69. span<const asio::generic::stream_protocol::endpoint> connect;
  70. data_t(error_code v) noexcept : err(v) {}
  71. data_t(resolve_t v) noexcept : resolve(v) {}
  72. data_t(span<const asio::generic::stream_protocol::endpoint> v) noexcept : connect(v) {}
  73. } data;
  74. struct immediate_tag
  75. {
  76. };
  77. vsconnect_action(immediate_tag) noexcept : type(vsconnect_action_type::immediate), data(error_code()) {}
  78. vsconnect_action(error_code v = {}) noexcept : type(vsconnect_action_type::none), data(v) {}
  79. vsconnect_action(data_t::resolve_t v) noexcept : type(vsconnect_action_type::resolve), data(v) {}
  80. vsconnect_action(span<const asio::generic::stream_protocol::endpoint> v) noexcept
  81. : type(vsconnect_action_type::connect), data(v)
  82. {
  83. }
  84. };
  85. class variant_stream_connect_algo
  86. {
  87. variant_stream_state* st_;
  88. const any_address* addr_;
  89. boost::optional<asio::ip::tcp::resolver> resolv_;
  90. std::vector<asio::generic::stream_protocol::endpoint> endpoints_;
  91. std::string service_;
  92. int resume_point_{0};
  93. const std::string& address() const { return access::get_impl(*addr_).address; }
  94. asio::any_io_executor get_executor() const { return st_->sock.get_executor(); }
  95. public:
  96. variant_stream_connect_algo(variant_stream_state& st, const any_address& addr) : st_(&st), addr_(&addr) {}
  97. asio::ip::tcp::resolver& resolver() { return *resolv_; }
  98. asio::generic::stream_protocol::socket& socket() { return st_->sock; }
  99. vsconnect_action resume(
  100. error_code ec,
  101. const asio::ip::tcp::resolver::results_type* resolver_results,
  102. asio::cancellation_type_t cancel_state
  103. )
  104. {
  105. // All errors are considered fatal
  106. if (ec)
  107. return ec;
  108. // If we received a terminal cancellation signal, exit with the appropriate error code.
  109. // In composed async operations, if the cancellation arrives after an intermediate operation
  110. // has completed, but before the handler is called, the operation finishes successfully,
  111. // but the cancellation state is set. This check covers this case.
  112. if (!!(cancel_state & asio::cancellation_type_t::terminal))
  113. return error_code(asio::error::operation_aborted);
  114. switch (resume_point_)
  115. {
  116. case 0:
  117. // Clean up any previous state
  118. st_->sock = asio::generic::stream_protocol::socket(get_executor());
  119. // Set up the endpoints vector
  120. if (addr_->type() == address_type::host_and_port)
  121. {
  122. // Emplace the resolver
  123. resolv_.emplace(get_executor());
  124. // Resolve the endpoints
  125. service_ = std::to_string(addr_->port());
  126. BOOST_MYSQL_YIELD(resume_point_, 1, vsconnect_action({&address(), &service_}));
  127. // Convert them to a vector of type-erased endpoints.
  128. // This workarounds https://github.com/chriskohlhoff/asio/issues/1502
  129. // and makes connect() uniform for TCP and UNIX
  130. endpoints_.reserve(resolver_results->size());
  131. for (const auto& entry : *resolver_results)
  132. {
  133. endpoints_.push_back(entry.endpoint());
  134. }
  135. }
  136. else
  137. {
  138. BOOST_ASSERT(addr_->type() == address_type::unix_path);
  139. #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
  140. endpoints_.push_back(asio::local::stream_protocol::endpoint(address()));
  141. #else
  142. BOOST_MYSQL_YIELD(resume_point_, 3, vsconnect_action::immediate_tag{});
  143. return vsconnect_action(asio::error::operation_not_supported);
  144. #endif
  145. }
  146. // Actually connect
  147. BOOST_MYSQL_YIELD(resume_point_, 2, vsconnect_action{endpoints_});
  148. // If we're doing TCP, disable Naggle's algorithm
  149. if (addr_->type() == address_type::host_and_port)
  150. {
  151. st_->sock.set_option(asio::ip::tcp::no_delay(true));
  152. }
  153. // Done
  154. }
  155. return {};
  156. }
  157. };
  158. // Implements the EngineStream concept (see stream_adaptor)
  159. class variant_stream
  160. {
  161. public:
  162. variant_stream(asio::any_io_executor ex, asio::ssl::context* ctx) : st_(std::move(ex), ctx) {}
  163. bool supports_ssl() const { return true; }
  164. // Executor
  165. using executor_type = asio::any_io_executor;
  166. executor_type get_executor() { return st_.sock.get_executor(); }
  167. // SSL
  168. void ssl_handshake(error_code& ec)
  169. {
  170. st_.create_ssl_stream().handshake(asio::ssl::stream_base::client, ec);
  171. }
  172. template <class CompletionToken>
  173. void async_ssl_handshake(CompletionToken&& token)
  174. {
  175. st_.create_ssl_stream();
  176. st_.ssl->async_handshake(asio::ssl::stream_base::client, std::forward<CompletionToken>(token));
  177. }
  178. void ssl_shutdown(error_code& ec)
  179. {
  180. BOOST_ASSERT(st_.ssl.has_value());
  181. st_.ssl->shutdown(ec);
  182. }
  183. template <class CompletionToken>
  184. void async_ssl_shutdown(CompletionToken&& token)
  185. {
  186. BOOST_ASSERT(st_.ssl.has_value());
  187. st_.ssl->async_shutdown(std::forward<CompletionToken>(token));
  188. }
  189. // Reading
  190. std::size_t read_some(asio::mutable_buffer buff, bool use_ssl, error_code& ec)
  191. {
  192. if (use_ssl)
  193. {
  194. BOOST_ASSERT(st_.ssl.has_value());
  195. return st_.ssl->read_some(buff, ec);
  196. }
  197. else
  198. {
  199. return st_.sock.read_some(buff, ec);
  200. }
  201. }
  202. template <class CompletionToken>
  203. void async_read_some(asio::mutable_buffer buff, bool use_ssl, CompletionToken&& token)
  204. {
  205. if (use_ssl)
  206. {
  207. BOOST_ASSERT(st_.ssl.has_value());
  208. st_.ssl->async_read_some(buff, std::forward<CompletionToken>(token));
  209. }
  210. else
  211. {
  212. st_.sock.async_read_some(buff, std::forward<CompletionToken>(token));
  213. }
  214. }
  215. // Writing
  216. std::size_t write_some(boost::asio::const_buffer buff, bool use_ssl, error_code& ec)
  217. {
  218. if (use_ssl)
  219. {
  220. BOOST_ASSERT(st_.ssl.has_value());
  221. return st_.ssl->write_some(buff, ec);
  222. }
  223. else
  224. {
  225. return st_.sock.write_some(buff, ec);
  226. }
  227. }
  228. template <class CompletionToken>
  229. void async_write_some(boost::asio::const_buffer buff, bool use_ssl, CompletionToken&& token)
  230. {
  231. if (use_ssl)
  232. {
  233. BOOST_ASSERT(st_.ssl.has_value());
  234. return st_.ssl->async_write_some(buff, std::forward<CompletionToken>(token));
  235. }
  236. else
  237. {
  238. return st_.sock.async_write_some(buff, std::forward<CompletionToken>(token));
  239. }
  240. }
  241. // Connect and close
  242. void connect(const void* server_address, error_code& output_ec)
  243. {
  244. // Setup
  245. variant_stream_connect_algo algo(st_, *static_cast<const any_address*>(server_address));
  246. error_code ec;
  247. asio::ip::tcp::resolver::results_type resolver_results;
  248. // Run until complete
  249. while (true)
  250. {
  251. // The sync algorithm doesn't support cancellation
  252. auto act = algo.resume(ec, &resolver_results, asio::cancellation_type_t::none);
  253. switch (act.type)
  254. {
  255. case vsconnect_action_type::connect: asio::connect(st_.sock, act.data.connect, ec); break;
  256. case vsconnect_action_type::resolve:
  257. resolver_results = algo.resolver()
  258. .resolve(*act.data.resolve.hostname, *act.data.resolve.service, ec);
  259. break;
  260. case vsconnect_action_type::immediate: break; // has effect only for async
  261. case vsconnect_action_type::none: output_ec = act.data.err; return;
  262. default: BOOST_ASSERT(false); // LCOV_EXCL_LINE
  263. }
  264. }
  265. }
  266. template <class CompletionToken>
  267. void async_connect(const void* server_address, CompletionToken&& token)
  268. {
  269. asio::async_compose<CompletionToken, void(error_code)>(
  270. connect_op(*this, *static_cast<const any_address*>(server_address)),
  271. token,
  272. get_executor()
  273. );
  274. }
  275. void close(error_code& ec)
  276. {
  277. st_.sock.shutdown(asio::generic::stream_protocol::socket::shutdown_both, ec);
  278. st_.sock.close(ec);
  279. }
  280. // Exposed for testing
  281. const asio::generic::stream_protocol::socket& socket() const { return st_.sock; }
  282. private:
  283. variant_stream_state st_;
  284. struct connect_op
  285. {
  286. std::unique_ptr<variant_stream_connect_algo> algo_;
  287. connect_op(variant_stream& this_obj, const any_address& server_address)
  288. : algo_(new variant_stream_connect_algo(this_obj.st_, server_address))
  289. {
  290. }
  291. template <class Self>
  292. void operator()(
  293. Self& self,
  294. error_code ec = {},
  295. const asio::ip::tcp::resolver::results_type& resolver_results = {}
  296. )
  297. {
  298. auto act = algo_->resume(ec, &resolver_results, self.cancelled());
  299. switch (act.type)
  300. {
  301. case vsconnect_action_type::connect:
  302. asio::async_connect(algo_->socket(), act.data.connect, std::move(self));
  303. break;
  304. case vsconnect_action_type::resolve:
  305. algo_->resolver()
  306. .async_resolve(*act.data.resolve.hostname, *act.data.resolve.service, std::move(self));
  307. break;
  308. case vsconnect_action_type::immediate:
  309. asio::dispatch(
  310. asio::get_associated_immediate_executor(self, self.get_io_executor()),
  311. std::move(self)
  312. );
  313. break;
  314. case vsconnect_action_type::none:
  315. algo_.reset();
  316. self.complete(act.data.err);
  317. break;
  318. default: BOOST_ASSERT(false); // LCOV_EXCL_LINE
  319. }
  320. }
  321. // Signature for range connect
  322. template <class Self>
  323. void operator()(Self& self, error_code ec, asio::generic::stream_protocol::endpoint)
  324. {
  325. (*this)(self, ec, asio::ip::tcp::resolver::results_type{});
  326. }
  327. };
  328. };
  329. } // namespace detail
  330. } // namespace mysql
  331. } // namespace boost
  332. #endif