race.hpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693
  1. //
  2. // Copyright (c) 2022 Klemens Morgenstern (klemens.morgenstern@gmx.net)
  3. //
  4. // Distributed under the Boost Software License, Version 1.0. (See accompanying
  5. // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
  6. //
  7. #ifndef BOOST_COBALT_DETAIL_RACE_HPP
  8. #define BOOST_COBALT_DETAIL_RACE_HPP
  9. #include <boost/cobalt/detail/await_result_helper.hpp>
  10. #include <boost/cobalt/detail/fork.hpp>
  11. #include <boost/cobalt/detail/handler.hpp>
  12. #include <boost/cobalt/detail/forward_cancellation.hpp>
  13. #include <boost/cobalt/result.hpp>
  14. #include <boost/cobalt/this_thread.hpp>
  15. #include <boost/cobalt/detail/util.hpp>
  16. #include <boost/asio/bind_allocator.hpp>
  17. #include <boost/asio/bind_cancellation_slot.hpp>
  18. #include <boost/asio/bind_executor.hpp>
  19. #include <boost/asio/cancellation_signal.hpp>
  20. #include <boost/asio/associated_cancellation_slot.hpp>
  21. #include <boost/core/no_exceptions_support.hpp>
  22. #include <boost/intrusive_ptr.hpp>
  23. #include <boost/core/demangle.hpp>
  24. #include <boost/core/span.hpp>
  25. #include <boost/variant2/variant.hpp>
  26. #include <coroutine>
  27. #include <optional>
  28. #include <algorithm>
  29. namespace boost::cobalt::detail
  30. {
  31. struct left_race_tag {};
  32. // helpers it determining the type of things;
  33. template<typename Base, // range of aw
  34. typename Awaitable = Base>
  35. struct race_traits
  36. {
  37. // for a ranges race this is based on the range, not the AW in it.
  38. constexpr static bool is_lvalue = std::is_lvalue_reference_v<Base>;
  39. // what the value is supposed to be cast to before the co_await_operator
  40. using awaitable = std::conditional_t<is_lvalue, std::decay_t<Awaitable> &, Awaitable &&>;
  41. // do we need operator co_await
  42. constexpr static bool is_actual = awaitable_type<awaitable>;
  43. // the type with .await_ functions & interrupt_await
  44. using actual_awaitable
  45. = std::conditional_t<
  46. is_actual,
  47. awaitable,
  48. decltype(get_awaitable_type(std::declval<awaitable>()))>;
  49. // the type to be used with interruptible
  50. using interruptible_type
  51. = std::conditional_t<
  52. std::is_lvalue_reference_v<Base>,
  53. std::decay_t<actual_awaitable> &,
  54. std::decay_t<actual_awaitable> &&>;
  55. constexpr static bool interruptible =
  56. cobalt::interruptible<interruptible_type>;
  57. static void do_interrupt(std::decay_t<actual_awaitable> & aw)
  58. {
  59. if constexpr (interruptible)
  60. static_cast<interruptible_type>(aw).interrupt_await();
  61. }
  62. };
  63. struct interruptible_base
  64. {
  65. virtual void interrupt_await() = 0;
  66. };
  67. template<asio::cancellation_type Ct, typename URBG, typename ... Args>
  68. struct race_variadic_impl
  69. {
  70. template<typename URBG_>
  71. BOOST_COBALT_MSVC_NOINLINE
  72. race_variadic_impl(URBG_ && g, Args && ... args)
  73. : args{std::forward<Args>(args)...}, g(std::forward<URBG_>(g))
  74. {
  75. }
  76. std::tuple<Args...> args;
  77. URBG g;
  78. constexpr static std::size_t tuple_size = sizeof...(Args);
  79. struct awaitable : fork::static_shared_state<256 * tuple_size>
  80. {
  81. boost::source_location loc;
  82. template<std::size_t ... Idx>
  83. awaitable(std::tuple<Args...> & args, URBG & g, std::index_sequence<Idx...>) :
  84. aws{args}
  85. {
  86. if constexpr (!std::is_same_v<URBG, left_race_tag>)
  87. std::shuffle(impls.begin(), impls.end(), g);
  88. std::fill(working.begin(), working.end(), nullptr);
  89. }
  90. std::tuple<Args...> & aws;
  91. std::array<asio::cancellation_signal, tuple_size> cancel_;
  92. template<typename > constexpr static auto make_null() {return nullptr;};
  93. std::array<asio::cancellation_signal*, tuple_size> cancel = {make_null<Args>()...};
  94. std::array<interruptible_base*, tuple_size> working;
  95. std::size_t index{std::numeric_limits<std::size_t>::max()};
  96. constexpr static bool all_void = (std::is_void_v<co_await_result_t<Args>> && ... );
  97. std::optional<variant2::variant<void_as_monostate<co_await_result_t<Args>>...>> result;
  98. std::exception_ptr error;
  99. bool has_result() const
  100. {
  101. return index != std::numeric_limits<std::size_t>::max();
  102. }
  103. void cancel_all()
  104. {
  105. interrupt_await();
  106. for (auto i = 0u; i < tuple_size; i++)
  107. if (auto &r = cancel[i]; r)
  108. std::exchange(r, nullptr)->emit(Ct);
  109. }
  110. void interrupt_await()
  111. {
  112. for (auto i : working)
  113. if (i)
  114. i->interrupt_await();
  115. }
  116. template<typename T, typename Error>
  117. void assign_error(system::result<T, Error> & res)
  118. BOOST_TRY
  119. {
  120. std::move(res).value(loc);
  121. }
  122. BOOST_CATCH(...)
  123. {
  124. error = std::current_exception();
  125. }
  126. BOOST_CATCH_END
  127. template<typename T>
  128. void assign_error(system::result<T, std::exception_ptr> & res)
  129. {
  130. error = std::move(res).error();
  131. }
  132. template<std::size_t Idx>
  133. static detail::fork await_impl(awaitable & this_)
  134. BOOST_TRY
  135. {
  136. using traits = race_traits<mp11::mp_at_c<mp11::mp_list<Args...>, Idx>>;
  137. typename traits::actual_awaitable aw_{
  138. get_awaitable_type(
  139. static_cast<typename traits::awaitable>(std::get<Idx>(this_.aws))
  140. )
  141. };
  142. as_result_t aw{aw_};
  143. struct interruptor final : interruptible_base
  144. {
  145. std::decay_t<typename traits::actual_awaitable> & aw;
  146. interruptor(std::decay_t<typename traits::actual_awaitable> & aw) : aw(aw) {}
  147. void interrupt_await() override
  148. {
  149. traits::do_interrupt(aw);
  150. }
  151. };
  152. interruptor in{aw_};
  153. //if constexpr (traits::interruptible)
  154. this_.working[Idx] = &in;
  155. auto transaction = [&this_, idx = Idx] {
  156. if (this_.has_result())
  157. boost::throw_exception(std::runtime_error("Another transaction already started"));
  158. this_.cancel[idx] = nullptr;
  159. // reserve the index early bc
  160. this_.index = idx;
  161. this_.cancel_all();
  162. };
  163. co_await fork::set_transaction_function(transaction);
  164. // check manually if we're ready
  165. auto rd = aw.await_ready();
  166. if (!rd)
  167. {
  168. this_.cancel[Idx] = &this_.cancel_[Idx];
  169. co_await this_.cancel[Idx]->slot();
  170. // make sure the executor is set
  171. co_await detail::fork::wired_up;
  172. // do the await - this doesn't call await-ready again
  173. if constexpr (std::is_void_v<decltype(aw_.await_resume())>)
  174. {
  175. auto res = co_await aw;
  176. if (!this_.has_result())
  177. {
  178. this_.index = Idx;
  179. if (res.has_error())
  180. this_.assign_error(res);
  181. }
  182. if constexpr(!all_void)
  183. if (this_.index == Idx && !res.has_error())
  184. this_.result.emplace(variant2::in_place_index<Idx>);
  185. }
  186. else
  187. {
  188. auto val = co_await aw;
  189. if (!this_.has_result())
  190. this_.index = Idx;
  191. if (this_.index == Idx)
  192. {
  193. if (val.has_error())
  194. this_.assign_error(val);
  195. else
  196. this_.result.emplace(variant2::in_place_index<Idx>, *std::move(val));
  197. }
  198. }
  199. this_.cancel[Idx] = nullptr;
  200. }
  201. else
  202. {
  203. if (!this_.has_result())
  204. this_.index = Idx;
  205. if constexpr (std::is_void_v<decltype(aw_.await_resume())>)
  206. {
  207. auto res = aw.await_resume();
  208. if (this_.index == Idx)
  209. {
  210. if (res.has_error())
  211. this_.assign_error(res);
  212. else
  213. this_.result.emplace(variant2::in_place_index<Idx>);
  214. }
  215. }
  216. else
  217. {
  218. if (this_.index == Idx)
  219. {
  220. auto res = aw.await_resume();
  221. if (res.has_error())
  222. this_.assign_error(res);
  223. else
  224. this_.result.emplace(variant2::in_place_index<Idx>, *std::move(res));
  225. }
  226. else
  227. aw.await_resume();
  228. }
  229. this_.cancel[Idx] = nullptr;
  230. }
  231. this_.cancel_all();
  232. this_.working[Idx] = nullptr;
  233. }
  234. BOOST_CATCH(...)
  235. {
  236. if (!this_.has_result())
  237. this_.index = Idx;
  238. if (this_.index == Idx)
  239. this_.error = std::current_exception();
  240. this_.working[Idx] = nullptr;
  241. }
  242. BOOST_CATCH_END
  243. std::array<detail::fork(*)(awaitable&), tuple_size> impls {
  244. []<std::size_t ... Idx>(std::index_sequence<Idx...>)
  245. {
  246. return std::array<detail::fork(*)(awaitable&), tuple_size>{&await_impl<Idx>...};
  247. }(std::make_index_sequence<tuple_size>{})
  248. };
  249. detail::fork last_forked;
  250. bool await_ready()
  251. {
  252. last_forked = impls[0](*this);
  253. return last_forked.done();
  254. }
  255. template<typename H>
  256. auto await_suspend(
  257. std::coroutine_handle<H> h,
  258. const boost::source_location & loc = BOOST_CURRENT_LOCATION)
  259. {
  260. this->loc = loc;
  261. this->exec = cobalt::detail::get_executor(h);
  262. last_forked.release().resume();
  263. if (!this->outstanding_work()) // already done, resume rightaway.
  264. return false;
  265. for (std::size_t idx = 1u;
  266. idx < tuple_size; idx++) // we'
  267. {
  268. auto l = impls[idx](*this);
  269. const auto d = l.done();
  270. l.release();
  271. if (d)
  272. break;
  273. }
  274. if (!this->outstanding_work()) // already done, resume rightaway.
  275. return false;
  276. // arm the cancel
  277. assign_cancellation(
  278. h,
  279. [&](asio::cancellation_type ct)
  280. {
  281. for (auto & cs : cancel)
  282. if (cs)
  283. cs->emit(ct);
  284. });
  285. this->coro.reset(h.address());
  286. return true;
  287. }
  288. BOOST_COBALT_MSVC_NOINLINE
  289. auto await_resume()
  290. {
  291. if (error)
  292. std::rethrow_exception(error);
  293. if constexpr (all_void)
  294. return index;
  295. else
  296. return std::move(*result);
  297. }
  298. auto await_resume(const as_tuple_tag &)
  299. {
  300. if constexpr (all_void)
  301. return std::make_tuple(error, index);
  302. else
  303. return std::make_tuple(error, std::move(*result));
  304. }
  305. auto await_resume(const as_result_tag & )
  306. -> system::result<std::conditional_t<all_void, std::size_t, variant2::variant<void_as_monostate<co_await_result_t<Args>>...>>, std::exception_ptr>
  307. {
  308. if (error)
  309. return {system::in_place_error, error};
  310. if constexpr (all_void)
  311. return {system::in_place_value, index};
  312. else
  313. return {system::in_place_value, std::move(*result)};
  314. }
  315. };
  316. awaitable operator co_await() &&
  317. {
  318. return awaitable{args, g, std::make_index_sequence<tuple_size>{}};
  319. }
  320. };
  321. template<asio::cancellation_type Ct, typename URBG, typename Range>
  322. struct race_ranged_impl
  323. {
  324. using result_type = co_await_result_t<std::decay_t<decltype(*std::begin(std::declval<Range>()))>>;
  325. template<typename URBG_>
  326. BOOST_COBALT_MSVC_NOINLINE
  327. race_ranged_impl(URBG_ && g, Range && rng)
  328. : range{std::forward<Range>(rng)}, g(std::forward<URBG_>(g))
  329. {
  330. }
  331. Range range;
  332. URBG g;
  333. struct awaitable : fork::shared_state
  334. {
  335. #if !defined(BOOST_ASIO_ENABLE_HANDLER_TRACKING)
  336. boost::source_location loc;
  337. #endif
  338. using type = std::decay_t<decltype(*std::begin(std::declval<Range>()))>;
  339. using traits = race_traits<Range, type>;
  340. std::size_t index{std::numeric_limits<std::size_t>::max()};
  341. std::conditional_t<
  342. std::is_void_v<result_type>,
  343. variant2::monostate,
  344. std::optional<result_type>> result;
  345. std::exception_ptr error;
  346. #if !defined(BOOST_COBALT_NO_PMR)
  347. pmr::polymorphic_allocator<void> alloc{&resource};
  348. Range &aws;
  349. struct dummy
  350. {
  351. template<typename ... Args>
  352. dummy(Args && ...) {}
  353. };
  354. std::conditional_t<traits::interruptible,
  355. pmr::vector<std::decay_t<typename traits::actual_awaitable>*>,
  356. dummy> working{std::size(aws), alloc};
  357. /* all below `reorder` is reordered
  358. *
  359. * cancel[idx] is for aws[reorder[idx]]
  360. */
  361. pmr::vector<std::size_t> reorder{std::size(aws), alloc};
  362. pmr::vector<asio::cancellation_signal> cancel_{std::size(aws), alloc};
  363. pmr::vector<asio::cancellation_signal*> cancel{std::size(aws), alloc};
  364. #else
  365. Range &aws;
  366. struct dummy
  367. {
  368. template<typename ... Args>
  369. dummy(Args && ...) {}
  370. };
  371. std::conditional_t<traits::interruptible,
  372. std::vector<std::decay_t<typename traits::actual_awaitable>*>,
  373. dummy> working{std::size(aws), std::allocator<void>()};
  374. /* all below `reorder` is reordered
  375. *
  376. * cancel[idx] is for aws[reorder[idx]]
  377. */
  378. std::vector<std::size_t> reorder{std::size(aws), std::allocator<void>()};
  379. std::vector<asio::cancellation_signal> cancel_{std::size(aws), std::allocator<void>()};
  380. std::vector<asio::cancellation_signal*> cancel{std::size(aws), std::allocator<void>()};
  381. #endif
  382. bool has_result() const {return index != std::numeric_limits<std::size_t>::max(); }
  383. awaitable(Range & aws, URBG & g)
  384. : fork::shared_state((256 + sizeof(co_awaitable_type<type>) + sizeof(std::size_t)) * std::size(aws))
  385. , aws(aws)
  386. {
  387. std::generate(reorder.begin(), reorder.end(), [i = std::size_t(0u)]() mutable {return i++;});
  388. if constexpr (traits::interruptible)
  389. std::fill(working.begin(), working.end(), nullptr);
  390. if constexpr (!std::is_same_v<URBG, left_race_tag>)
  391. std::shuffle(reorder.begin(), reorder.end(), g);
  392. }
  393. void cancel_all()
  394. {
  395. interrupt_await();
  396. for (auto & r : cancel)
  397. if (r)
  398. std::exchange(r, nullptr)->emit(Ct);
  399. }
  400. void interrupt_await()
  401. {
  402. if constexpr (traits::interruptible)
  403. for (auto aw : working)
  404. if (aw)
  405. traits::do_interrupt(*aw);
  406. }
  407. template<typename T, typename Error>
  408. void assign_error(system::result<T, Error> & res)
  409. BOOST_TRY
  410. {
  411. std::move(res).value(loc);
  412. }
  413. BOOST_CATCH(...)
  414. {
  415. error = std::current_exception();
  416. }
  417. BOOST_CATCH_END
  418. template<typename T>
  419. void assign_error(system::result<T, std::exception_ptr> & res)
  420. {
  421. error = std::move(res).error();
  422. }
  423. static detail::fork await_impl(awaitable & this_, std::size_t idx)
  424. BOOST_TRY
  425. {
  426. typename traits::actual_awaitable aw_{
  427. get_awaitable_type(
  428. static_cast<typename traits::awaitable>(*std::next(std::begin(this_.aws), idx))
  429. )};
  430. as_result_t aw{aw_};
  431. if constexpr (traits::interruptible)
  432. this_.working[idx] = &aw_;
  433. auto transaction = [&this_, idx = idx] {
  434. if (this_.has_result())
  435. boost::throw_exception(std::runtime_error("Another transaction already started"));
  436. this_.cancel[idx] = nullptr;
  437. // reserve the index early bc
  438. this_.index = idx;
  439. this_.cancel_all();
  440. };
  441. co_await fork::set_transaction_function(transaction);
  442. // check manually if we're ready
  443. auto rd = aw.await_ready();
  444. if (!rd)
  445. {
  446. this_.cancel[idx] = &this_.cancel_[idx];
  447. co_await this_.cancel[idx]->slot();
  448. // make sure the executor is set
  449. co_await detail::fork::wired_up;
  450. // do the await - this doesn't call await-ready again
  451. if constexpr (std::is_void_v<result_type>)
  452. {
  453. auto res = co_await aw;
  454. if (!this_.has_result())
  455. {
  456. if (res.has_error())
  457. this_.assign_error(res);
  458. this_.index = idx;
  459. }
  460. }
  461. else
  462. {
  463. auto val = co_await aw;
  464. if (!this_.has_result())
  465. this_.index = idx;
  466. if (this_.index == idx)
  467. {
  468. if (val.has_error())
  469. this_.assign_error(val);
  470. else
  471. this_.result.emplace(*std::move(val));
  472. }
  473. }
  474. this_.cancel[idx] = nullptr;
  475. }
  476. else
  477. {
  478. if (!this_.has_result())
  479. this_.index = idx;
  480. if constexpr (std::is_void_v<decltype(aw_.await_resume())>)
  481. {
  482. auto val = aw.await_resume();
  483. if (val.has_error())
  484. this_.assign_error(val);
  485. }
  486. else
  487. {
  488. if (this_.index == idx)
  489. {
  490. auto val = aw.await_resume();
  491. if (val.has_error())
  492. this_.assign_error(val);
  493. else
  494. this_.result.emplace(*std::move(val));
  495. }
  496. else
  497. aw.await_resume();
  498. }
  499. this_.cancel[idx] = nullptr;
  500. }
  501. this_.cancel_all();
  502. if constexpr (traits::interruptible)
  503. this_.working[idx] = nullptr;
  504. }
  505. BOOST_CATCH(...)
  506. {
  507. if (!this_.has_result())
  508. this_.index = idx;
  509. if (this_.index == idx)
  510. this_.error = std::current_exception();
  511. if constexpr (traits::interruptible)
  512. this_.working[idx] = nullptr;
  513. }
  514. BOOST_CATCH_END
  515. detail::fork last_forked;
  516. bool await_ready()
  517. {
  518. last_forked = await_impl(*this, reorder.front());
  519. return last_forked.done();
  520. }
  521. template<typename H>
  522. auto await_suspend(std::coroutine_handle<H> h,
  523. const boost::source_location & loc = BOOST_CURRENT_LOCATION)
  524. {
  525. this->loc = loc;
  526. this->exec = detail::get_executor(h);
  527. last_forked.release().resume();
  528. if (!this->outstanding_work()) // already done, resume rightaway.
  529. return false;
  530. for (auto itr = std::next(reorder.begin());
  531. itr < reorder.end(); std::advance(itr, 1)) // we'
  532. {
  533. auto l = await_impl(*this, *itr);
  534. auto d = l.done();
  535. l.release();
  536. if (d)
  537. break;
  538. }
  539. if (!this->outstanding_work()) // already done, resume rightaway.
  540. return false;
  541. // arm the cancel
  542. assign_cancellation(
  543. h,
  544. [&](asio::cancellation_type ct)
  545. {
  546. for (auto & cs : cancel)
  547. if (cs)
  548. cs->emit(ct);
  549. });
  550. this->coro.reset(h.address());
  551. return true;
  552. }
  553. BOOST_COBALT_MSVC_NOINLINE
  554. auto await_resume()
  555. {
  556. if (error)
  557. std::rethrow_exception(error);
  558. if constexpr (std::is_void_v<result_type>)
  559. return index;
  560. else
  561. return std::make_pair(index, *result);
  562. }
  563. auto await_resume(const as_tuple_tag &)
  564. {
  565. if constexpr (std::is_void_v<result_type>)
  566. return std::make_tuple(error, index);
  567. else
  568. return std::make_tuple(error, std::make_pair(index, std::move(*result)));
  569. }
  570. auto await_resume(const as_result_tag & )
  571. -> system::result<std::conditional_t<std::is_void_v<result_type>, std::size_t, std::pair<std::size_t, result_type>>, std::exception_ptr>
  572. {
  573. if (error)
  574. return {system::in_place_error, error};
  575. if constexpr (std::is_void_v<result_type>)
  576. return {system::in_place_value, index};
  577. else
  578. return {system::in_place_value, std::make_pair(index, std::move(*result))};
  579. }
  580. };
  581. awaitable operator co_await() &&
  582. {
  583. return awaitable{range, g};
  584. }
  585. };
  586. }
  587. #endif //BOOST_COBALT_DETAIL_RACE_HPP