C++20尝鲜-协程(二)

C++20尝鲜-协程(二)

上文中我们简单介绍了c++20中协程的用法,本文参考举几个应用例子来加深下c++20中协程的使用。

编译器版本如下:

gcc version 11.1.0 (Ubuntu 11.1.0-1ubuntu1~20.04)

实现类似python生成器功能

简单介绍下python的生成器,在python中协程通过yield实现,当调用存在yield的函数时,函数中的逻辑实际上并没有马上执行,而是只返回了一个生成器对象。而后每次使用这个生成器的时候,函数才被实际执行。

模仿python生成器的逻辑,用c++20去实现并用来生成斐波那契数列来作为例子,代码如下:

#include <coroutine>
#include <iostream>

/**
 * @brief   `std::suspend_always`(initial) + `std::suspend_always`(final)
 * @ingroup Return
 */
class promise_aa {
public:
    /**
     * @brief suspend after invoke
     * @return std::suspend_always
     */
    std::suspend_always initial_suspend() noexcept {
        return {};
    }
    /**
     * @brief suspend after return
     * @return std::suspend_always
     */
    std::suspend_always final_suspend() noexcept {
        return {};
    }
};

template <typename T>
class enumerable {
public:
    class promise_type;
    class iterator;

    using value_type = T;
    using reference = value_type&;
    using pointer = value_type*;

private:
    std::coroutine_handle<promise_type> coro{};

public:
    enumerable(const enumerable&) = delete;
    enumerable& operator=(const enumerable&) = delete;
    enumerable(enumerable&& rhs) noexcept : coro{ rhs.coro } {
        rhs.coro = nullptr;
    }
    enumerable& operator=(enumerable&& rhs) noexcept {
        std::swap(coro, rhs.coro);
        return *this;
    }
    enumerable() noexcept = default;
    explicit enumerable(std::coroutine_handle<promise_type> rh) noexcept : coro{ rh } {
    }
    /**
     * @brief   The type will destroy the frame in destructor
     *          So promise/iterator are free from those ownership control
     */
    ~enumerable() noexcept {
        if (coro)
            coro.destroy();
    }

public:
    iterator begin() noexcept(false) {
        if (coro) // resumeable?
        {
            coro.resume();
            if (coro.done()) // finished?
                return iterator{ nullptr };
        }
        return iterator{ coro };
    }
    iterator end() noexcept {
        return iterator{ nullptr };
    }

public:
    class promise_type final : public promise_aa {
        friend class iterator;
        friend class enumerable;

        pointer current = nullptr;

    public:
        /**
         * @brief create coroutine handle from current promise's address
         */
        enumerable get_return_object() noexcept {
            return enumerable{
                std::coroutine_handle<promise_type>::from_promise(*this) };
        }
        void unhandled_exception() noexcept(false) {
            throw;
        }
        /// @brief  `co_yield` expression. for reference
        auto yield_value(reference ref) noexcept {
            current = std::addressof(ref);
            return std::suspend_always{};
        }
        /// @brief  `co_yield` expression. for r-value
        auto yield_value(value_type&& v) noexcept {
            return yield_value(v);
        }
        /// @brief

        /**
         * @brief `co_return` expression. There should be no more access to the value.
         */
        void return_void() noexcept {
            current = nullptr;
        }
    };

    class iterator final {
    public:
        using iterator_category = std::forward_iterator_tag;
        using difference_type = ptrdiff_t;
        using value_type = T;
        using reference = value_type&;
        using pointer = value_type*;

    public:
        std::coroutine_handle<promise_type> coro;

    public:
        /// @see enumerable::end()
        explicit iterator(std::nullptr_t) noexcept : coro{ nullptr } {
        }
        /// @see enumerable::begin()
        explicit iterator(std::coroutine_handle<promise_type> handle) noexcept
            : coro{ handle } {
        }

    public:
        /// @brief post increment is prohibited
        iterator& operator++(int) = delete;
        iterator& operator++() noexcept(false) {
            coro.resume();
            if (coro.done())    // enumerable will destroy
                coro = nullptr; // the frame later...
            return *this;
        }

        pointer operator->() noexcept {
            pointer ptr = coro.promise().current;
            return ptr;
        }
        reference operator*() noexcept {
            return *(this->operator->());
        }

        bool operator==(const iterator& rhs) const noexcept {
            return this->coro == rhs.coro;
        }
        bool operator!=(const iterator& rhs) const noexcept {
            return !(*this == rhs);
        }
    };
};

auto Fibonacci(int32_t max)->enumerable<int32_t>
{
	int32_t a = 0;
	int32_t b = 1;
	for (int32_t fib = b; fib < max; fib = a + b, a = b, b = fib)
	{
		co_yield fib;
	}
}

int main()
{
	auto g = Fibonacci(100);
	for (auto iter = g.begin(); iter != g.end(); ++iter)
	{
		std::cout << *iter << " ";
	}
	std::cout << std::endl;
}

结果打印如下:

root@DESKTOK:/mnt/d/Code/os# ./a.out
1 1 2 3 5 8 13 21 34 55 89

Fibonacci函数用来打印小于max的所有斐波那契序列,函数返回一个enumerable<int32_t>对象,我们知道它实际上是一个协程句柄。函数的实际执行流程如下:

  1. 调用Fibonacci(100)得到enumerable<int32_t>对象,因为initial_suspend返回为std::suspend_always,所以函数实际为暂停状态,还未开始运行
  2. 遍历enumerable<int32_t>对象,调用begin()获取迭代器对象,并在begin()函数中恢复Fibonacci函数的执行
  3. Fibonacci函数中恢复执行,继续运行``co_yield fib,调用promise_typeauto yield_value(reference ref)将缓存fib的地址至current对象,然后暂停并返回到main`函数
  4. main函数继续遍历,通过迭代器的operator ++重载控制Fibonacci函数的切入恢复,继续又继续运行暂停
  5. 直到遍历完成,函数返回,coro.done()trueiter == g.end()成立,main函数循环也结束

eventfd事件通知

一般epoll在代码中的调用流程为

  1. 创建epoll
  2. 注册关心的事件至epoll
  3. 调用epoll_wait等待关心事件的发生
  4. 关心事件发生后的处理

我们可以用协程的方式实现上述的几个流程,以eventfd的写事件作为关心的事件,代码实现如下:

#include <coroutine>
#include <iostream>
#include <array>
#include <vector>
#include <assert.h>
#include <sys/epoll.h>
#include <unistd.h>
#include <sys/eventfd.h>
#include <string.h>

class promise_na {
public:
    /**
     * @brief no suspend after invoke
     * @return suspend_never
     */
    std::suspend_never initial_suspend() noexcept {
        return {};
    }
    /**
     * @brief suspend after return
     * @return suspend_always
     */
    std::suspend_always final_suspend() noexcept {
        return {};
    }
};

void notify_event(int64_t efd) noexcept(false) {
	write(efd, &efd, sizeof(efd));
}

void consume_event(int64_t efd) noexcept(false) {
	int32_t ret = 0;
    while (ret = read(efd, &efd, sizeof(efd)) > 0)
    {
        std::cout << "read size : " << ret << std::endl;
    }
}

class epoll_owner final {
    int64_t epfd;

public:
    /**
     * @brief create a fd with `epoll`. Throw if the function fails.
     * @see kqeueue
     * @throw system_error
     */
	epoll_owner() noexcept(false) :epfd{ epoll_create1(EPOLL_CLOEXEC) }
	{

	}
    /**
     * @brief close the current epoll file descriptor
     */
    ~epoll_owner() noexcept
    {
        close(epfd);
    }
    epoll_owner(const epoll_owner&) = delete;
    epoll_owner(epoll_owner&&) = delete;
    epoll_owner& operator=(const epoll_owner&) = delete;
    epoll_owner& operator=(epoll_owner&&) = delete;

public:
	void try_add(uint64_t fd, epoll_event& req) noexcept(false)
	{
		int op = EPOLL_CTL_ADD;
		epoll_ctl(epfd, op, fd, &req);
	}

    void reset(uint64_t fd, epoll_event& req)
    {
        int op = EPOLL_CTL_MOD;
        epoll_ctl(epfd, op, fd, &req);
    }

    ptrdiff_t wait(uint32_t wait_ms,
        std::vector<epoll_event>& output) noexcept(false)
    {
        auto count = epoll_wait(epfd, output.data(), output.size(), wait_ms);
        return count;
    }
};

class event final {
    int32_t m_fd;
    bool m_isSet;

public:
	event() noexcept(false) :m_isSet(false)
	{
		m_fd = ::eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC);
	}
	~event() noexcept
	{
		close(m_fd);
		m_fd = 0;
	}
    event(const event&) = delete;
    event(event&&) = delete;
    event& operator=(const event&) = delete;
    event& operator=(event&&) = delete;

    uint64_t fd() const noexcept
    {
        return m_fd;
    }
    bool is_set() const noexcept
    {
        return m_isSet;
    }
    void set() noexcept(false)
    {
        if (m_isSet)
        {
            return;
        }

        notify_event(m_fd);                          
        m_isSet = true;
    }
	void reset() noexcept(false)
	{
        if (m_isSet)
        {
            consume_event(m_fd);
        }
		m_isSet = false;
	}
};

auto wait_in(epoll_owner& ep, event& efd) {
    class awaiter : epoll_event {
        epoll_owner& ep;
        event& efd;

    public:
        /**
         * @brief Prepares one-time registration
         */
        awaiter(epoll_owner& _ep, event& _efd) noexcept
            : epoll_event{}, ep{ _ep }, efd{ _efd } {
            this->events = EPOLLET | EPOLLIN | EPOLLONESHOT;
        }

        bool await_ready() const noexcept {
            return efd.is_set();
        }
        /**
         * @brief Wait for `write` to given `eventfd`
         */
        void await_suspend(std::coroutine_handle<void> coro) noexcept(false) {
            this->data.ptr = coro.address();
            static bool added = false;
            if (!added)
            {
                added = true;
                return ep.try_add(efd.fd(), *this);
            }
            else
            {
                return ep.reset(efd.fd(), *this);
            }
        }
        /**
         * @brief Reset the given event object when resumed
         */
        void await_resume() noexcept {
            return efd.reset();
        }
    };
    return awaiter{ ep, efd };
}

class frame_t : public std::coroutine_handle<void> {
public:
    /**
     * @brief Acquire `coroutine_handle<void>` from current object and expose it through `get_return_object`
     */
    class promise_type : public promise_na {
    public:
        /**
         * @brief The `frame_t` will do nothing for exception handling
         */
        void unhandled_exception() noexcept(false) {
            throw;
        }
        void return_void() noexcept {
        }
        /**
         * @brief Acquire `coroutine_handle<void>` from current promise and return it
         * @return frame_t
         */
        frame_t get_return_object() noexcept {
            return frame_t{ coroutine_handle<promise_type>::from_promise(*this) };
        }
    };

public:
    explicit frame_t(coroutine_handle<void> frame = nullptr) noexcept
        : std::coroutine_handle<void>{ frame } {
    }
};

auto wait_for_multiple_times(epoll_owner& ep, event& efd, //
    uint32_t counter) -> frame_t {
    while (counter--)
        co_await wait_in(ep, efd);
}

void resume_signaled_tasks(epoll_owner& ep) {
    std::vector<epoll_event> events{10};
    auto count = ep.wait(1000, events); // wait for 1 sec
	if (count == 0)
	{
		std::cout << "errno : " << errno << " , msg : " << strerror(errno) << std::endl;
		return;
	}

	std::for_each(events.begin(), events.begin() + count, [](epoll_event& e) {
		auto coro = std::coroutine_handle<void>::from_address(e.data.ptr);
		coro.resume();
		});
}

int main()
{
    epoll_owner ep{};
    event e1 {};

    wait_for_multiple_times(ep, e1, 6); // 6 await
    auto repeat = 8u;                   // + 2 timeout
    while (repeat--) {
        e1.set();
        // resume if there is available event-waiting coroutines
        resume_signaled_tasks(ep);
    };
    return 0;
}
  1. 首先调用协程函数wait_for_multiple_times,它返回frame_t协程句柄,函数执行暂停在co_await wait_in(ep, efd);,并执行await_suspendeventfd事件注册至epoll
  2. 回到主函数,执行while循环,并触发eventfd的写事件,接着调用resume_signaled_tasks函数,并在函数中通过ep.wait执行epoll_wait,获知eventfd写事件发生,继而恢复wait_for_multiple_times的执行
  3. 以此类推,wait_for_multiple_times共执行6次

参考资料

https://github.com/luncliff/coroutine

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值