promise.hpp 9.8 KB


  1. //
  2. // Copyright (c) 2022 Klemens Morgenstern (klemens.morgenstern@gmx.net)
  3. //
  4. // Distributed under the Boost Software License, Version 1.0. (See accompanying
  5. // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
  6. //
  7. #ifndef BOOST_COBALT_DETAIL_PROMISE_HPP
  8. #define BOOST_COBALT_DETAIL_PROMISE_HPP
  9. #include <boost/cobalt/detail/exception.hpp>
  10. #include <boost/cobalt/detail/forward_cancellation.hpp>
  11. #include <boost/cobalt/detail/wrapper.hpp>
  12. #include <boost/cobalt/detail/this_thread.hpp>
  13. #include <boost/cobalt/noop.hpp>
  14. #include <boost/cobalt/op.hpp>
  15. #include <boost/cobalt/unique_handle.hpp>
  16. #include <boost/asio/cancellation_signal.hpp>
  17. #include <boost/asio/bind_allocator.hpp>
  18. #include <boost/core/exchange.hpp>
  19. #include <coroutine>
  20. #include <optional>
  21. #include <utility>
  22. namespace boost::cobalt
  23. {
  24. struct as_tuple_tag;
  25. struct as_result_tag;
  26. template<typename Return>
  27. struct promise;
  28. namespace detail
  29. {
  30. template<typename T>
  31. struct promise_receiver;
  32. template<typename T>
  33. struct promise_value_holder
  34. {
  35. std::optional<T> result;
  36. bool result_taken = false;
  37. system::result<T, std::exception_ptr> get_result_value()
  38. {
  39. result_taken = true;
  40. BOOST_ASSERT(result);
  41. return {system::in_place_value, std::move(*result)};
  42. }
  43. void return_value(T && ret)
  44. {
  45. result.emplace(std::move(ret));
  46. static_cast<promise_receiver<T>*>(this)->set_done();
  47. }
  48. void return_value(const T & ret)
  49. {
  50. result.emplace(ret);
  51. static_cast<promise_receiver<T>*>(this)->set_done();
  52. }
  53. constexpr promise_value_holder() = default;
  54. constexpr promise_value_holder(noop<T> value) noexcept(std::is_nothrow_move_constructible_v<T>) : result(std::move(value.value)) {}
  55. };
  56. template<>
  57. struct promise_value_holder<void>
  58. {
  59. bool result_taken = false;
  60. system::result<void, std::exception_ptr> get_result_value()
  61. {
  62. result_taken = true;
  63. return {system::in_place_value};
  64. }
  65. inline void return_void();
  66. constexpr promise_value_holder() = default;
  67. constexpr promise_value_holder(noop<void>) {}
  68. };
  69. template<typename T>
  70. struct promise_receiver : promise_value_holder<T>
  71. {
  72. std::exception_ptr exception;
  73. system::result<T, std::exception_ptr> get_result()
  74. {
  75. if (exception && !done) // detached error
  76. return {system::in_place_error, std::exchange(exception, nullptr)};
  77. else if (exception)
  78. {
  79. this->result_taken = true;
  80. return {system::in_place_error, exception};
  81. }
  82. return this->get_result_value();
  83. }
  84. void unhandled_exception()
  85. {
  86. exception = std::current_exception();
  87. set_done();
  88. }
  89. bool done = false;
  90. unique_handle<void> awaited_from{nullptr};
  91. void set_done()
  92. {
  93. done = true;
  94. }
  95. promise_receiver() = default;
  96. promise_receiver(noop<T> value) : promise_value_holder<T>(std::move(value)), done(true) {}
  97. promise_receiver(promise_receiver && lhs) noexcept
  98. : promise_value_holder<T>(std::move(lhs)),
  99. exception(std::move(lhs.exception)), done(lhs.done), awaited_from(std::move(lhs.awaited_from)),
  100. reference(lhs.reference), cancel_signal(lhs.cancel_signal)
  101. {
  102. if (!done && !exception)
  103. {
  104. *reference = this;
  105. lhs.exception = moved_from_exception();
  106. }
  107. lhs.done = true;
  108. }
  109. promise_receiver& operator=(promise_receiver && lhs) noexcept
  110. {
  111. if (*reference == this)
  112. {
  113. *reference = nullptr;
  114. }
  115. promise_value_holder<T>::operator=(std::move(lhs));
  116. exception = std::move(lhs.exception);
  117. done = std::move(lhs.done);
  118. awaited_from = std::move(lhs.awaited_from);
  119. reference = std::move(lhs.reference);
  120. cancel_signal = std::move(lhs.cancel_signal);
  121. if (!done && !exception)
  122. {
  123. *reference = this;
  124. lhs.exception = moved_from_exception();
  125. }
  126. return *this;
  127. }
  128. ~promise_receiver()
  129. {
  130. if (!done && *reference == this)
  131. *reference = nullptr;
  132. }
  133. promise_receiver(promise_receiver * &reference, asio::cancellation_signal & cancel_signal)
  134. : reference(&reference), cancel_signal(&cancel_signal)
  135. {
  136. reference = this;
  137. }
  138. struct awaitable
  139. {
  140. promise_receiver * self;
  141. std::exception_ptr ex;
  142. asio::cancellation_slot cl;
  143. awaitable(promise_receiver * self) : self(self)
  144. {
  145. }
  146. awaitable(awaitable && aw) : self(aw.self)
  147. {
  148. }
  149. ~awaitable ()
  150. {
  151. }
  152. bool await_ready() const { return self->done; }
  153. template<typename Promise>
  154. bool await_suspend(std::coroutine_handle<Promise> h)
  155. {
  156. if (self->done) // ok, so we're actually done already, so noop
  157. return false;
  158. if (ex)
  159. return false;
  160. if (self->awaited_from != nullptr) // we're already being awaited, that's an error!
  161. {
  162. ex = already_awaited();
  163. return false;
  164. }
  165. if constexpr (requires (Promise p) {p.get_cancellation_slot();})
  166. if ((cl = h.promise().get_cancellation_slot()).is_connected())
  167. cl.emplace<forward_cancellation>(*self->cancel_signal);
  168. self->awaited_from.reset(h.address());
  169. return true;
  170. }
  171. T await_resume(const boost::source_location & loc = BOOST_CURRENT_LOCATION)
  172. {
  173. if (cl.is_connected())
  174. cl.clear();
  175. if (ex)
  176. std::rethrow_exception(ex);
  177. return self->get_result().value(loc);
  178. }
  179. system::result<T, std::exception_ptr> await_resume(const as_result_tag &)
  180. {
  181. if (cl.is_connected())
  182. cl.clear();
  183. if (ex)
  184. return {system::in_place_error, std::move(ex)};
  185. return self->get_result();
  186. }
  187. auto await_resume(const as_tuple_tag &)
  188. {
  189. if (cl.is_connected())
  190. cl.clear();
  191. if constexpr (std::is_void_v<T>)
  192. {
  193. if (ex)
  194. return std::move(ex);
  195. return self->get_result().error();
  196. }
  197. else
  198. {
  199. if (ex)
  200. return std::make_tuple(std::move(ex), T{});
  201. auto res = self->get_result();
  202. if (res.has_error())
  203. return std::make_tuple(res.error(), T{});
  204. else
  205. return std::make_tuple(std::exception_ptr(), std::move(*res));
  206. }
  207. }
  208. void interrupt_await() &
  209. {
  210. if (!self)
  211. return ;
  212. ex = detached_exception();
  213. if (self->awaited_from)
  214. self->awaited_from.release().resume();
  215. }
  216. };
  217. promise_receiver **reference;
  218. asio::cancellation_signal * cancel_signal;
  219. awaitable get_awaitable() {return awaitable{this};}
  220. void interrupt_await() &
  221. {
  222. exception = detached_exception();
  223. awaited_from.release().resume();
  224. }
  225. };
  226. inline void promise_value_holder<void>::return_void()
  227. {
  228. static_cast<promise_receiver<void>*>(this)->set_done();
  229. }
  230. template<typename Return>
  231. struct cobalt_promise_result
  232. {
  233. promise_receiver<Return>* receiver{nullptr};
  234. void return_value(Return && ret)
  235. {
  236. if(receiver)
  237. receiver->return_value(std::move(ret));
  238. }
  239. void return_value(const Return & ret)
  240. {
  241. if(receiver)
  242. receiver->return_value(ret);
  243. }
  244. };
  245. template<>
  246. struct cobalt_promise_result<void>
  247. {
  248. promise_receiver<void>* receiver{nullptr};
  249. void return_void()
  250. {
  251. if(receiver)
  252. receiver->return_void();
  253. }
  254. };
  255. template<typename Return>
  256. struct cobalt_promise
  257. : promise_memory_resource_base,
  258. promise_cancellation_base<asio::cancellation_slot, asio::enable_total_cancellation>,
  259. promise_throw_if_cancelled_base,
  260. enable_awaitables<cobalt_promise<Return>>,
  261. enable_await_allocator<cobalt_promise<Return>>,
  262. enable_await_executor<cobalt_promise<Return>>,
  263. enable_await_deferred,
  264. cobalt_promise_result<Return>
  265. {
  266. using promise_cancellation_base<asio::cancellation_slot, asio::enable_total_cancellation>::await_transform;
  267. using promise_throw_if_cancelled_base::await_transform;
  268. using enable_awaitables<cobalt_promise<Return>>::await_transform;
  269. using enable_await_allocator<cobalt_promise<Return>>::await_transform;
  270. using enable_await_executor<cobalt_promise<Return>>::await_transform;
  271. using enable_await_deferred::await_transform;
  272. [[nodiscard]] promise<Return> get_return_object()
  273. {
  274. return promise<Return>{this};
  275. }
  276. mutable asio::cancellation_signal signal;
  277. using executor_type = executor;
  278. executor_type exec;
  279. const executor_type & get_executor() const {return exec;}
  280. template<typename ... Args>
  281. cobalt_promise(Args & ...args)
  282. :
  283. #if !defined(BOOST_COBALT_NO_PMR)
  284. promise_memory_resource_base(detail::get_memory_resource_from_args(args...)),
  285. #endif
  286. exec{detail::get_executor_from_args(args...)}
  287. {
  288. this->reset_cancellation_source(signal.slot());
  289. }
  290. std::suspend_never initial_suspend() noexcept {return {};}
  291. auto final_suspend() noexcept
  292. {
  293. return final_awaitable{this};
  294. }
  295. void unhandled_exception()
  296. {
  297. if (this->receiver)
  298. this->receiver->unhandled_exception();
  299. else
  300. throw ;
  301. }
  302. ~cobalt_promise()
  303. {
  304. if (this->receiver)
  305. {
  306. if (!this->receiver->done && !this->receiver->exception)
  307. this->receiver->exception = completed_unexpected();
  308. this->receiver->set_done();
  309. this->receiver->awaited_from.reset(nullptr);
  310. }
  311. }
  312. private:
  313. struct final_awaitable
  314. {
  315. cobalt_promise * promise;
  316. bool await_ready() const noexcept
  317. {
  318. return promise->receiver && promise->receiver->awaited_from.get() == nullptr;
  319. }
  320. std::coroutine_handle<void> await_suspend(std::coroutine_handle<cobalt_promise> h) noexcept
  321. {
  322. std::coroutine_handle<void> res = std::noop_coroutine();
  323. if (promise->receiver && promise->receiver->awaited_from.get() != nullptr)
  324. res = promise->receiver->awaited_from.release();
  325. if (auto &rec = h.promise().receiver; rec != nullptr)
  326. {
  327. if (!rec->done && !rec->exception)
  328. rec->exception = completed_unexpected();
  329. rec->set_done();
  330. rec->awaited_from.reset(nullptr);
  331. rec = nullptr;
  332. }
  333. detail::self_destroy(h);
  334. return res;
  335. }
  336. void await_resume() noexcept
  337. {
  338. }
  339. };
  340. };
  341. }
  342. }
  343. #endif //BOOST_COBALT_DETAIL_PROMISE_HPP