Taskflow 是一个基于 C++ 的并行任务编程框架,其核心调度器采用了高效的工作窃取算法来实现任务的动态负载均衡。 -- 【基于AI】
工作窃取算法基本原理
工作窃取算法是一种用于任务并行调度的策略,主要特点包括:
-
每个工作线程维护自己的任务队列
-
当线程自己的队列为空时,可以从其他线程的队列"窃取"任务
-
减少了线程间的竞争,提高了并行效率
Taskflow 的工作窃取实现
1. 任务队列结构
Taskflow 为每个工作线程维护一个双端队列(deque):
class Worker {
//...
size_t _id;
size_t _vtm;
Executor* _executor {nullptr};
DefaultNotifier::Waiter* _waiter;
std::thread _thread;
std::default_random_engine _rdgen;
//std::uniform_int_distribution<size_t> _udist;
BoundedTaskQueue<Node*> _wsq;
};
3. 工作线程主循环
// Procedure: _spawn
inline void Executor::_spawn(size_t N) {
for(size_t id=0; id<N; ++id) {
_workers[id]._id = id;
_workers[id]._vtm = id;
_workers[id]._executor = this;
_workers[id]._waiter = &_notifier._waiters[id];
_workers[id]._thread = std::thread([&, &w=_workers[id]] () {
pt::this_worker = &w;
// initialize the random engine and seed for work-stealing loop
w._rdgen.seed(static_cast<std::default_random_engine::result_type>(
std::hash<std::thread::id>()(std::this_thread::get_id()))
);
// before entering the work-stealing loop, call the scheduler prologue
if(_worker_interface) {
_worker_interface->scheduler_prologue(w);
}
Node* t = nullptr;
std::exception_ptr ptr = nullptr;
// must use 1 as condition instead of !done because
// the previous worker may stop while the following workers
// are still preparing for entering the scheduling loop
#ifndef TF_DISABLE_EXCEPTION_HANDLING
try {
#endif
// worker loop
while(1) {
// drain out the local queue
_exploit_task(w, t);
// steal and wait for tasks
if(_wait_for_task(w, t) == false) {
break;
}
}
#ifndef TF_DISABLE_EXCEPTION_HANDLING
}
catch(...) {
ptr = std::current_exception();
}
#endif
// call the user-specified epilogue function
if(_worker_interface) {
_worker_interface->scheduler_epilogue(w, ptr);
}
});
}
}
// Function: _wait_for_task
inline bool Executor::_wait_for_task(Worker& w, Node*& t) {
explore_task:
if(_explore_task(w, t) == false) {
return false;
}
// Go exploit the task if we successfully steal one.
if(t) {
return true;
}
//...
}
// Function: _explore_task
inline bool Executor::_explore_task(Worker& w, Node*& t) {
//assert(!t);
const size_t MAX_STEALS = ((num_queues() + 1) << 1);
std::uniform_int_distribution<size_t> udist(0, num_queues()-1);
size_t num_steals = 0;
size_t vtm = w._vtm;
// Make the worker steal immediately from the assigned victim.
while(true) {
// If the worker's victim thread is within the worker pool, steal from the worker's queue.
// Otherwise, steal from the buffer, adjusting the victim index based on the worker pool size.
t = (vtm < _workers.size())
? _workers[vtm]._wsq.steal()
: _buffers.steal(vtm - _workers.size());
if(t) {
w._vtm = vtm;
break;
}
// Increment the steal count, and if it exceeds MAX_STEALS, yield the thread.
// If the number of *consecutive* empty steals reaches MAX_STEALS, exit the loop.
if (++num_steals > MAX_STEALS) {
std::this_thread::yield();
if(num_steals > 100 + MAX_STEALS) {
break;
}
}
#if __cplusplus >= TF_CPP20
if(w._done.test(std::memory_order_relaxed)) {
#else
if(w._done.load(std::memory_order_relaxed)) {
#endif
return false;
}
// Randomely generate a next victim.
vtm = udist(w._rdgen); //w._rdvtm();
}
return true;
}
任务调度流程
// Procedure: _schedule
inline void Executor::_schedule(Worker& worker, Node* node) {
// caller is a worker of this executor - starting at v3.5 we do not use
// any complicated notification mechanism as the experimental result
// has shown no significant advantage.
if(worker._executor == this) {
worker._wsq.push(node, [&](){ _buffers.push(node); });
_notifier.notify_one();
return;
}
// caller is not a worker of this executor - go through the centralized queue
_buffers.push(node);
_notifier.notify_one();
}
// Procedure: _schedule
inline void Executor::_schedule(Node* node) {
_buffers.push(node);
_notifier.notify_one();
}
4. 工作窃取实现
BoundedTaskQueue一个有界任务队列的实现,专为高性能工作窃取(work-stealing)调度器设计。下面我将详细分析这个类的设计和实现特点。
核心设计特点
-
固定大小环形缓冲区:使用模板参数确定固定容量(2^LogSize)
-
无锁设计:通过原子操作实现线程安全
-
双端访问:支持前端(pop/push)和后端(steal)操作
-
缓存优化:严格的缓存行对齐减少伪共享
-
边界检查:明确的满/空状态处理
template <typename T, size_t LogSize = TF_DEFAULT_BOUNDED_TASK_QUEUE_LOG_SIZE>
class BoundedTaskQueue {
static_assert(std::is_pointer_v<T>, "T must be a pointer type");
constexpr static int64_t BufferSize = int64_t{1} << LogSize;
constexpr static int64_t BufferMask = (BufferSize - 1);
static_assert((BufferSize >= 2) && ((BufferSize & (BufferSize - 1)) == 0));
alignas(2*TF_CACHELINE_SIZE) std::atomic<int64_t> _top {0};
alignas(2*TF_CACHELINE_SIZE) std::atomic<int64_t> _bottom {0};
alignas(2*TF_CACHELINE_SIZE) std::atomic<T> _buffer[BufferSize];
public:
/**
@brief constructs the queue with a given capacity
*/
BoundedTaskQueue() = default;
/**
@brief destructs the queue
*/
~BoundedTaskQueue() = default;
/**
@brief queries if the queue is empty at the time of this call
*/
bool empty() const noexcept;
/**
@brief queries the number of items at the time of this call
*/
size_t size() const noexcept;
/**
@brief queries the capacity of the queue
*/
constexpr size_t capacity() const;
/**
@brief tries to insert an item to the queue
@tparam O data type
@param item the item to perfect-forward to the queue
@return `true` if the insertion succeed or `false` (queue is full)
Only the owner thread can insert an item to the queue.
*/
template <typename O>
bool try_push(O&& item);
/**
@brief tries to insert an item to the queue or invoke the callable if fails
@tparam O data type
@tparam C callable type
@param item the item to perfect-forward to the queue
@param on_full callable to invoke when the queue is full (insertion fails)
Only the owner thread can insert an item to the queue.
*/
template <typename O, typename C>
void push(O&& item, C&& on_full);
/**
@brief pops out an item from the queue
Only the owner thread can pop out an item from the queue.
The return can be a `nullptr` if this operation failed (empty queue).
*/
T pop();
/**
@brief steals an item from the queue
Any threads can try to steal an item from the queue.
The return can be a `nullptr` if this operation failed (not necessary empty).
*/
T steal();
/**
@brief attempts to steal a task with a hint mechanism
@param num_empty_steals a reference to a counter tracking consecutive empty steal attempts
This function tries to steal a task from the queue. If the steal attempt
is successful, the stolen task is returned.
Additionally, if the queue is empty, the provided counter `num_empty_steals` is incremented;
otherwise, `num_empty_steals` is reset to zero.
*/
T steal_with_hint(size_t& num_empty_steals);
};
主要组件分析
1. 模板参数与常量
template <typename T, size_t LogSize = TF_DEFAULT_BOUNDED_TASK_QUEUE_LOG_SIZE>
class BoundedTaskQueue {
constexpr static int64_t BufferSize = int64_t{1} << LogSize;
constexpr static int64_t BufferMask = (BufferSize - 1);
static_assert((BufferSize >= 2) && ((BufferSize & (BufferSize - 1)) == 0));
};
-
容量确定:在编译时通过
LogSize
确定队列大小(必须是2的幂) -
掩码优化:使用
BufferMask
替代取模运算,提高效率
2. 核心成员变量
alignas(2*TF_CACHELINE_SIZE) std::atomic<int64_t> _top {0}; // 队首(窃取端)
alignas(2*TF_CACHELINE_SIZE) std::atomic<int64_t> _bottom {0}; // 队尾(所有者端)
alignas(2*TF_CACHELINE_SIZE) std::atomic<T> _buffer[BufferSize]; // 存储数组
-
缓存对齐:关键变量分开对齐(2倍缓存行),避免伪共享
-
原子操作:所有变量使用原子类型保证线程安全
关键操作分析
1. try_push 操作 (所有者线程)
template <typename O>
bool try_push(O&& o) {
int64_t b = _bottom.load(std::memory_order_relaxed);
int64_t t = _top.load(std::memory_order_acquire);
// 检查队列是否满
if TF_UNLIKELY((b - t) > BufferSize - 1) {
return false;
}
_buffer[b & BufferMask].store(std::forward<O>(o), std::memory_order_relaxed);
std::atomic_thread_fence(std::memory_order_release);
_bottom.store(b + 1, std::memory_order_release);
return true;
}
-
边界检查:通过
b - t > BufferSize - 1
判断队列是否满 -
内存序:
-
获取
_top
使用memory_order_acquire
-
存储
_bottom
前使用memory_order_release
栅栏 -
最终使用
memory_order_release
存储_bottom
-
2. push 操作 (带回调版本)
template <typename O, typename C>
void push(O&& o, C&& on_full) {
// 类似try_push但失败时调用回调
if TF_UNLIKELY((b - t) > BufferSize - 1) {
on_full();
return;
}
// ...其余与try_push相同
}
-
策略模式:队列满时调用用户提供的回调函数
-
灵活性:允许用户自定义队列满时的处理逻辑
3. pop 操作 (所有者线程)
T pop() {
int64_t b = _bottom.load(std::memory_order_relaxed) - 1;
_bottom.store(b, std::memory_order_relaxed);
std::atomic_thread_fence(std::memory_order_seq_cst);
int64_t t = _top.load(std::memory_order_relaxed);
T item {nullptr};
if(t <= b) {
item = _buffer[b & BufferMask].load(std::memory_order_relaxed);
if(t == b) {
// 处理与窃取线程的竞争
if(!_top.compare_exchange_strong(t, t+1,
std::memory_order_seq_cst,
std::memory_order_relaxed)) {
item = nullptr;
}
_bottom.store(b + 1, std::memory_order_relaxed);
}
}
else {
_bottom.store(b + 1, std::memory_order_relaxed);
}
return item;
}
-
竞争处理:使用CAS操作处理与窃取线程的竞争
-
内存序:关键操作使用
memory_order_seq_cst
保证顺序一致性
4. steal 操作 (工作窃取线程)
T steal() {
int64_t t = _top.load(std::memory_order_acquire);
std::atomic_thread_fence(std::memory_order_seq_cst);
int64_t b = _bottom.load(std::memory_order_acquire);
T item{nullptr};
if(t < b) {
item = _buffer[t & BufferMask].load(std::memory_order_relaxed);
if(!_top.compare_exchange_strong(t, t+1,
std::memory_order_seq_cst,
std::memory_order_relaxed)) {
return nullptr;
}
}
return item;
}
-
原子操作:使用CAS确保只有一个窃取线程能获取任务
-
内存序:使用
memory_order_acquire
和memory_order_seq_cst
5. steal_with_hint 操作
T steal_with_hint(size_t& num_empty_steals) {
// 类似steal但维护空窃取计数器
if(t < b) {
num_empty_steals = 0;
// ...窃取逻辑
}
else {
++num_empty_steals;
}
return item;
}
-
启发式策略:跟踪连续空窃取次数,可用于优化调度策略
-
性能优化:帮助调度器识别空闲队列
Taskflow 工作窃取的特点
-
分布式任务队列:
-
每个工作线程有自己的任务队列
-
减少了共享数据结构的竞争
-
-
窃取策略优化:
-
随机选择受害者线程,避免总是窃取同一线程
-
采用环形遍历,确保公平性
-
-
无锁与有锁结合:
-
自己的队列操作无锁(单线程访问)
-
窃取操作使用细粒度锁或原子操作
-
-
负载均衡:
-
空闲线程主动寻找工作
-
动态平衡各线程负载
-
-
任务优先级:
-
自己的任务优先执行(LIFO)
-
窃取的任务是较老的任务(FIFO)
-
与其他框架对比
特性 | Taskflow | TBB | OpenMP |
---|---|---|---|
队列结构 | 分布式双端队列 | 分布式队列 | 集中式队列 |
窃取策略 | 随机+环形 | 随机 | 通常无窃取 |
任务粒度 | 任意 | 中等 | 较大 |
调度开销 | 低 | 中 | 高 |
适用场景 | 细粒度任务 | 通用 | 数据并行 |
Taskflow 的工作窃取算法特别适合处理复杂任务依赖图的情况,能够高效地实现:
-
动态负载均衡
-
低调度开销
-
良好的可扩展性
这种设计使得 Taskflow 在复杂任务调度场景下表现出色,特别是对于不规则计算图和有动态任务生成的情况。