基于c++20 coroutine实现对称协程

这是我很久以前研究c++20 coroutine的过程,如今回看有些生疏,故写此文章用于记录和复习。

直接上代码

#include <coroutine>
#include <vector>
#include <iostream>
#include <optional>
#include <memory>
using namespace std;


class symmetricTask {
public:
	class promise_type;
	// 代表这个协程函数的handle
	std::coroutine_handle<promise_type> handle;
	class promise_type {
	public:
		// 调用该协程的父协程的handle
		std::coroutine_handle<promise_type> father;
		class initCallAwait;
		class finalAwait;
		class normalCallAwait;
		promise_type()
		{
		}
		~promise_type()
		{
		}
		void return_void() {}
		auto get_return_object()
		{
			return symmetricTask{ coroutine_handle<promise_type>::from_promise(*this) };
		}
		auto initial_suspend()
		{
			return suspend_always{};
		}
		finalAwait final_suspend() noexcept {
			return {};
		}
		void unhandled_exception()
		{
			std::exit(1);
		}
		initCallAwait await_transform(symmetricTask&& t) {
			return initCallAwait(t.handle);
		}
		normalCallAwait await_transform(int i) {
			cout << "transform int" << endl;
			return normalCallAwait{};
		}
		class initCallAwait {
		public:
			friend promise_type;
			std::coroutine_handle<promise_type> handle;
			explicit initCallAwait(std::coroutine_handle<promise_type> h) noexcept : handle(h) {}
			bool await_ready() {
				if (handle) {
					return false;
				}
				return true;
			}
			void await_resume() noexcept {
				cout << "init resume" << endl;
			}
			std::coroutine_handle<> await_suspend(std::coroutine_handle<promise_type> h) noexcept {
				cout << "init suspend h:" << h.address() << " this->handle:" << handle.address() << endl;
				handle.promise().father = h;
				return handle;
			}
		};
		class normalCallAwait {
		public:
			friend promise_type;
			bool await_ready() {
				return false;
			}
			void await_resume() noexcept {
				cout << "normal resume" << endl;
			}
			void await_suspend(std::coroutine_handle<promise_type> h) noexcept;
		};
		class finalAwait {
		public:
			friend promise_type;
			std::coroutine_handle<promise_type> handle;
			bool await_ready() noexcept {
				return false;
			}
			void await_resume() noexcept {
				cout << "final resume" << endl;
			}
			std::coroutine_handle<> await_suspend(std::coroutine_handle<promise_type> h) noexcept {
				if (h.promise().father) {
					cout << "final suspend. prepare resume" << h.promise().father.address() << endl;
					return h.promise().father;
				}
				cout << "final suspend. without father" << endl;
				return std::noop_coroutine();
			}
		};
	};
};

vector<std::coroutine_handle<symmetricTask::promise_type>> mgr;

void symmetricTask::promise_type::normalCallAwait::await_suspend(std::coroutine_handle<promise_type> h) noexcept {
	cout << "normal suspend push_back h:" << h.address() << endl;
	mgr.push_back(h);
}

symmetricTask b() noexcept {
	cout << "function b" << endl;
	co_await 1;
	cout << "function b end" << endl;
}

symmetricTask a() noexcept {
	cout << "function a" << endl;
	co_await b();
	cout << "function a end" << endl;
}

symmetricTask asyncCall() noexcept {
	cout << "asyncCall" << endl;
	co_await a();
	cout << "asyncCall end" << endl;
}

int main() {
	cout << "start" << endl;
	auto t = asyncCall();
	t.handle.resume();
	cout << "start resume" << endl;
	for (auto t : mgr) {
		t.resume();
	}
}

调用结果如下

start
asyncCall
init suspend h:000002700D478740 this->handle:000002700D478880
function a
init suspend h:000002700D478880 this->handle:000002700D4789C0
function b
transform int
normal suspend push_back h:000002700D4789C0
start resume
normal resume
function b end
final suspend. prepare resume000002700D478880
init resume
function a end
final suspend. prepare resume000002700D478740
init resume
asyncCall end
final suspend. without father

我们将代码展开看下具体的实现,为了能更加清楚具体的实现细节,我修改了上面abc函数的代码

symmetricTask b() noexcept {
	int i = 1;
  	if (1)
	{
		int j = 2;
		co_await 1;
		int k = 3;
	} else {
    	int y = 5;
    }
	int x = 4;
}

symmetricTask a() noexcept {
	co_await b();
}

symmetricTask asyncCall() noexcept {
	co_await a();
}

int main() {
	auto t = asyncCall();
	t.handle.resume();
	for (auto t : mgr) {
		t.resume();
	}
}

来看看他的内部实现吧~

/*************************************************************************************
 * NOTE: The coroutine transformation you've enabled is a hand coded transformation! *
 *       Most of it is _not_ present in the AST. What you see is an approximation.   *
 *************************************************************************************/
#include <experimental/coroutine>
#include <vector>
#include <iostream>
#include <optional>
#include <memory>
using namespace std;


class symmetricTask
{
  
  public: 
  class promise_type;
  std::experimental::coroutine_handle<promise_type> handle;
  class promise_type
  {
    
    public: 
    std::experimental::coroutine_handle<promise_type> father;
    class initCallAwait;
    class finalAwait;
    class normalCallAwait;
    inline promise_type()
    : father{std::experimental::coroutine_handle<promise_type>()}
    {
    }
    
    inline ~promise_type() noexcept
    {
    }
    
    inline void return_void()
    {
    }
    
    inline symmetricTask get_return_object()
    {
      return symmetricTask{std::experimental::coroutine_handle<promise_type>::from_promise(*this)};
    }
    
    inline std::experimental::suspend_always initial_suspend()
    {
      return std::experimental::suspend_always{};
    }
    
    inline finalAwait final_suspend() noexcept
    {
      return {std::experimental::coroutine_handle<promise_type>{}};
    }
    
    inline void unhandled_exception()
    {
      exit(1);
    }
    
    inline initCallAwait await_transform(symmetricTask && t)
    {
      return initCallAwait(std::experimental::coroutine_handle<promise_type>(t.handle));
    }
    
    inline normalCallAwait await_transform(int i)
    {
      return normalCallAwait{};
    }
    
    class initCallAwait
    {
      
      public: 
      friend symmetricTask::promise_type;
      std::experimental::coroutine_handle<symmetricTask::promise_type> handle;
      inline explicit initCallAwait(std::experimental::coroutine_handle<symmetricTask::promise_type> h) noexcept
      : handle{std::experimental::coroutine_handle<symmetricTask::promise_type>(h)}
      {
      }
      
      inline bool await_ready()
      {
        if(static_cast<bool>(static_cast<const std::experimental::coroutine_handle<void>&>(this->handle).operator bool())) {
          return false;
        } 
        
        return true;
      }
      
      inline void await_resume() noexcept
      {
        std::operator<<(std::cout, "init resume").operator<<(std::endl);
      }
      
      inline std::experimental::coroutine_handle<void> await_suspend(std::experimental::coroutine_handle<symmetricTask::promise_type> h) noexcept
      {
        this->handle.promise().father.operator=(h);
        return std::experimental::coroutine_handle<void>(static_cast<const std::experimental::coroutine_handle<void>&>(this->handle));
      }
      
    };
    
    class normalCallAwait
    {
      
      public: 
      friend symmetricTask::promise_type;
      inline bool await_ready()
      {
        return false;
      }
      
      inline void await_resume() noexcept
      {
      }
      
      void await_suspend(std::experimental::coroutine_handle<symmetricTask::promise_type> h) noexcept;
      
    };
    
    class finalAwait
    {
      
      public: 
      friend symmetricTask::promise_type;
      std::experimental::coroutine_handle<symmetricTask::promise_type> handle;
      inline bool await_ready() noexcept
      {
        return false;
      }
      
      inline void await_resume() noexcept
      {
      }
      
      inline std::experimental::coroutine_handle<void> await_suspend(std::experimental::coroutine_handle<symmetricTask::promise_type> h) noexcept
      {
        if(static_cast<bool>(static_cast<const std::experimental::coroutine_handle<void>&>(h.promise().father).operator bool())) {
          return std::experimental::coroutine_handle<void>(static_cast<const std::experimental::coroutine_handle<void>&>(h.promise().father));
        } 
        
        return std::experimental::coroutine_handle<void>(static_cast<std::experimental::coroutine_handle<void> &&>(std::experimental::noop_coroutine()));
      }
      
    };
    
  };
  
  // inline constexpr symmetricTask(symmetricTask &&) noexcept = default;
};



std::vector<std::experimental::coroutine_handle<symmetricTask::promise_type>, std::allocator<std::experimental::coroutine_handle<symmetricTask::promise_type> > > mgr = std::vector<std::experimental::coroutine_handle<symmetricTask::promise_type>, std::allocator<std::experimental::coroutine_handle<symmetricTask::promise_type> > >();


void symmetricTask::promise_type::normalCallAwait::await_suspend(std::experimental::coroutine_handle<promise_type> h) noexcept {
	mgr.push_back(h);
}

struct __bFrame
{
  void (*resume_fn)(__bFrame *);
  void (*destroy_fn)(__bFrame *);
  std::experimental::__coroutine_traits_sfinae<symmetricTask>::promise_type __promise;
  int __suspend_index;
  bool __initial_await_suspend_called;
  int i;
  int j;
  int k;
  int y;
  int x;
  std::experimental::suspend_always __suspend_103_15;
  symmetricTask::promise_type::normalCallAwait __suspend_108_3;
  symmetricTask::promise_type::finalAwait __suspend_103_15_1;
};

symmetricTask b() noexcept
{
  /* Allocate the frame including the promise */
  __bFrame * __f = reinterpret_cast<__bFrame *>(operator new(__builtin_coro_size()));
  __f->__suspend_index = 0;
  __f->__initial_await_suspend_called = false;
  
  /* Construct the promise. */
  new (&__f->__promise)std::experimental::__coroutine_traits_sfinae<symmetricTask>::promise_type{};
  
  symmetricTask __coro_gro = __f->__promise.get_return_object() /* NRVO variable */;
  
  /* Forward declare the resume and destroy function. */
  void __bResume(__bFrame * __f);
  void __bDestroy(__bFrame * __f);
  
  /* Assign the resume and destroy function pointers. */
  __f->resume_fn = &__bResume;
  __f->destroy_fn = &__bDestroy;
  
  /* Call the made up function with the coroutine body for initial suspend.
     This function will be called subsequently by coroutine_handle<>::resume()
     which calls __builtin_coro_resume(__handle_) */
  __bResume(__f);
  
  
  return __coro_gro;
}

/* This function invoked by coroutine_handle<>::resume() */
void __bResume(__bFrame * __f)
{
  try 
  {
    /* Create a switch to get to the correct resume point */
    switch(__f->__suspend_index) {
      case 0: break;
      case 1: goto __resume_b_1;
      case 2: goto __resume_b_2;
    }
    
    /* co_await insights.cpp:103 */
    __f->__suspend_103_15 = __f->__promise.initial_suspend();
    if(!__f->__suspend_103_15.await_ready()) {
      __f->__suspend_103_15.await_suspend(std::experimental::coroutine_handle<void>(std::experimental::coroutine_handle<symmetricTask::promise_type>::from_address(static_cast<void *>(__f))));
      __f->__suspend_index = 1;
      __f->__initial_await_suspend_called = true;
      return;
    } 
    
    __resume_b_1:
    __f->__suspend_103_15.await_resume();
    __f->i = 1;
    if(1) {
      __f->j = 2;
      
      /* co_await insights.cpp:108 */
      __f->__suspend_108_3 = __f->__promise.await_transform(1);
      if(!__f->__suspend_108_3.await_ready()) {
        __f->__suspend_108_3.await_suspend(std::experimental::coroutine_handle<symmetricTask::promise_type>::from_address(static_cast<void *>(__f)));
        __f->__suspend_index = 2;
        return;
      } 
      
      __resume_b_2:
      __f->__suspend_108_3.await_resume();
      __f->k = 3;
    } else {
      __f->y = 5;
    } 
    
    __f->x = 4;
    goto __final_suspend;
  } catch(...) {
    if(!__f->__initial_await_suspend_called) {
      throw ;
    } 
    
    __f->__promise.unhandled_exception();
  }
  
  __final_suspend:
  
  /* co_await insights.cpp:103 */
  __f->__suspend_103_15_1 = __f->__promise.final_suspend();
  if(!__f->__suspend_103_15_1.await_ready()) {
    __builtin_coro_resume(__f->__suspend_103_15_1.await_suspend(std::experimental::coroutine_handle<symmetricTask::promise_type>::from_address(static_cast<void *>(__f))).address());
  } 
  
  ;
}

/* This function invoked by coroutine_handle<>::destroy() */
void __bDestroy(__bFrame * __f)
{
  /* destroy all variables with dtors */
  __f->~__bFrame();
  /* Deallocating the coroutine frame */
  operator delete(__builtin_coro_free(static_cast<void *>(__f)));
}



struct __aFrame
{
  void (*resume_fn)(__aFrame *);
  void (*destroy_fn)(__aFrame *);
  std::experimental::__coroutine_traits_sfinae<symmetricTask>::promise_type __promise;
  int __suspend_index;
  bool __initial_await_suspend_called;
  std::experimental::suspend_always __suspend_116_15;
  symmetricTask::promise_type::initCallAwait __suspend_117_2;
  symmetricTask::promise_type::finalAwait __suspend_116_15_1;
};

symmetricTask a() noexcept
{
  /* Allocate the frame including the promise */
  __aFrame * __f = reinterpret_cast<__aFrame *>(operator new(__builtin_coro_size()));
  __f->__suspend_index = 0;
  __f->__initial_await_suspend_called = false;
  
  /* Construct the promise. */
  new (&__f->__promise)std::experimental::__coroutine_traits_sfinae<symmetricTask>::promise_type{};
  
  symmetricTask __coro_gro = __f->__promise.get_return_object() /* NRVO variable */;
  
  /* Forward declare the resume and destroy function. */
  void __aResume(__aFrame * __f);
  void __aDestroy(__aFrame * __f);
  
  /* Assign the resume and destroy function pointers. */
  __f->resume_fn = &__aResume;
  __f->destroy_fn = &__aDestroy;
  
  /* Call the made up function with the coroutine body for initial suspend.
     This function will be called subsequently by coroutine_handle<>::resume()
     which calls __builtin_coro_resume(__handle_) */
  __aResume(__f);
  
  
  return __coro_gro;
}

/* This function invoked by coroutine_handle<>::resume() */
void __aResume(__aFrame * __f)
{
  try 
  {
    /* Create a switch to get to the correct resume point */
    switch(__f->__suspend_index) {
      case 0: break;
      case 1: goto __resume_a_1;
      case 2: goto __resume_a_2;
    }
    
    /* co_await insights.cpp:116 */
    __f->__suspend_116_15 = __f->__promise.initial_suspend();
    if(!__f->__suspend_116_15.await_ready()) {
      __f->__suspend_116_15.await_suspend(std::experimental::coroutine_handle<void>(std::experimental::coroutine_handle<symmetricTask::promise_type>::from_address(static_cast<void *>(__f))));
      __f->__suspend_index = 1;
      __f->__initial_await_suspend_called = true;
      return;
    } 
    
    __resume_a_1:
    __f->__suspend_116_15.await_resume();
    
    /* co_await insights.cpp:117 */
    __f->__suspend_117_2 = __f->__promise.await_transform(b());
    if(!__f->__suspend_117_2.await_ready()) {
      __builtin_coro_resume(__f->__suspend_117_2.await_suspend(std::experimental::coroutine_handle<symmetricTask::promise_type>::from_address(static_cast<void *>(__f))).address());
      __f->__suspend_index = 2;
      return;
    } 
    
    __resume_a_2:
    __f->__suspend_117_2.await_resume();
    goto __final_suspend;
  } catch(...) {
    if(!__f->__initial_await_suspend_called) {
      throw ;
    } 
    
    __f->__promise.unhandled_exception();
  }
  
  __final_suspend:
  
  /* co_await insights.cpp:116 */
  __f->__suspend_116_15_1 = __f->__promise.final_suspend();
  if(!__f->__suspend_116_15_1.await_ready()) {
    __builtin_coro_resume(__f->__suspend_116_15_1.await_suspend(std::experimental::coroutine_handle<symmetricTask::promise_type>::from_address(static_cast<void *>(__f))).address());
  } 
  
  ;
}

/* This function invoked by coroutine_handle<>::destroy() */
void __aDestroy(__aFrame * __f)
{
  /* destroy all variables with dtors */
  __f->~__aFrame();
  /* Deallocating the coroutine frame */
  operator delete(__builtin_coro_free(static_cast<void *>(__f)));
}



struct __asyncCallFrame
{
  void (*resume_fn)(__asyncCallFrame *);
  void (*destroy_fn)(__asyncCallFrame *);
  std::experimental::__coroutine_traits_sfinae<symmetricTask>::promise_type __promise;
  int __suspend_index;
  bool __initial_await_suspend_called;
  std::experimental::suspend_always __suspend_120_15;
  symmetricTask::promise_type::initCallAwait __suspend_121_2;
  symmetricTask::promise_type::finalAwait __suspend_120_15_1;
};

symmetricTask asyncCall() noexcept
{
  /* Allocate the frame including the promise */
  __asyncCallFrame * __f = reinterpret_cast<__asyncCallFrame *>(operator new(__builtin_coro_size()));
  __f->__suspend_index = 0;
  __f->__initial_await_suspend_called = false;
  
  /* Construct the promise. */
  new (&__f->__promise)std::experimental::__coroutine_traits_sfinae<symmetricTask>::promise_type{};
  
  symmetricTask __coro_gro = __f->__promise.get_return_object() /* NRVO variable */;
  
  /* Forward declare the resume and destroy function. */
  void __asyncCallResume(__asyncCallFrame * __f);
  void __asyncCallDestroy(__asyncCallFrame * __f);
  
  /* Assign the resume and destroy function pointers. */
  __f->resume_fn = &__asyncCallResume;
  __f->destroy_fn = &__asyncCallDestroy;
  
  /* Call the made up function with the coroutine body for initial suspend.
     This function will be called subsequently by coroutine_handle<>::resume()
     which calls __builtin_coro_resume(__handle_) */
  __asyncCallResume(__f);
  
  
  return __coro_gro;
}

/* This function invoked by coroutine_handle<>::resume() */
void __asyncCallResume(__asyncCallFrame * __f)
{
  try 
  {
    /* Create a switch to get to the correct resume point */
    switch(__f->__suspend_index) {
      case 0: break;
      case 1: goto __resume_asyncCall_1;
      case 2: goto __resume_asyncCall_2;
    }
    
    /* co_await insights.cpp:120 */
    __f->__suspend_120_15 = __f->__promise.initial_suspend();
    if(!__f->__suspend_120_15.await_ready()) {
      __f->__suspend_120_15.await_suspend(std::experimental::coroutine_handle<void>(std::experimental::coroutine_handle<symmetricTask::promise_type>::from_address(static_cast<void *>(__f))));
      __f->__suspend_index = 1;
      __f->__initial_await_suspend_called = true;
      return;
    } 
    
    __resume_asyncCall_1:
    __f->__suspend_120_15.await_resume();
    
    /* co_await insights.cpp:121 */
    __f->__suspend_121_2 = __f->__promise.await_transform(a());
    if(!__f->__suspend_121_2.await_ready()) {
      __builtin_coro_resume(__f->__suspend_121_2.await_suspend(std::experimental::coroutine_handle<symmetricTask::promise_type>::from_address(static_cast<void *>(__f))).address());
      __f->__suspend_index = 2;
      return;
    } 
    
    __resume_asyncCall_2:
    __f->__suspend_121_2.await_resume();
    goto __final_suspend;
  } catch(...) {
    if(!__f->__initial_await_suspend_called) {
      throw ;
    } 
    
    __f->__promise.unhandled_exception();
  }
  
  __final_suspend:
  
  /* co_await insights.cpp:120 */
  __f->__suspend_120_15_1 = __f->__promise.final_suspend();
  if(!__f->__suspend_120_15_1.await_ready()) {
    __builtin_coro_resume(__f->__suspend_120_15_1.await_suspend(std::experimental::coroutine_handle<symmetricTask::promise_type>::from_address(static_cast<void *>(__f))).address());
  } 
  
  ;
}

/* This function invoked by coroutine_handle<>::destroy() */
void __asyncCallDestroy(__asyncCallFrame * __f)
{
  /* destroy all variables with dtors */
  __f->~__asyncCallFrame();
  /* Deallocating the coroutine frame */
  operator delete(__builtin_coro_free(static_cast<void *>(__f)));
}



int main()
{
  symmetricTask t = asyncCall();
  static_cast<std::experimental::coroutine_handle<void>&>(t.handle).resume();
  {
    std::vector<std::experimental::coroutine_handle<symmetricTask::promise_type>, std::allocator<std::experimental::coroutine_handle<symmetricTask::promise_type> > > & __range1 = mgr;
    std::__wrap_iter<std::experimental::coroutine_handle<symmetricTask::promise_type> *> __begin1 = __range1.begin();
    std::__wrap_iter<std::experimental::coroutine_handle<symmetricTask::promise_type> *> __end1 = __range1.end();
    for(; std::operator!=(__begin1, __end1); __begin1.operator++()) {
      std::experimental::coroutine_handle<symmetricTask::promise_type> t = std::experimental::coroutine_handle<symmetricTask::promise_type>(__begin1.operator*());
      static_cast<std::experimental::coroutine_handle<void>&>(t).resume();
    }
    
  }
  return 0;
}

简单总结就是对于每一个coroutine函数,都会生成一个专属的类,这个类的成员变量就是这个函数栈上的所有变量,在通过我们提供的promise_type和await的原语对这个类做拓展。co_await关键字会被拓展成一个case,如同第三方无栈协程一样通过switch case的方式实现协程的跳转。我在其中做的最关键的一步就是在init时进行了suspend,同时suspend时存下父函数的handle,在这个协程被挂起时,handle被保存到全局变量,函数链被依次挂起返回到最上层。当协程执行完结束时,调用final_suspend,这时会直接通过我们保存父函数的handle,来唤醒父函数,从而实现协程的对称概念。

整个函数的流程图如下,希望可以帮助大家理解:
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值