探究C++20协程(4)——协程中的调度器

协程本身并不能实现异步操作,它们需要依赖于调度器(Scheduler)的组件来实现异步操作。调度器负责管理协程的执行和调度。

调度器的抽象设计

为了实现协程的异步调度,我们需要提供调度器的一个抽象实现,来支持不同的调度逻辑。

class AbstractExecutor {
 public:
  virtual void execute(std::function<void()> &&func) = 0;
};

调度的时机

在C++20协程中,co_await表达式用于暂停当前协程的执行,并且可以进行异步操作。当异步操作完成后,协程将继续执行。这一过程的关键是所谓的Awaiter对象,它决定了何时以及如何恢复协程。

Awaiter必须实现以下三个成员函数:

  • await_ready():判断协程是否可以立即继续执行。如果返回true,则协程继续执行,不会真正挂起。
  • await_suspend(std::coroutine_handle<>):定义当协程挂起时执行的操作。这是调度器发挥作用的时机。
  • await_resume():定义当协程准备恢复执行时的返回值或行为。

通过自定义Awaiter,我们可以在await_suspend方法中集成一个调度器。这个调度器不必关心协程的具体执行内容,而是专注于何时以及在哪个上下文中恢复协程。其他两个函数都要求同步返回,是不能作为调度的时机。

实际上,按照 C++ 协程的设计,await_suspend 确实是用来提供调度支持的,由于这个时间点协程已经完全挂起,因此我们可以在任意一个线程上调用 handle.resume(),甚至不用担心线程安全的问题。

如果存在调度器,其实现的大致样子如下所示:

// 调度器的类型有多种,因此专门提供一个模板参数 Executor
template<typename Result, typename Executor>
struct TaskAwaiter {

  // 构造 TaskAwaiter 的时候传入调度器的具体实现
  explicit TaskAwaiter(AbstractExecutor *executor, Task<Result, Executor> &&task) noexcept
      : _executor(executor), task(std::move(task)) {}

  void await_suspend(std::coroutine_handle<> handle) noexcept {
    task.finally([handle, this]() {
      // 将 resume 函数的调用交给调度器执行
      _executor->execute([handle]() {
        handle.resume();
      });
    });
  }
  //...省略其余实现
 private:
  Task<Result, Executor> task;
  AbstractExecutor *_executor;
};

调度器的持有问题?(初次和挂起)

TaskAwaiter 当中的调度器实例是从外部传来的,这样设计的目的是希望把调度器的创建和绑定交给协程本身。换句话说,调度器应该属于协程。这样设计的好处就是协程内部的代码均会被调度到它对应的调度器上执行,可以确保逻辑的一致性和正确性。

调度器应该与 Task 或者 TaskPromise 绑定到一起。TaskPromise 作为协程的承诺类型(promise type),其中需要包含一个调度器实例。这个调度器决定了协程的运行方式,确保所有通过 co_await 挂起的任务在恢复时都能通过同一调度器执行。这里使用模板参数 Executor 来指定调度器的类型,使得每一个 Task 实例都绑定一个具体的 Executor 实例。

// 增加模板参数 Executor
template<typename ResultType, typename Executor>
struct TaskPromise {
  // 协程启动时也需要在恢复时实现调度
  DispatchAwaiter initial_suspend() { return DispatchAwaiter{&executor}; }

  std::suspend_always final_suspend() noexcept { return {}; }

  // Task 类型增加模板参数 Executor 可以方便创建协程时执行调度器的类型
  Task<ResultType, Executor> get_return_object() {
    return Task{std::coroutine_handle<TaskPromise>::from_promise(*this)};
  }

  // 注意模板参数
  template<typename _ResultType, typename _Executor>
  TaskAwaiter<_ResultType, _Executor> await_transform(Task<_ResultType, _Executor> &&task) {
    return TaskAwaiter<_ResultType, _Executor>(&executor, std::move(task));
  }

 private:

  Executor executor;
};

initial_suspend:在协程启动时触发,使用 DispatchAwaiter 将协程的第一次执行提交给绑定的调度器。这确保了协程即使在异步环境中也能保持执行的一致性(与co_wait 使用同一个调度器)。

await_transform:提供了对协程内部每一个 co_await 的处理。通过返回一个配置了当前调度器的 TaskAwaiter,保证了协程在每次挂起后都能通过相同的调度器恢复运行。

这里面的调度器都是由TaskPromise传入进去的指针指针,其对象实例是在TaskPromise里。

对于DispatchAwaiter实现如下,其协程内部的所有逻辑都可以顺利地调度到协程对应的调度器上。

struct DispatchAwaiter {

  explicit DispatchAwaiter(AbstractExecutor *executor) noexcept
      : _executor(executor) {}

  bool await_ready() const { return false; }

  void await_suspend(std::coroutine_handle<> handle) const {
    // 调度到协程对应的调度器上
    _executor->execute([handle]() {
      handle.resume();
    });
  }

  void await_resume() {}

 private:
  AbstractExecutor *_executor;
};

调度器的实现

接下来给出几种简单的调度器实现作为示例,了解其思想

NoopExecutor

最简单的调度器类型。它立即在当前线程中执行给定的任务,没有任何异步处理或线程切换。这种调度器适用于需要保证代码执行在同一线程上,或者在测试和调试阶段不希望引入额外的线程复杂性时使用。

class NoopExecutor : public AbstractExecutor {
 public:
  void execute(std::function<void()> &&func) override {
    func();
  }
};

对于Task 的改动不大,只是增加了模板参数 Executor:

template<typename ResultType, typename Executor = NewThreadExecutor>
struct Task {
    using promise_type = TaskPromise<ResultType, Executor>;
    ...
};

NewThreadExecutor

每次调用 execute 时都会创建一个新的线程来运行传递的任务。这种调度器适合于执行时间较长的任务,且任务之间几乎没有依赖,能够并行处理。由于每个任务都在新线程上执行(这个需要线程池来优化)

class NewThreadExecutor : public AbstractExecutor {
 public:
  void execute(std::function<void()> &&func) override {
    std::thread(func).detach();
  }
};

AsyncExecutor

使用std::async机制来调度协程。std::launch::async指定了异步执行。

class AsyncExecutor : public AbstractExecutor {
 public:
  void execute(std::function<void()> &&func) override {
    auto future = std::async(std::launch::async, func);
  }
};

LooperExecutor

是一个基于事件循环的单线程调度器,通常用于需要顺序处理任务的场景。内部使用一个线程来循环处理任务队列中的任务,通过同步机制(条件变量和互斥锁)管理任务的添加和执行。

class LooperExecutor : public AbstractExecutor {
 private:
  std::condition_variable queue_condition;
  std::mutex queue_lock;
  std::queue<std::function<void()>> executable_queue;

  // true 的时候是工作状态,如果要关闭事件循环,就置为 false
  std::atomic<bool> is_active;
  std::thread work_thread;

  // 处理事件循环
  void run_loop() {
    // 检查当前事件循环是否是工作状态,或者队列没有清空
    while (is_active.load(std::memory_order_relaxed) || !executable_queue.empty()) {
      std::unique_lock lock(queue_lock);
      if (executable_queue.empty()) {
        // 队列为空,需要等待新任务加入队列或者关闭事件循环的通知
        queue_condition.wait(lock);
        // 如果队列为空,那么说明收到的是关闭的通知
        if (executable_queue.empty()) {
          // 现有逻辑下此处用 break 也可
          // 使用 continue 可以再次检查状态和队列,方便将来扩展
          continue;
        }
      }
      // 取出第一个任务,解锁再执行。
      // 解锁非常:func 是外部逻辑,不需要锁保护;func 当中可能请求锁,导致死锁
      auto func = executable_queue.front();
      executable_queue.pop();
      lock.unlock();

      func();
    }
  }

 public:

  LooperExecutor() {
    is_active.store(true, std::memory_order_relaxed);
    work_thread = std::thread(&LooperExecutor::run_loop, this);
  }

  ~LooperExecutor() {
    shutdown(false);
    // 等待线程执行完,防止出现意外情况
    join();
  }

  void execute(std::function<void()> &&func) override {
    std::unique_lock lock(queue_lock);
    if (is_active.load(std::memory_order_relaxed)) {
      executable_queue.push(func);
      lock.unlock();
      // 通知队列,主要用于队列之前为空时调用 wait 等待的情况
      // 通知不需要加锁,否则锁会交给 wait 方导致当前线程阻塞
      queue_condition.notify_one();
    }
  }

  void shutdown(bool wait_for_complete = true) {
    // 修改后立即生效,在 run_loop 当中就能尽早(加锁前)就检测到 is_active 的变化
    is_active.store(false, std::memory_order_relaxed);
    if (!wait_for_complete) {    
      std::unique_lock lock(queue_lock);
      // 清空任务队列
      decltype(executable_queue) empty_queue;
      std::swap(executable_queue, empty_queue);
      lock.unlock();
    }

    // 通知 wait 函数,避免 Looper 线程不退出
    // 不需要加锁,否则锁会交给 wait 方导致当前线程阻塞
    queue_condition.notify_all();
  }

  void join() {
    if (work_thread.joinable()) {
      work_thread.join();
    }
  }
};
  • 当队列为空时,Looper 的线程通过 wait 来实现阻塞等待。
  • 有新任务加入时,通过 notify_one 来通知 run_loop 继续执行。

结果展示

测试结果代码如下:

// 这意味着每个恢复的位置都会通过 std::async 上执行
Task<int,AsyncExecutor> simple_task2() {
    debug("task 2 start ...");
    using namespace std::chrono_literals;
    std::this_thread::sleep_for(1s);
    debug("task 2 returns after 1s.");
    co_return 2;
}
// 这意味着每个恢复的位置都会新建一个线程来执行
Task<int,NewThreadExecutor> simple_task3() {
    debug("in task 3 start ...");
    using namespace std::chrono_literals;
    std::this_thread::sleep_for(2s);
    debug("task 3 returns after 2s.");
    co_return 3;
}
//这意味着每个恢复的位置都会在同一个线程上执行
Task<int, LooperExecutor> simple_task() {
    debug("task start ...");
    auto result2 = co_await simple_task2();
    debug("returns from task2: ", result2);
    auto result3 = co_await simple_task3();
    debug("returns from task3: ", result3);
    co_return 1 + result2 + result3;
}

int main() {
    auto simpleTask = simple_task();
    simpleTask.then([](int i) {
        debug("simple task end: ", i);
        }).catching([](std::exception& e) {
            //debug("error occurred", e.what());
        });
    try {
        auto i = simpleTask.get_result();
        debug("simple task end from get: ", i);
    }
    catch (std::exception& e) {
        //debug("error: ", e.what());
    }
    return 0;
}

结果如下:

37432 task start ...
39472 task 2 start ...
39472 task 2 returns after 1s.
37432 returns from task2:  2
33856 in task 3 start ...
33856 task 3 returns after 2s.
37432 returns from task3:  3
37432 simple task end:  6
4420 simple task end from get:  6
37432 run_loop exit.

解读如下:

  • 37432 task start,协程 simple_task 在 LooperExecutor 上开始执行的标志。
  • 39472 task 2 start,在 simple_task 中,第一个 co_await simple_task2() 被调用。因为 simple_task2 使用 AsyncExecutor,它启动了在新线程上的异步操作。
  • 39472 task 2 returns after 1s,simple_task2 完成执行,同样在 39472 线程上。表示 std::async 的操作已完成。
  • 37432 returns from task2: 2,控制返回到 simple_task 中,接收 simple_task2 的结果(返回值 2)。此操作仍然在 LooperExecutor 线程 (37432) 上执行,表明 LooperExecutor 成功地将任务调度回了其事件循环中。
  • 33856 in task 3 start,simple_task 中的第二个 co_await,即 simple_task3() 被调用。因为 simple_task3 使用 NewThreadExecutor,它在新的线程 (33856) 上开始执行。
  • 33856 task 3 returns after 2s,在新线程 (33856) 上,simple_task3 完成了执行,该线程由 NewThreadExecutor 创建,与 simple_task2 的执行线程 (39472) 不同。
  • 37432 returns from task3: 3,控制权再次回到 simple_task 中,处理从 simple_task3 返回的结果(返回值 3),依旧在 LooperExecutor 的线程 (37432) 上执行。
  • 37432 simple task end: 6,simple_task 计算最终结果并结束,所有操作均在 LooperExecutor 的线程 (37432) 上顺利完成。
  • 4420 simple task end from get: 6,最终结果 (6) 被从 simple_task.get_result() 方法中检索,在主线程上执行。
  • 37432 run_loop exit,LooperExecutor 的事件循环线程 (37432) 结束,可能是因为程序的整体结束或 LooperExecutor 的 shutdown() 方法被调用。

完整代码

存在两个等待体,都有调度器,DispatchAwaiter和TaskAwaiter。

协程的启动时(通过TaskPromise的 initial_suspend):DispatchAwaiter 确保协程的第一步执行可以被适当调度。这其中设置协程的执行环境,确保协程在特定的线程或线程池中启动。

当协程在 co_await 表达式处挂起时(通过await_transform):TaskAwaiter 会通过传递给它的调度器调度协程的恢复。这是通过调用调度器的 execute 方法来安排 std::coroutine_handle 的 resume 方法执行的。

#define __cpp_lib_coroutine
#include <coroutine>
#include <exception>
#include <iostream>
#include <thread>
#include <functional>
#include <mutex>
#include <list>
#include <optional>
#include <cassert>
#include <queue>
#include <future>
using namespace std;
void debug(const std::string& s) {
    printf("%d %s\n", std::this_thread::get_id(), s.c_str());
}

void debug(const std::string& s, int x) {
    printf("%d %s %d\n", std::this_thread::get_id(), s.c_str(), x);
}

// 调度器
class AbstractExecutor {
public:
    virtual void execute(std::function<void()>&& func) = 0;
};

class NoopExecutor : public AbstractExecutor {
public:
    void execute(std::function<void()>&& func) override {
        func();
    }
};

class NewThreadExecutor : public AbstractExecutor {
public:
    void execute(std::function<void()>&& func) override {
        std::thread(func).detach();
    }
};

class AsyncExecutor : public AbstractExecutor {
public:
    void execute(std::function<void()>&& func) override {
        auto future = std::async(func);
    }
};

class LooperExecutor : public AbstractExecutor {
private:
    std::condition_variable queue_condition;
    std::mutex queue_lock;
    std::queue<std::function<void()>> executable_queue;

    std::atomic<bool> is_active;
    std::thread work_thread;

    void run_loop() {
        while (is_active.load(std::memory_order_relaxed) || !executable_queue.empty()) {
            std::unique_lock lock(queue_lock);
            if (executable_queue.empty()) {
                queue_condition.wait(lock);
                if (executable_queue.empty()) {
                    continue;
                }
            }
            auto func = executable_queue.front();
            executable_queue.pop();
            lock.unlock();

            func();
        }
        debug("run_loop exit.");
    }

public:

    LooperExecutor() {
        is_active.store(true, std::memory_order_relaxed);
        work_thread = std::thread(&LooperExecutor::run_loop, this);
    }

    ~LooperExecutor() {
        shutdown(false);
        if (work_thread.joinable()) {
            work_thread.join();
        }
    }

    void execute(std::function<void()>&& func) override {
        std::unique_lock lock(queue_lock);
        if (is_active.load(std::memory_order_relaxed)) {
            executable_queue.push(func);
            lock.unlock();
            queue_condition.notify_one();
        }
    }

    void shutdown(bool wait_for_complete = true) {
        is_active.store(false, std::memory_order_relaxed);
        if (!wait_for_complete) {
            // clear queue.
            std::unique_lock lock(queue_lock);
            decltype(executable_queue) empty_queue;
            std::swap(executable_queue, empty_queue);
            lock.unlock();
        }

        queue_condition.notify_all();
    }
};

template<typename T>
struct Result
{
    explicit Result() = default;
    explicit Result(T&& value) : _value(value) {}
    explicit Result(std::exception_ptr&& exception_ptr) : _exception_ptr(exception_ptr) {}
    T get_or_throw() {
        if (_exception_ptr) {
            std::rethrow_exception(_exception_ptr);
        }
        return _value;
    }
private:
    T _value;
    std::exception_ptr _exception_ptr;
};
// 用于协程initial_suspend()时直接将运行逻辑切入调度器的等待体
struct DispatchAwaiter {

    explicit DispatchAwaiter(AbstractExecutor* executor) noexcept
        : _executor(executor) {}

    bool await_ready() const { return false; }

    void await_suspend(std::coroutine_handle<> handle) const {
        _executor->execute([handle]() {
            handle.resume();
            });
    }

    void await_resume() {}

private:
    AbstractExecutor* _executor;
};


// 前向声明
template<typename ResultType, typename Executor>
struct Task;

template<typename Result, typename Executor>
struct TaskAwaiter {
    explicit TaskAwaiter(AbstractExecutor* executor, Task<Result, Executor>&& task) noexcept
        : _executor(executor), task(std::move(task)) {}

    TaskAwaiter(TaskAwaiter&& completion) noexcept
        : _executor(completion._executor), task(std::exchange(completion.task, {})) {}

    TaskAwaiter(TaskAwaiter&) = delete;

    TaskAwaiter& operator=(TaskAwaiter&) = delete;

    constexpr bool await_ready() const noexcept {
        return false;
    }
    // 在这里增加了调度器的运行
    void await_suspend(std::coroutine_handle<> handle) noexcept {
        task.finally([handle, this]() {
            _executor->execute([handle]() {
                handle.resume();
                });
            });
    }

    Result await_resume() noexcept {
        return task.get_result();
    }

private:
    Task<Result, Executor> task;
    AbstractExecutor* _executor;
};

// 对应修改增加调度器的传入
template<typename ResultType,typename Executor>
struct TaskPromise {
    //此时调度器将开始调度,执行的逻辑
    DispatchAwaiter initial_suspend() { return DispatchAwaiter(&executor); }

    std::suspend_always final_suspend() noexcept { return {}; }

    Task<ResultType, Executor> get_return_object() {
        return Task{ std::coroutine_handle<TaskPromise>::from_promise(*this) };
    }
    //在这里返回等待器对象时需要将调度器的指针带上
    template<typename _ResultType, typename _Executor>
    TaskAwaiter<_ResultType, _Executor> await_transform(Task<_ResultType, _Executor>&& task) {
        return TaskAwaiter<_ResultType, _Executor>(&executor, std::move(task));
    }

    void unhandled_exception() {
        std::lock_guard lock(completion_lock);
        result = Result<ResultType>(std::current_exception());
        completion.notify_all();
        notify_callbacks();
    }

    void return_value(ResultType value) {
        std::lock_guard lock(completion_lock);
        result = Result<ResultType>(std::move(value));
        completion.notify_all();
        notify_callbacks();
    }

    ResultType get_result() {
        std::unique_lock lock(completion_lock);
        if (!result.has_value()) {
            completion.wait(lock);
        }
        return result->get_or_throw();
    }

    void on_completed(std::function<void(Result<ResultType>)>&& func) {
        std::unique_lock lock(completion_lock);
        if (result.has_value()) {
            auto value = result.value();
            lock.unlock();
            func(value);
        }
        else {
            completion_callbacks.push_back(func);
        }
    }

private:
    std::optional<Result<ResultType>> result;
    Executor executor;
    std::mutex completion_lock;
    std::condition_variable completion;

    std::list<std::function<void(Result<ResultType>)>> completion_callbacks;

    void notify_callbacks() {
        auto value = result.value();
        for (auto& callback : completion_callbacks) {
            callback(value);
        }
        completion_callbacks.clear();
    }

};

template<typename ResultType,typename Executor = NewThreadExecutor>
struct Task {

    using promise_type = TaskPromise<ResultType, Executor>;

    ResultType get_result() {
        return handle.promise().get_result();
    }

    Task& then(std::function<void(ResultType)>&& func) {
        handle.promise().on_completed([func](auto result) {
            try {
                func(result.get_or_throw());
            }
            catch (std::exception& e) {
                // ignore.
            }
            });
        return *this;
    }

    Task& catching(std::function<void(std::exception&)>&& func) {
        handle.promise().on_completed([func](auto result) {
            try {
                result.get_or_throw();
            }
            catch (std::exception& e) {
                func(e);
            }
            });
        return *this;
    }

    Task& finally(std::function<void()>&& func) {
        handle.promise().on_completed([func](auto result) { func(); });
        return *this;
    }

    explicit Task(std::coroutine_handle<promise_type> handle) noexcept : handle(handle) {}

    Task(Task&& task) noexcept : handle(std::exchange(task.handle, {})) {}

    Task(Task&) = delete;

    Task& operator=(Task&) = delete;

    ~Task() {
        if (handle) handle.destroy();
    }

private:
    std::coroutine_handle<promise_type> handle;
};
// 这意味着每个恢复的位置都会通过 std::async 上执行
Task<int,AsyncExecutor> simple_task2() {
    debug("task 2 start ...");
    using namespace std::chrono_literals;
    std::this_thread::sleep_for(1s);
    debug("task 2 returns after 1s.");
    co_return 2;
}
// 这意味着每个恢复的位置都会新建一个线程来执行
Task<int,NewThreadExecutor> simple_task3() {
    debug("in task 3 start ...");
    using namespace std::chrono_literals;
    std::this_thread::sleep_for(2s);
    debug("task 3 returns after 2s.");
    co_return 3;
}
//这意味着每个恢复的位置都会在同一个线程上执行
Task<int, LooperExecutor> simple_task() {
    debug("task start ...");
    auto result2 = co_await simple_task2();
    debug("returns from task2: ", result2);
    auto result3 = co_await simple_task3();
    debug("returns from task3: ", result3);
    co_return 1 + result2 + result3;
}

int main() {
    auto simpleTask = simple_task();
    simpleTask.then([](int i) {
        debug("simple task end: ", i);
        }).catching([](std::exception& e) {
            //debug("error occurred", e.what());
        });
    try {
        auto i = simpleTask.get_result();
        debug("simple task end from get: ", i);
    }
    catch (std::exception& e) {
        //debug("error: ", e.what());
    }
    return 0;
}
  • 23
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值