说明
- CountDownLatch是一次性的屏障:必须至少有count个线程到达(countdown),才将await上的线程放行;
- CyclicBarrier 则是可循环利用的屏障:每parties个线程都等待在await后,才会将这parties个线程放行;(对于主动break的情况例外)
CountDownLatch实现源码
#pragma once
#include<assert.h>
#include<mutex>
#include<atomic>
#include<condition_variable>
class CountDownLatch {
public:
CountDownLatch(int count):count(count){
assert(count > 0);
}
void await() {
if (count == 0) return;
std::unique_lock<std::mutex>lock(mx);
cond.wait(lock, [&]() {return count == 0; });
}
void countDown() {
int old_c = count.load();
while (old_c > 0) {
if (count.compare_exchange_strong(old_c, old_c - 1)) {
if (old_c == 1) {
std::unique_lock<std::mutex>lock(mx);
cond.notify_all();
}
break;
}
old_c = count.load();
}
}
int getCount() {
return count;
}
private:
std::atomic<int> count;
std::mutex mx;
std::condition_variable cond;
};
CyclicBarrier 实现源码
#pragma once
#include<assert.h>
#include<mutex>
#include<condition_variable>
template<typename CompleteFunc>
class CyclicBarrier {
private:
struct Generator {
unsigned int id;
bool isBroken = false;
Generator(int id=0):id(id){}
std::shared_ptr<Generator> nextGeneration() {
return std::shared_ptr<Generator>(new Generator(id + 1));
}
};
public:
CyclicBarrier(int parties, CompleteFunc& completeFunc):
parties(parties), count(parties), completeFunc(completeFunc),
generator(new Generator()){
}
int await() {
std::unique_lock<std::mutex>lock(mx);
auto g = generator;
if (g->isBroken) return -1;
int index=--count;
if (index == 0) {
if (completeFunc)
completeFunc();
nextGeneration();
return index;
}
cond.wait(lock);
if (g->isBroken)
return -1;
assert(g->id != generator->id);
return index;
}
void reset() {
std::unique_lock<std::mutex>lock(mx);
breakBarrier();
nextGeneration();
}
int getParties() {
return parties;
}
int getNumberWaiting() {
return parties - count;
}
private:
void nextGeneration() {
generator = generator->nextGeneration();
count = parties;
cond.notify_all();
}
void breakBarrier() {
generator->isBroken = true;
count = parties;
cond.notify_all();
}
const int parties;
int count;
CompleteFunc completeFunc;
std::shared_ptr<Generator> generator;
std::mutex mx;
std::condition_variable cond;
};