// // Copyright (c) 2022 Klemens Morgenstern (klemens.morgenstern@gmx.net) // // Distributed under the Boost Software License, Version 1.0. (See accompanying // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) // #ifndef BOOST_COBALT_DETAIL_RACE_HPP #define BOOST_COBALT_DETAIL_RACE_HPP #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace boost::cobalt::detail { struct left_race_tag {}; // helpers it determining the type of things; template struct race_traits { // for a ranges race this is based on the range, not the AW in it. constexpr static bool is_lvalue = std::is_lvalue_reference_v; // what the value is supposed to be cast to before the co_await_operator using awaitable = std::conditional_t &, Awaitable &&>; // do we need operator co_await constexpr static bool is_actual = awaitable_type; // the type with .await_ functions & interrupt_await using actual_awaitable = std::conditional_t< is_actual, awaitable, decltype(get_awaitable_type(std::declval()))>; // the type to be used with interruptible using interruptible_type = std::conditional_t< std::is_lvalue_reference_v, std::decay_t &, std::decay_t &&>; constexpr static bool interruptible = cobalt::interruptible; static void do_interrupt(std::decay_t & aw) { if constexpr (interruptible) static_cast(aw).interrupt_await(); } }; struct interruptible_base { virtual void interrupt_await() = 0; }; template struct race_variadic_impl { template race_variadic_impl(URBG_ && g, Args && ... args) : args{std::forward(args)...}, g(std::forward(g)) { } std::tuple args; URBG g; constexpr static std::size_t tuple_size = sizeof...(Args); struct awaitable : fork::static_shared_state<256 * tuple_size> { #if !defined(BOOST_ASIO_ENABLE_HANDLER_TRACKING) boost::source_location loc; #endif template awaitable(std::tuple & args, URBG & g, std::index_sequence) : aws{args} { if constexpr (!std::is_same_v) std::shuffle(impls.begin(), impls.end(), g); std::fill(working.begin(), working.end(), nullptr); } std::tuple & aws; std::array cancel_; template constexpr static auto make_null() {return nullptr;}; std::array cancel = {make_null()...}; std::array working; std::size_t index{std::numeric_limits::max()}; constexpr static bool all_void = (std::is_void_v> && ... ); std::optional>...>> result; std::exception_ptr error; bool has_result() const { return index != std::numeric_limits::max(); } void cancel_all() { interrupt_await(); for (auto i = 0u; i < tuple_size; i++) if (auto &r = cancel[i]; r) std::exchange(r, nullptr)->emit(Ct); } void interrupt_await() { for (auto i : working) if (i) i->interrupt_await(); } template void assign_error(system::result & res) BOOST_TRY { std::move(res).value(loc); } BOOST_CATCH(...) { error = std::current_exception(); } BOOST_CATCH_END template void assign_error(system::result & res) { error = std::move(res).error(); } template static detail::fork await_impl(awaitable & this_) BOOST_TRY { using traits = race_traits, Idx>>; typename traits::actual_awaitable aw_{ get_awaitable_type( static_cast(std::get(this_.aws)) ) }; as_result_t aw{aw_}; struct interruptor final : interruptible_base { std::decay_t & aw; interruptor(std::decay_t & aw) : aw(aw) {} void interrupt_await() override { traits::do_interrupt(aw); } }; interruptor in{aw_}; //if constexpr (traits::interruptible) this_.working[Idx] = ∈ auto transaction = [&this_, idx = Idx] { if (this_.has_result()) boost::throw_exception(std::runtime_error("Another transaction already started")); this_.cancel[idx] = nullptr; // reserve the index early bc this_.index = idx; this_.cancel_all(); }; co_await fork::set_transaction_function(transaction); // check manually if we're ready auto rd = aw.await_ready(); if (!rd) { this_.cancel[Idx] = &this_.cancel_[Idx]; co_await this_.cancel[Idx]->slot(); // make sure the executor is set co_await detail::fork::wired_up; // do the await - this doesn't call await-ready again if constexpr (std::is_void_v) { auto res = co_await aw; if (!this_.has_result()) { this_.index = Idx; if (res.has_error()) this_.assign_error(res); } if constexpr(!all_void) if (this_.index == Idx && !res.has_error()) this_.result.emplace(variant2::in_place_index); } else { auto val = co_await aw; if (!this_.has_result()) this_.index = Idx; if (this_.index == Idx) { if (val.has_error()) this_.assign_error(val); else this_.result.emplace(variant2::in_place_index, *std::move(val)); } } this_.cancel[Idx] = nullptr; } else { if (!this_.has_result()) this_.index = Idx; if constexpr (std::is_void_v) { auto res = aw.await_resume(); if (this_.index == Idx) { if (res.has_error()) this_.assign_error(res); else this_.result.emplace(variant2::in_place_index); } } else { if (this_.index == Idx) { auto res = aw.await_resume(); if (res.has_error()) this_.assign_error(res); else this_.result.emplace(variant2::in_place_index, *std::move(res)); } else aw.await_resume(); } this_.cancel[Idx] = nullptr; } this_.cancel_all(); this_.working[Idx] = nullptr; } BOOST_CATCH(...) { if (!this_.has_result()) this_.index = Idx; if (this_.index == Idx) this_.error = std::current_exception(); this_.working[Idx] = nullptr; } BOOST_CATCH_END std::array impls { [](std::index_sequence) { return std::array{&await_impl...}; }(std::make_index_sequence{}) }; detail::fork last_forked; bool await_ready() { last_forked = impls[0](*this); return last_forked.done(); } template auto await_suspend( std::coroutine_handle h, const boost::source_location & loc = BOOST_CURRENT_LOCATION) { this->loc = loc; this->exec = &cobalt::detail::get_executor(h); last_forked.release().resume(); if (!this->outstanding_work()) // already done, resume rightaway. return false; for (std::size_t idx = 1u; idx < tuple_size; idx++) // we' { auto l = impls[idx](*this); const auto d = l.done(); l.release(); if (d) break; } if (!this->outstanding_work()) // already done, resume rightaway. return false; // arm the cancel assign_cancellation( h, [&](asio::cancellation_type ct) { for (auto & cs : cancel) if (cs) cs->emit(ct); }); this->coro.reset(h.address()); return true; } #if _MSC_VER BOOST_NOINLINE #endif auto await_resume() { if (error) std::rethrow_exception(error); if constexpr (all_void) return index; else return std::move(*result); } auto await_resume(const as_tuple_tag &) { if constexpr (all_void) return std::make_tuple(error, index); else return std::make_tuple(error, std::move(*result)); } auto await_resume(const as_result_tag & ) -> system::result>...>>, std::exception_ptr> { if (error) return {system::in_place_error, error}; if constexpr (all_void) return {system::in_place_value, index}; else return {system::in_place_value, std::move(*result)}; } }; awaitable operator co_await() && { return awaitable{args, g, std::make_index_sequence{}}; } }; template struct race_ranged_impl { using result_type = co_await_result_t()))>>; template race_ranged_impl(URBG_ && g, Range && rng) : range{std::forward(rng)}, g(std::forward(g)) { } Range range; URBG g; struct awaitable : fork::shared_state { #if !defined(BOOST_ASIO_ENABLE_HANDLER_TRACKING) boost::source_location loc; #endif using type = std::decay_t()))>; using traits = race_traits; std::size_t index{std::numeric_limits::max()}; std::conditional_t< std::is_void_v, variant2::monostate, std::optional> result; std::exception_ptr error; #if !defined(BOOST_COBALT_NO_PMR) pmr::monotonic_buffer_resource res; pmr::polymorphic_allocator alloc{&resource}; Range &aws; struct dummy { template dummy(Args && ...) {} }; std::conditional_t*>, dummy> working{std::size(aws), alloc}; /* all below `reorder` is reordered * * cancel[idx] is for aws[reorder[idx]] */ pmr::vector reorder{std::size(aws), alloc}; pmr::vector cancel_{std::size(aws), alloc}; pmr::vector cancel{std::size(aws), alloc}; #else Range &aws; struct dummy { template dummy(Args && ...) {} }; std::conditional_t*>, dummy> working{std::size(aws), std::allocator()}; /* all below `reorder` is reordered * * cancel[idx] is for aws[reorder[idx]] */ std::vector reorder{std::size(aws), std::allocator()}; std::vector cancel_{std::size(aws), std::allocator()}; std::vector cancel{std::size(aws), std::allocator()}; #endif bool has_result() const {return index != std::numeric_limits::max(); } awaitable(Range & aws, URBG & g) : fork::shared_state((256 + sizeof(co_awaitable_type) + sizeof(std::size_t)) * std::size(aws)) , aws(aws) { std::generate(reorder.begin(), reorder.end(), [i = std::size_t(0u)]() mutable {return i++;}); if constexpr (traits::interruptible) std::fill(working.begin(), working.end(), nullptr); if constexpr (!std::is_same_v) std::shuffle(reorder.begin(), reorder.end(), g); } void cancel_all() { interrupt_await(); for (auto & r : cancel) if (r) std::exchange(r, nullptr)->emit(Ct); } void interrupt_await() { if constexpr (traits::interruptible) for (auto aw : working) if (aw) traits::do_interrupt(*aw); } template void assign_error(system::result & res) BOOST_TRY { std::move(res).value(loc); } BOOST_CATCH(...) { error = std::current_exception(); } BOOST_CATCH_END template void assign_error(system::result & res) { error = std::move(res).error(); } static detail::fork await_impl(awaitable & this_, std::size_t idx) BOOST_TRY { typename traits::actual_awaitable aw_{ get_awaitable_type( static_cast(*std::next(std::begin(this_.aws), idx)) )}; as_result_t aw{aw_}; if constexpr (traits::interruptible) this_.working[idx] = &aw_; auto transaction = [&this_, idx = idx] { if (this_.has_result()) boost::throw_exception(std::runtime_error("Another transaction already started")); this_.cancel[idx] = nullptr; // reserve the index early bc this_.index = idx; this_.cancel_all(); }; co_await fork::set_transaction_function(transaction); // check manually if we're ready auto rd = aw.await_ready(); if (!rd) { this_.cancel[idx] = &this_.cancel_[idx]; co_await this_.cancel[idx]->slot(); // make sure the executor is set co_await detail::fork::wired_up; // do the await - this doesn't call await-ready again if constexpr (std::is_void_v) { auto res = co_await aw; if (!this_.has_result()) { if (res.has_error()) this_.assign_error(res); this_.index = idx; } } else { auto val = co_await aw; if (!this_.has_result()) this_.index = idx; if (this_.index == idx) { if (val.has_error()) this_.assign_error(val); else this_.result.emplace(*std::move(val)); } } this_.cancel[idx] = nullptr; } else { if (!this_.has_result()) this_.index = idx; if constexpr (std::is_void_v) { auto val = aw.await_resume(); if (val.has_error()) this_.assign_error(val); } else { if (this_.index == idx) { auto val = aw.await_resume(); if (val.has_error()) this_.assign_error(val); else this_.result.emplace(*std::move(val)); } else aw.await_resume(); } this_.cancel[idx] = nullptr; } this_.cancel_all(); if constexpr (traits::interruptible) this_.working[idx] = nullptr; } BOOST_CATCH(...) { if (!this_.has_result()) this_.index = idx; if (this_.index == idx) this_.error = std::current_exception(); if constexpr (traits::interruptible) this_.working[idx] = nullptr; } BOOST_CATCH_END detail::fork last_forked; bool await_ready() { last_forked = await_impl(*this, reorder.front()); return last_forked.done(); } template auto await_suspend(std::coroutine_handle h, const boost::source_location & loc = BOOST_CURRENT_LOCATION) { this->loc = loc; this->exec = &detail::get_executor(h); last_forked.release().resume(); if (!this->outstanding_work()) // already done, resume rightaway. return false; for (auto itr = std::next(reorder.begin()); itr < reorder.end(); std::advance(itr, 1)) // we' { auto l = await_impl(*this, *itr); auto d = l.done(); l.release(); if (d) break; } if (!this->outstanding_work()) // already done, resume rightaway. return false; // arm the cancel assign_cancellation( h, [&](asio::cancellation_type ct) { for (auto & cs : cancel) if (cs) cs->emit(ct); }); this->coro.reset(h.address()); return true; } #if _MSC_VER BOOST_NOINLINE #endif auto await_resume() { if (error) std::rethrow_exception(error); if constexpr (std::is_void_v) return index; else return std::make_pair(index, *result); } auto await_resume(const as_tuple_tag &) { if constexpr (std::is_void_v) return std::make_tuple(error, index); else return std::make_tuple(error, std::make_pair(index, std::move(*result))); } auto await_resume(const as_result_tag & ) -> system::result { if (error) return {system::in_place_error, error}; if constexpr (std::is_void_v) return {system::in_place_value, index}; else return {system::in_place_value, std::make_pair(index, std::move(*result))}; } }; awaitable operator co_await() && { return awaitable{range, g}; } }; } #endif //BOOST_COBALT_DETAIL_RACE_HPP