handshake.hpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. //
  2. // Copyright (c) 2019-2025 Ruben Perez Hidalgo (rubenperez038 at gmail dot com)
  3. //
  4. // Distributed under the Boost Software License, Version 1.0. (See accompanying
  5. // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
  6. //
  7. #ifndef BOOST_MYSQL_IMPL_INTERNAL_SANSIO_HANDSHAKE_HPP
  8. #define BOOST_MYSQL_IMPL_INTERNAL_SANSIO_HANDSHAKE_HPP
  9. #include <boost/mysql/character_set.hpp>
  10. #include <boost/mysql/client_errc.hpp>
  11. #include <boost/mysql/diagnostics.hpp>
  12. #include <boost/mysql/error_code.hpp>
  13. #include <boost/mysql/handshake_params.hpp>
  14. #include <boost/mysql/mysql_collations.hpp>
  15. #include <boost/mysql/string_view.hpp>
  16. #include <boost/mysql/detail/algo_params.hpp>
  17. #include <boost/mysql/detail/next_action.hpp>
  18. #include <boost/mysql/detail/ok_view.hpp>
  19. #include <boost/mysql/impl/internal/coroutine.hpp>
  20. #include <boost/mysql/impl/internal/protocol/capabilities.hpp>
  21. #include <boost/mysql/impl/internal/protocol/db_flavor.hpp>
  22. #include <boost/mysql/impl/internal/protocol/deserialization.hpp>
  23. #include <boost/mysql/impl/internal/protocol/serialization.hpp>
  24. #include <boost/mysql/impl/internal/protocol/static_buffer.hpp>
  25. #include <boost/mysql/impl/internal/sansio/auth_plugin_common.hpp>
  26. #include <boost/mysql/impl/internal/sansio/caching_sha2_password.hpp>
  27. #include <boost/mysql/impl/internal/sansio/connection_state_data.hpp>
  28. #include <boost/mysql/impl/internal/sansio/mysql_native_password.hpp>
  29. #include <boost/core/span.hpp>
  30. #include <boost/system/result.hpp>
  31. #include <boost/variant2/variant.hpp>
  32. #include <array>
  33. #include <cstdint>
  34. #include <cstring>
  35. namespace boost {
  36. namespace mysql {
  37. namespace detail {
  38. // Stores which authentication plugin we're using, plus any required state. Variant-like
  39. class any_authentication_plugin
  40. {
  41. enum class type_t
  42. {
  43. mnp,
  44. csha2p
  45. };
  46. // Which authentication plugin are we using?
  47. type_t type_{type_t::mnp};
  48. // State for algorithms that require stateful exchanges.
  49. // mysql_native_password is stateless, so only caching_sha2_password has an entry here
  50. csha2p_algo csha2p_;
  51. public:
  52. any_authentication_plugin() = default;
  53. // Emplaces a plugin of the type given by plugin_name. Errors on unknown plugin
  54. error_code emplace_by_name(string_view plugin_name)
  55. {
  56. if (plugin_name == mnp_plugin_name)
  57. {
  58. type_ = type_t::mnp;
  59. return error_code();
  60. }
  61. else if (plugin_name == csha2p_plugin_name)
  62. {
  63. type_ = type_t::csha2p;
  64. csha2p_ = csha2p_algo(); // Reset any leftover state, just in case
  65. return error_code();
  66. }
  67. else
  68. {
  69. return client_errc::unknown_auth_plugin;
  70. }
  71. }
  72. // Hashes the password with the selected plugin
  73. static_buffer<max_hash_size> hash_password(
  74. string_view password,
  75. span<const std::uint8_t, scramble_size> scramble
  76. ) const
  77. {
  78. switch (type_)
  79. {
  80. case type_t::mnp: return mnp_hash_password(password, scramble);
  81. case type_t::csha2p: return csha2p_hash_password(password, scramble);
  82. default: BOOST_ASSERT(false); return {}; // LCOV_EXCL_LINE
  83. }
  84. }
  85. // Invokes the plugin action. Use when a more_data packet is received.
  86. next_action resume(
  87. connection_state_data& st,
  88. span<const std::uint8_t> server_data,
  89. string_view password,
  90. span<const std::uint8_t, scramble_size> scramble,
  91. bool secure_channel,
  92. std::uint8_t& seqnum
  93. )
  94. {
  95. switch (type_)
  96. {
  97. case type_t::mnp:
  98. // This algorithm doesn't allow more data frames
  99. return error_code(client_errc::bad_handshake_packet_type);
  100. case type_t::csha2p:
  101. return csha2p_.resume(st, server_data, password, scramble, secure_channel, seqnum);
  102. default:
  103. BOOST_ASSERT(false);
  104. return next_action(client_errc::bad_handshake_packet_type); // LCOV_EXCL_LINE
  105. }
  106. }
  107. string_view name() const
  108. {
  109. switch (type_)
  110. {
  111. case type_t::mnp: return mnp_plugin_name;
  112. case type_t::csha2p: return csha2p_plugin_name;
  113. default: BOOST_ASSERT(false); return {}; // LCOV_EXCL_LINE
  114. }
  115. }
  116. };
  117. class handshake_algo
  118. {
  119. int resume_point_{0};
  120. handshake_params hparams_;
  121. any_authentication_plugin plugin_;
  122. std::array<std::uint8_t, scramble_size> scramble_;
  123. std::uint8_t sequence_number_{0};
  124. bool secure_channel_{false};
  125. static capabilities conditional_capability(bool condition, capabilities cap)
  126. {
  127. return condition ? cap : capabilities{};
  128. }
  129. // Given our params and the capabilities that the server sent us,
  130. // performs capability negotiation, returning either the capabilities to
  131. // send to the server or an error
  132. static system::result<capabilities> negotiate_capabilities(
  133. const handshake_params& params,
  134. capabilities server_caps,
  135. bool transport_supports_ssl
  136. )
  137. {
  138. // The capabilities that we absolutely require. These are always set except in extremely old servers
  139. constexpr capabilities mandatory_capabilities =
  140. // We don't speak the older protocol
  141. capabilities::protocol_41 |
  142. // We only know how to deserialize the hello frame if this is set
  143. capabilities::plugin_auth |
  144. // Same as above
  145. capabilities::plugin_auth_lenenc_data |
  146. // This makes processing execute responses easier
  147. capabilities::deprecate_eof |
  148. // Used in MariaDB to signal 4.1 protocol. Always set in MySQL, too
  149. capabilities::secure_connection;
  150. // The capabilities that we support but don't require
  151. constexpr capabilities optional_capabilities = capabilities::multi_results |
  152. capabilities::ps_multi_results;
  153. auto ssl = transport_supports_ssl ? params.ssl() : ssl_mode::disable;
  154. capabilities required_caps = mandatory_capabilities |
  155. conditional_capability(
  156. !params.database().empty(),
  157. capabilities::connect_with_db
  158. ) |
  159. conditional_capability(
  160. params.multi_queries(),
  161. capabilities::multi_statements
  162. ) |
  163. conditional_capability(ssl == ssl_mode::require, capabilities::ssl);
  164. if (has_capabilities(required_caps, capabilities::ssl) &&
  165. !has_capabilities(server_caps, capabilities::ssl))
  166. {
  167. // This happens if the server doesn't have SSL configured. This special
  168. // error code helps users diagnosing their problem a lot (server_unsupported doesn't).
  169. return make_error_code(client_errc::server_doesnt_support_ssl);
  170. }
  171. else if (!has_capabilities(server_caps, required_caps))
  172. {
  173. return make_error_code(client_errc::server_unsupported);
  174. }
  175. return server_caps & (required_caps | optional_capabilities |
  176. conditional_capability(ssl == ssl_mode::enable, capabilities::ssl));
  177. }
  178. // Attempts to map the collection_id to a character set. We try to be conservative
  179. // here, since servers will happily accept unknown collation IDs, silently defaulting
  180. // to the server's default character set (often latin1, which is not Unicode).
  181. static character_set collation_id_to_charset(std::uint16_t collation_id)
  182. {
  183. switch (collation_id)
  184. {
  185. case mysql_collations::utf8mb4_bin:
  186. case mysql_collations::utf8mb4_general_ci: return utf8mb4_charset;
  187. case mysql_collations::ascii_general_ci:
  188. case mysql_collations::ascii_bin: return ascii_charset;
  189. default: return character_set{};
  190. }
  191. }
  192. // Saves the scramble, checking that it has the right size
  193. error_code save_scramble(span<const std::uint8_t> value)
  194. {
  195. // All scrambles must have exactly this size. Otherwise, it's a protocol violation error
  196. if (value.size() != scramble_size)
  197. return client_errc::protocol_value_error;
  198. // Store the scramble
  199. std::memcpy(scramble_.data(), value.data(), scramble_size);
  200. // Done
  201. return error_code();
  202. }
  203. error_code process_hello(connection_state_data& st, diagnostics& diag, span<const std::uint8_t> buffer)
  204. {
  205. // Deserialize server hello
  206. server_hello hello{};
  207. auto err = deserialize_server_hello(buffer, hello, diag);
  208. if (err)
  209. return err;
  210. // Check capabilities
  211. auto negotiated_caps = negotiate_capabilities(hparams_, hello.server_capabilities, st.tls_supported);
  212. if (negotiated_caps.has_error())
  213. return negotiated_caps.error();
  214. // Set capabilities, db flavor and connection ID
  215. st.current_capabilities = *negotiated_caps;
  216. st.flavor = hello.server;
  217. st.connection_id = hello.connection_id;
  218. // If we're using SSL, mark the channel as secure
  219. secure_channel_ = secure_channel_ || has_capabilities(*negotiated_caps, capabilities::ssl);
  220. // Save which authentication plugin we're using. Do this before saving the scramble,
  221. // as an unknown plugin might have a scramble size different to what we know
  222. err = plugin_.emplace_by_name(hello.auth_plugin_name);
  223. if (err)
  224. return err;
  225. // Save the scramble for later
  226. return save_scramble(hello.auth_plugin_data);
  227. }
  228. // Response to that initial greeting
  229. ssl_request compose_ssl_request(const connection_state_data& st)
  230. {
  231. return ssl_request{
  232. st.current_capabilities,
  233. static_cast<std::uint32_t>(max_packet_size),
  234. hparams_.connection_collation(),
  235. };
  236. }
  237. next_action serialize_login_request(connection_state_data& st)
  238. {
  239. auto hashed_password = plugin_.hash_password(hparams_.password(), scramble_);
  240. return st.write(
  241. login_request{
  242. st.current_capabilities,
  243. static_cast<std::uint32_t>(max_packet_size),
  244. hparams_.connection_collation(),
  245. hparams_.username(),
  246. hashed_password,
  247. hparams_.database(),
  248. plugin_.name(),
  249. },
  250. sequence_number_
  251. );
  252. }
  253. // Processes auth_switch and auth_more_data messages, and leaves the result in auth_resp_
  254. next_action process_auth_switch(connection_state_data& st, auth_switch msg)
  255. {
  256. // Emplace the new authentication plugin
  257. auto ec = plugin_.emplace_by_name(msg.plugin_name);
  258. if (ec)
  259. return ec;
  260. // Store the scramble for later (required by caching_sha2_password, for instance)
  261. ec = save_scramble(msg.auth_data);
  262. if (ec)
  263. return ec;
  264. // Hash the password
  265. auto hashed_password = plugin_.hash_password(hparams_.password(), scramble_);
  266. // Serialize the response
  267. return st.write(auth_switch_response{hashed_password}, sequence_number_);
  268. }
  269. void on_success(connection_state_data& st, const ok_view& ok)
  270. {
  271. st.status = connection_status::ready;
  272. st.backslash_escapes = ok.backslash_escapes();
  273. st.current_charset = collation_id_to_charset(hparams_.connection_collation());
  274. }
  275. next_action resume_impl(connection_state_data& st, diagnostics& diag, error_code ec)
  276. {
  277. if (ec)
  278. return ec;
  279. handshake_server_response resp(error_code{});
  280. next_action act;
  281. switch (resume_point_)
  282. {
  283. case 0:
  284. // Handshake wipes out state, so no state checks are performed.
  285. // Setup
  286. st.reset();
  287. // Read server greeting
  288. BOOST_MYSQL_YIELD(resume_point_, 1, st.read(sequence_number_))
  289. // Process server greeting
  290. ec = process_hello(st, diag, st.reader.message());
  291. if (ec)
  292. return ec;
  293. // SSL
  294. if (has_capabilities(st.current_capabilities, capabilities::ssl))
  295. {
  296. // Send SSL request
  297. BOOST_MYSQL_YIELD(resume_point_, 2, st.write(compose_ssl_request(st), sequence_number_))
  298. // SSL handshake
  299. BOOST_MYSQL_YIELD(resume_point_, 3, next_action::ssl_handshake())
  300. // Mark the connection as using ssl
  301. st.tls_active = true;
  302. }
  303. // Compose and send handshake response
  304. BOOST_MYSQL_YIELD(resume_point_, 4, serialize_login_request(st))
  305. // Receive the response
  306. BOOST_MYSQL_YIELD(resume_point_, 5, st.read(sequence_number_))
  307. // Process it
  308. resp = deserialize_handshake_server_response(st.reader.message(), st.flavor, diag);
  309. // Auth switches are only legal at this point. Handle the case here
  310. if (resp.type == handshake_server_response::type_t::auth_switch)
  311. {
  312. // Write our packet
  313. BOOST_MYSQL_YIELD(resume_point_, 6, process_auth_switch(st, resp.data.auth_sw))
  314. // Read another packet
  315. BOOST_MYSQL_YIELD(resume_point_, 7, st.read(sequence_number_))
  316. // Deserialize it
  317. resp = deserialize_handshake_server_response(st.reader.message(), st.flavor, diag);
  318. }
  319. // Now we will send/receive raw data packets from the server until an OK or error happens.
  320. // Packets requiring responses are auth_more_data packets
  321. while (resp.type == handshake_server_response::type_t::auth_more_data)
  322. {
  323. // Invoke the authentication plugin algorithm
  324. act = plugin_.resume(
  325. st,
  326. resp.data.more_data,
  327. hparams_.password(),
  328. scramble_,
  329. secure_channel_,
  330. sequence_number_
  331. );
  332. // Do what the plugin says
  333. if (act.type() == next_action_type::none)
  334. {
  335. // The plugin signalled an error. Exit
  336. BOOST_ASSERT(act.error());
  337. return act;
  338. }
  339. else if (act.type() == next_action_type::write)
  340. {
  341. // The plugin wants us to first write the message in the write buffer, then read
  342. BOOST_MYSQL_YIELD(resume_point_, 8, act)
  343. BOOST_MYSQL_YIELD(resume_point_, 9, st.read(sequence_number_))
  344. }
  345. else
  346. {
  347. // The plugin wants us to read another packet
  348. BOOST_ASSERT(act.type() == next_action_type::read);
  349. BOOST_MYSQL_YIELD(resume_point_, 10, act)
  350. }
  351. // If we got here, we've successfully read a packet. Deserialize it
  352. resp = deserialize_handshake_server_response(st.reader.message(), st.flavor, diag);
  353. }
  354. // If we got here, we've received a packet that terminates the algorithm
  355. if (resp.type == handshake_server_response::type_t::ok)
  356. {
  357. // Auth success, quit
  358. on_success(st, resp.data.ok);
  359. return next_action();
  360. }
  361. else if (resp.type == handshake_server_response::type_t::error)
  362. {
  363. // Error, quit
  364. return resp.data.err;
  365. }
  366. else
  367. {
  368. // Auth switches are no longer allowed at this point
  369. BOOST_ASSERT(resp.type == handshake_server_response::type_t::auth_switch);
  370. return error_code(client_errc::bad_handshake_packet_type);
  371. }
  372. }
  373. // We should never get here
  374. BOOST_ASSERT(false);
  375. return next_action(); // LCOV_EXCL_LINE
  376. }
  377. public:
  378. handshake_algo(handshake_algo_params params) noexcept
  379. : hparams_(params.hparams), secure_channel_(params.secure_channel)
  380. {
  381. }
  382. next_action resume(connection_state_data& st, diagnostics& diag, error_code ec)
  383. {
  384. // On error, reset the connection's state to well-known values
  385. auto act = resume_impl(st, diag, ec);
  386. if (act.is_done() && act.error())
  387. st.reset();
  388. return act;
  389. }
  390. };
  391. } // namespace detail
  392. } // namespace mysql
  393. } // namespace boost
  394. #endif