context.hpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. // Copyright (c) 2024 Klemens D. Morgenstern
  2. //
  3. // Distributed under the Boost Software License, Version 1.0. (See accompanying
  4. // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
  5. #ifndef BOOST_COBALT_EXPERIMENTAL_CONTEXT_HPP
  6. #define BOOST_COBALT_EXPERIMENTAL_CONTEXT_HPP
  7. #include <boost/callable_traits/args.hpp>
  8. #include <boost/context/fiber.hpp>
  9. #include <boost/context/fixedsize_stack.hpp>
  10. #include <boost/cobalt/concepts.hpp>
  11. #include <boost/cobalt/experimental/frame.hpp>
  12. #include <boost/cobalt/config.hpp>
  13. #include <coroutine>
  14. #include <new>
  15. namespace boost::cobalt::experimental
  16. {
  17. template<typename, typename ...>
  18. struct context;
  19. namespace detail
  20. {
  21. template<typename Promise>
  22. struct context_frame : frame<context_frame<Promise>, Promise>
  23. {
  24. boost::context::fiber caller, callee;
  25. void (*after_resume)(context_frame *, void *) = nullptr;
  26. void * after_resume_p;
  27. template<typename ... Args>
  28. requires std::constructible_from<Promise, Args...>
  29. context_frame(Args && ... args) : frame<context_frame, Promise>(args...) {}
  30. template<typename ... Args>
  31. requires (!std::constructible_from<Promise, Args...> && std::is_default_constructible_v<Promise>)
  32. context_frame(Args && ...) {}
  33. void resume()
  34. {
  35. callee = std::move(callee).resume();
  36. if (auto af = std::exchange(after_resume, nullptr))
  37. af(this, after_resume_p);
  38. }
  39. void destroy()
  40. {
  41. auto c = std::exchange(callee, {});
  42. this->~context_frame();
  43. }
  44. template<typename Awaitable>
  45. auto do_resume(void * )
  46. {
  47. return +[](context_frame * this_, void * p)
  48. {
  49. auto aw_ = static_cast<Awaitable*>(p);
  50. auto h = std::coroutine_handle<Promise>::from_address(this_) ;
  51. aw_->await_suspend(h);
  52. };
  53. }
  54. template<typename Awaitable>
  55. auto do_resume(bool * )
  56. {
  57. return +[](context_frame * this_, void * p)
  58. {
  59. auto aw_ = static_cast<Awaitable*>(p);
  60. auto h = std::coroutine_handle<Promise>::from_address(this_) ;
  61. if (!aw_->await_suspend(h))
  62. h.resume();
  63. };
  64. }
  65. template<typename Awaitable, typename Promise_>
  66. auto do_resume(std::coroutine_handle<Promise_> * )
  67. {
  68. return +[](context_frame * this_, void * p)
  69. {
  70. auto aw_ = static_cast<Awaitable*>(p);
  71. auto h = std::coroutine_handle<Promise>::from_address(this_) ;
  72. aw_->await_suspend(h).resume();
  73. };
  74. }
  75. template<typename Awaitable>
  76. auto do_await(Awaitable aw)
  77. {
  78. if (!aw.await_ready())
  79. {
  80. after_resume_p = & aw;
  81. after_resume = do_resume<Awaitable>(
  82. static_cast<decltype(aw.await_suspend(std::declval<std::coroutine_handle<Promise>>()))*>(nullptr)
  83. );
  84. caller = std::move(caller).resume();
  85. }
  86. return aw.await_resume();
  87. }
  88. template<typename Handle, typename ... Args>
  89. context<Handle, Args...> get_context()
  90. {
  91. return context<Handle, Args...>{this};
  92. }
  93. };
  94. template<typename Traits, typename Promise, typename ... Args>
  95. struct stack_allocator : boost::context::fixedsize_stack {};
  96. template<typename Traits, typename Promise, typename ... Args>
  97. requires requires {Promise::operator new(std::size_t{});}
  98. struct stack_allocator<Traits, Promise, Args...>
  99. {
  100. boost::context::stack_context allocate()
  101. {
  102. const auto size = Traits::default_size();
  103. const auto p = Promise::operator new(size);
  104. boost::context::stack_context sctx;
  105. sctx.size = size;
  106. sctx.sp = static_cast< char * >( p) + sctx.size;
  107. return sctx;
  108. }
  109. void deallocate( boost::context::stack_context & sctx) noexcept
  110. {
  111. void * vp = static_cast< char * >( sctx.sp) - sctx.size;
  112. Promise::operator delete(vp, sctx.size);
  113. }
  114. };
  115. template<typename Traits, typename Promise, typename ... Args>
  116. requires requires {Promise::operator new(std::size_t{}, std::decay_t<Args&>()...);}
  117. struct stack_allocator<Traits, Promise, Args...>
  118. {
  119. std::tuple<Args&...> args;
  120. boost::context::stack_context allocate()
  121. {
  122. const auto size = Traits::default_size();
  123. const auto p = std::apply(
  124. [size](auto & ... args_)
  125. {
  126. return Promise::operator new(size, args_...);
  127. }, args);
  128. boost::context::stack_context sctx;
  129. sctx.size = size;
  130. sctx.sp = static_cast< char * >( p) + sctx.size;
  131. return sctx;
  132. }
  133. void deallocate( boost::context::stack_context & sctx) noexcept
  134. {
  135. void * vp = static_cast< char * >( sctx.sp) - sctx.size;
  136. Promise::operator delete(vp, sctx.size);
  137. }
  138. };
  139. struct await_transform_base
  140. {
  141. struct dummy {};
  142. void await_transform(dummy);
  143. };
  144. template<typename T>
  145. struct await_transform_impl : await_transform_base, T
  146. {
  147. };
  148. template<typename T>
  149. concept has_await_transform = ! requires (await_transform_impl<T> & p) {p.await_transform(await_transform_base::dummy{});};
  150. template<typename Promise, typename Context, typename Func, typename ... Args>
  151. void do_return(std::true_type /* is_void */, Promise& promise, Context ctx, Func && func, Args && ... args)
  152. {
  153. std::forward<Func>(func)(ctx, std::forward<Args>(args)...);
  154. promise.return_void();
  155. }
  156. template<typename Promise, typename Context, typename Func, typename ... Args>
  157. void do_return(std::false_type /* is_void */, Promise& promise, Context ctx, Func && func, Args && ... args)
  158. {
  159. promise.return_value(std::forward<Func>(func)(ctx, std::forward<Args>(args)...));
  160. }
  161. }
  162. template<typename Return, typename ... Args>
  163. struct context
  164. {
  165. using return_type = Return;
  166. using promise_type = typename std::coroutine_traits<Return, Args...>::promise_type;
  167. promise_type & promise() {return frame_->promise;}
  168. const promise_type & promise() const {return frame_->promise;}
  169. template<typename Return_, typename ... Args_>
  170. requires std::same_as<promise_type, typename context<Return_, Args_...>::promise_type>
  171. constexpr operator context<Return_, Args_...>() const
  172. {
  173. return {frame_};
  174. }
  175. template<typename Awaitable>
  176. requires (detail::has_await_transform<promise_type> &&
  177. requires (promise_type & pro, Awaitable && aw)
  178. {
  179. {pro.await_transform(std::forward<Awaitable>(aw))} -> awaitable_type<promise_type>;
  180. })
  181. auto await(Awaitable && aw)
  182. {
  183. return frame_->do_await(frame_->promise.await_transform(std::forward<Awaitable>(aw)));
  184. }
  185. template<typename Awaitable>
  186. requires (detail::has_await_transform<promise_type> &&
  187. requires (promise_type & pro, Awaitable && aw)
  188. {
  189. {pro.await_transform(std::forward<Awaitable>(aw).operator co_await())} -> awaitable_type<promise_type>;
  190. })
  191. auto await(Awaitable && aw)
  192. {
  193. return frame_->do_await(frame_->promise.await_transform(std::forward<Awaitable>(aw)).operator co_await());
  194. }
  195. template<typename Awaitable>
  196. requires (detail::has_await_transform<promise_type> &&
  197. requires (promise_type & pro, Awaitable && aw)
  198. {
  199. {operator co_await(pro.await_transform(std::forward<Awaitable>(aw)))} -> awaitable_type<promise_type>;
  200. })
  201. auto await(Awaitable && aw)
  202. {
  203. return frame_->do_await(operator co_await(frame_->promise.await_transform(std::forward<Awaitable>(aw))));
  204. }
  205. template<awaitable_type<promise_type> Awaitable>
  206. requires (!detail::has_await_transform<promise_type> )
  207. auto await(Awaitable && aw)
  208. {
  209. return frame_->do_await(std::forward<Awaitable>(aw));
  210. }
  211. template<typename Awaitable>
  212. requires (!detail::has_await_transform<promise_type>
  213. && requires (Awaitable && aw) {{operator co_await(std::forward<Awaitable>(aw))} -> awaitable_type<promise_type>;})
  214. auto await(Awaitable && aw)
  215. {
  216. return frame_->do_await(operator co_await(std::forward<Awaitable>(aw)));
  217. }
  218. template<typename Awaitable>
  219. requires (!detail::has_await_transform<promise_type>
  220. && requires (Awaitable && aw) {{std::forward<Awaitable>(aw).operator co_await()} -> awaitable_type<promise_type>;})
  221. auto await(Awaitable && aw)
  222. {
  223. return frame_->do_await(std::forward<Awaitable>(aw).operator co_await());
  224. }
  225. template<typename Yield>
  226. requires requires (promise_type & pro, Yield && value) {{pro.yield_value(std::forward<Yield>(value))} -> awaitable_type<promise_type>;}
  227. auto yield(Yield && value)
  228. {
  229. frame_->do_await(frame_->promise.yield_value(std::forward<Yield>(value)));
  230. }
  231. private:
  232. context(detail::context_frame<promise_type> * frame) : frame_(frame) {}
  233. template<typename, typename ...>
  234. friend struct context;
  235. //template<typename >
  236. friend struct detail::context_frame<promise_type>;
  237. detail::context_frame<promise_type> * frame_;
  238. };
  239. template<typename Return, typename ... Args, std::invocable<context<Return, Args...>, Args...> Func, typename StackAlloc>
  240. auto make_context(Func && func, std::allocator_arg_t, StackAlloc && salloc, Args && ... args)
  241. {
  242. auto sctx_ = salloc.allocate();
  243. using promise_type = typename std::coroutine_traits<Return, Args...>::promise_type;
  244. void * p = static_cast<char*>(sctx_.sp) - sizeof(detail::context_frame<promise_type>);
  245. auto sz = sctx_.size - sizeof(detail::context_frame<promise_type>);
  246. if (auto diff = reinterpret_cast<std::uintptr_t>(p) % alignof(detail::context_frame<promise_type>); diff != 0u)
  247. {
  248. p = static_cast<char*>(p) - diff;
  249. sz -= diff;
  250. }
  251. boost::context::preallocated psc{p, sz, sctx_};
  252. auto f = new (p) detail::context_frame<promise_type>(args...);
  253. auto res = f->promise.get_return_object();
  254. constexpr auto is_always_lazy =
  255. requires (promise_type & pro) {{pro.initial_suspend()} -> std::same_as<std::suspend_always>;}
  256. && noexcept(f->promise.initial_suspend());
  257. struct invoker
  258. {
  259. detail::context_frame<promise_type> * frame;
  260. mutable Func func;
  261. mutable std::tuple<Args...> args;
  262. invoker(detail::context_frame<promise_type> * frame, Func && func, Args && ... args)
  263. : frame(frame), func(std::forward<Func>(func)), args(std::forward<Args>(args)...)
  264. {
  265. }
  266. boost::context::fiber operator()(boost::context::fiber && f) const
  267. {
  268. auto & promise = frame->promise;
  269. frame->caller = std::move(f);
  270. try
  271. {
  272. if (!is_always_lazy)
  273. frame->do_await(promise.initial_suspend());
  274. std::apply(
  275. [&](auto && ... args_)
  276. {
  277. auto ctx = frame->template get_context<Return, Args...>();
  278. using return_type = decltype(std::forward<Func>(func)(ctx, std::forward<Args>(args_)...));
  279. detail::do_return(std::is_void<return_type>{}, frame->promise, ctx,
  280. std::forward<Func>(func), std::forward<Args>(args_)...);
  281. },
  282. std::move(args));
  283. }
  284. catch (boost::context::detail::forced_unwind &) { throw; }
  285. catch (...) {promise.unhandled_exception();}
  286. static_assert(noexcept(promise.final_suspend()));
  287. frame->do_await(promise.final_suspend());
  288. return std::move(frame->caller);
  289. }
  290. };
  291. f->callee = boost::context::fiber{
  292. std::allocator_arg, psc, std::forward<StackAlloc>(salloc),
  293. invoker(f, std::forward<Func>(func), std::forward<Args>(args)...)};
  294. if constexpr (is_always_lazy)
  295. f->promise.initial_suspend();
  296. else
  297. f->resume();
  298. return res;
  299. }
  300. template<typename Return, typename ... Args, std::invocable<context<Return, Args...>, Args...> Func>
  301. auto make_context(Func && func, Args && ... args)
  302. {
  303. return make_context<Return>(std::forward<Func>(func), std::allocator_arg,
  304. boost::context::fixedsize_stack(), std::forward<Args>(args)...);
  305. }
  306. template<typename ... Args, typename Func, typename StackAlloc>
  307. auto make_context(Func && func, std::allocator_arg_t, StackAlloc && salloc, Args && ... args)
  308. {
  309. return make_context<typename std::tuple_element_t<0u, callable_traits::args_t<Func>>::return_type>(
  310. std::forward<Func>(func), std::allocator_arg, std::forward<StackAlloc>(salloc), std::forward<Args>(args)...
  311. );
  312. }
  313. template<typename ... Args, typename Func>
  314. auto make_context(Func && func, Args && ... args)
  315. {
  316. return make_context<typename std::tuple_element_t<0u, callable_traits::args_t<Func>>::return_type>(
  317. std::forward<Func>(func), std::forward<Args>(args)...
  318. );
  319. }
  320. }
  321. #endif //BOOST_COBALT_EXPERIMENTAL_CONTEXT_HPP