// Copyright (c) 2024 Klemens D. Morgenstern
//
// Distributed under the Boost Software License, Version 1.0. (See accompanying
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
#ifndef BOOST_COBALT_EXPERIMENTAL_CONTEXT_HPP
#define BOOST_COBALT_EXPERIMENTAL_CONTEXT_HPP

#include <boost/callable_traits/args.hpp>
#include <boost/context/fiber.hpp>
#include <boost/context/fixedsize_stack.hpp>

#include <boost/cobalt/concepts.hpp>
#include <boost/cobalt/experimental/frame.hpp>
#include <boost/cobalt/config.hpp>
#include <coroutine>
#include <new>

namespace boost::cobalt::experimental
{

template<typename, typename ...>
struct context;

namespace detail
{

template<typename Promise>
struct context_frame : frame<context_frame<Promise>,  Promise>
{
  boost::context::fiber caller, callee;

  void (*after_resume)(context_frame *, void *) = nullptr;
  void * after_resume_p;

  template<typename ... Args>
    requires std::constructible_from<Promise, Args...>
  context_frame(Args && ... args) : frame<context_frame, Promise>(args...) {}

  template<typename ... Args>
    requires (!std::constructible_from<Promise, Args...> && std::is_default_constructible_v<Promise>)
  context_frame(Args && ...) {}

  void resume()
  {
    callee = std::move(callee).resume();
    if (auto af = std::exchange(after_resume, nullptr))
      af(this, after_resume_p);
  }
  void destroy()
  {
    auto c = std::exchange(callee, {});
    this->~context_frame();
  }

  template<typename Awaitable>
  auto do_resume(void * )
  {
    return +[](context_frame * this_, void * p)
    {
      auto aw_ = static_cast<Awaitable*>(p);
      auto h = std::coroutine_handle<Promise>::from_address(this_) ;
      aw_->await_suspend(h);
    };
  }

  template<typename Awaitable>
  auto do_resume(bool * )
  {
    return +[](context_frame * this_, void * p)
    {
      auto aw_ = static_cast<Awaitable*>(p);
      auto h = std::coroutine_handle<Promise>::from_address(this_) ;
      if (!aw_->await_suspend(h))
        h.resume();
    };
  }

  template<typename Awaitable, typename Promise_>
  auto do_resume(std::coroutine_handle<Promise_> * )
  {
    return +[](context_frame * this_, void * p)
    {
      auto aw_ = static_cast<Awaitable*>(p);
      auto h = std::coroutine_handle<Promise>::from_address(this_) ;
      aw_->await_suspend(h).resume();
    };
  }


  template<typename Awaitable>
  auto do_await(Awaitable aw)
  {
    if (!aw.await_ready())
    {
      after_resume_p = & aw;
      after_resume = do_resume<Awaitable>(
          static_cast<decltype(aw.await_suspend(std::declval<std::coroutine_handle<Promise>>()))*>(nullptr)
          );
      caller = std::move(caller).resume();
    }
    return aw.await_resume();
  }

  template<typename Handle, typename ... Args>
  context<Handle, Args...> get_context()
  {
    return context<Handle, Args...>{this};
  }

};


template<typename Traits, typename Promise, typename ... Args>
struct stack_allocator : boost::context::fixedsize_stack {};

template<typename Traits, typename Promise, typename ... Args>
requires requires {Promise::operator new(std::size_t{});}
struct stack_allocator<Traits, Promise, Args...>
{
  boost::context::stack_context allocate()
  {
    const auto size = Traits::default_size();
    const auto p = Promise::operator new(size);

    boost::context::stack_context sctx;
    sctx.size = size;
    sctx.sp = static_cast< char * >( p) + sctx.size;
    return sctx;
  }

  void deallocate( boost::context::stack_context & sctx) noexcept
  {
    void * vp = static_cast< char * >( sctx.sp) - sctx.size;
    Promise::operator delete(vp, sctx.size);
  }
};

template<typename Traits, typename Promise, typename ... Args>
  requires requires {Promise::operator new(std::size_t{}, std::decay_t<Args&>()...);}
struct stack_allocator<Traits, Promise, Args...>
{
  std::tuple<Args&...> args;

  boost::context::stack_context allocate()
  {
    const auto size = Traits::default_size();
    const auto p = std::apply(
        [size](auto & ... args_)
        {
          return Promise::operator new(size, args_...);
        }, args);

    boost::context::stack_context sctx;
    sctx.size = size;
    sctx.sp = static_cast< char * >( p) + sctx.size;
    return sctx;
  }

  void deallocate( boost::context::stack_context & sctx) noexcept
  {
    void * vp = static_cast< char * >( sctx.sp) - sctx.size;
    Promise::operator delete(vp, sctx.size);
  }
};


struct await_transform_base
{
  struct dummy {};
  void await_transform(dummy);
};

template<typename T>
struct await_transform_impl : await_transform_base, T
{
};

template<typename T>
concept has_await_transform = ! requires (await_transform_impl<T> & p) {p.await_transform(await_transform_base::dummy{});};

template<typename Promise, typename Context, typename Func, typename ... Args>
void do_return(std::true_type /* is_void */, Promise& promise, Context ctx, Func && func, Args && ... args)
{
  std::forward<Func>(func)(ctx, std::forward<Args>(args)...);
  promise.return_void();
}

template<typename Promise, typename Context, typename Func, typename ... Args>
void do_return(std::false_type /* is_void */, Promise& promise, Context ctx, Func && func, Args && ... args)
{
  promise.return_value(std::forward<Func>(func)(ctx, std::forward<Args>(args)...));
}


}

template<typename Return, typename ... Args>
struct context
{
  using return_type = Return;
  using promise_type = typename std::coroutine_traits<Return, Args...>::promise_type;

        promise_type & promise()       {return frame_->promise;}
  const promise_type & promise() const {return frame_->promise;}
  template<typename Return_, typename ... Args_>
    requires std::same_as<promise_type, typename context<Return_, Args_...>::promise_type>
  constexpr operator context<Return_, Args_...>() const
  {
    return {frame_};
  }

  template<typename Awaitable>
    requires (detail::has_await_transform<promise_type> &&
        requires (promise_type & pro, Awaitable && aw)
        {
          {pro.await_transform(std::forward<Awaitable>(aw))} -> awaitable_type<promise_type>;
        })
  auto await(Awaitable && aw)
  {
    return frame_->do_await(frame_->promise.await_transform(std::forward<Awaitable>(aw)));
  }

  template<typename Awaitable>
    requires (detail::has_await_transform<promise_type> &&
        requires (promise_type & pro, Awaitable && aw)
        {
          {pro.await_transform(std::forward<Awaitable>(aw).operator co_await())} -> awaitable_type<promise_type>;
        })
  auto await(Awaitable && aw)
  {
    return frame_->do_await(frame_->promise.await_transform(std::forward<Awaitable>(aw)).operator co_await());
  }
  template<typename Awaitable>
    requires (detail::has_await_transform<promise_type> &&
        requires (promise_type & pro, Awaitable && aw)
        {
          {operator co_await(pro.await_transform(std::forward<Awaitable>(aw)))} -> awaitable_type<promise_type>;
        })
  auto await(Awaitable && aw)
  {
    return frame_->do_await(operator co_await(frame_->promise.await_transform(std::forward<Awaitable>(aw))));
  }
  template<awaitable_type<promise_type> Awaitable>
    requires (!detail::has_await_transform<promise_type> )
  auto await(Awaitable && aw)
  {
    return frame_->do_await(std::forward<Awaitable>(aw));
  }

  template<typename Awaitable>
    requires (!detail::has_await_transform<promise_type>
          && requires (Awaitable && aw) {{operator co_await(std::forward<Awaitable>(aw))} -> awaitable_type<promise_type>;})
  auto await(Awaitable && aw)
  {
    return frame_->do_await(operator co_await(std::forward<Awaitable>(aw)));
  }

  template<typename Awaitable>
    requires (!detail::has_await_transform<promise_type>
            && requires (Awaitable && aw) {{std::forward<Awaitable>(aw).operator co_await()} -> awaitable_type<promise_type>;})
  auto await(Awaitable && aw)
  {
    return frame_->do_await(std::forward<Awaitable>(aw).operator co_await());
  }

  template<typename Yield>
    requires requires (promise_type & pro, Yield && value) {{pro.yield_value(std::forward<Yield>(value))} -> awaitable_type<promise_type>;}
  auto yield(Yield && value)
  {
    frame_->do_await(frame_->promise.yield_value(std::forward<Yield>(value)));
  }

 private:

  context(detail::context_frame<promise_type> * frame) : frame_(frame) {}
  template<typename, typename ...>
  friend struct context;

  //template<typename >
  friend struct detail::context_frame<promise_type>;

  detail::context_frame<promise_type> * frame_;
};

template<typename Return, typename ... Args, std::invocable<context<Return, Args...>, Args...> Func, typename StackAlloc>
auto make_context(Func && func, std::allocator_arg_t, StackAlloc  && salloc, Args && ... args)
{
  auto sctx_ = salloc.allocate();

  using promise_type = typename std::coroutine_traits<Return, Args...>::promise_type;
  void * p = static_cast<char*>(sctx_.sp) - sizeof(detail::context_frame<promise_type>);
  auto sz = sctx_.size - sizeof(detail::context_frame<promise_type>);

  if (auto diff = reinterpret_cast<std::uintptr_t>(p) % alignof(detail::context_frame<promise_type>); diff != 0u)
  {
    p = static_cast<char*>(p) - diff;
    sz -= diff;
  }

  boost::context::preallocated psc{p, sz, sctx_};
  auto f = new (p) detail::context_frame<promise_type>(args...);

  auto res = f->promise.get_return_object();

  constexpr auto is_always_lazy =
        requires (promise_type & pro) {{pro.initial_suspend()} -> std::same_as<std::suspend_always>;}
      && noexcept(f->promise.initial_suspend());

  struct invoker
  {
    detail::context_frame<promise_type> * frame;
    mutable Func func;
    mutable std::tuple<Args...> args;

    invoker(detail::context_frame<promise_type> * frame, Func && func, Args && ...  args)
        : frame(frame), func(std::forward<Func>(func)), args(std::forward<Args>(args)...)
    {
    }

    boost::context::fiber operator()(boost::context::fiber && f) const
    {
      auto & promise = frame->promise;
      frame->caller = std::move(f);

      try
      {
        if (!is_always_lazy)
          frame->do_await(promise.initial_suspend());

        std::apply(
            [&](auto && ... args_)
            {
              auto ctx = frame->template get_context<Return, Args...>();
              using return_type = decltype(std::forward<Func>(func)(ctx, std::forward<Args>(args_)...));

              detail::do_return(std::is_void<return_type>{}, frame->promise, ctx,
                                std::forward<Func>(func), std::forward<Args>(args_)...);
            },
            std::move(args));
      }
      catch (boost::context::detail::forced_unwind &) { throw; }
      catch (...) {promise.unhandled_exception();}

      static_assert(noexcept(promise.final_suspend()));
      frame->do_await(promise.final_suspend());
      return std::move(frame->caller);
    }

  };

  f->callee = boost::context::fiber{
    std::allocator_arg, psc, std::forward<StackAlloc>(salloc),
    invoker(f, std::forward<Func>(func), std::forward<Args>(args)...)};

  if constexpr (is_always_lazy)
    f->promise.initial_suspend();
  else
    f->resume();

  return res;
}

template<typename Return, typename ... Args, std::invocable<context<Return, Args...>, Args...> Func>
auto make_context(Func && func, Args && ... args)
{
  return make_context<Return>(std::forward<Func>(func), std::allocator_arg,
                      boost::context::fixedsize_stack(), std::forward<Args>(args)...);
}


template<typename ... Args, typename Func, typename StackAlloc>
auto make_context(Func && func, std::allocator_arg_t, StackAlloc  && salloc, Args && ... args)
{
  return make_context<typename std::tuple_element_t<0u, callable_traits::args_t<Func>>::return_type>(
      std::forward<Func>(func), std::allocator_arg, std::forward<StackAlloc>(salloc), std::forward<Args>(args)...
      );
}


template<typename ... Args, typename Func>
auto make_context(Func && func, Args && ... args)
{
  return make_context<typename std::tuple_element_t<0u, callable_traits::args_t<Func>>::return_type>(
      std::forward<Func>(func), std::forward<Args>(args)...
  );
}

}

#endif //BOOST_COBALT_EXPERIMENTAL_CONTEXT_HPP