在学习线程池之前,先看下线程池的两个基本组件:任务队列与线程组。
任务队列TaskQueue
提供了添加任务、获取任务、清空任务队列接口,通过semaphore控制添加与获取之间的同步。
semaphore
封装条件变量condition_variable_any的使用,
class semaphore {
private:
size_t _count;
std::recursive_mutex _mutex;
std::condition_variable_any _condition;
}
提供了post和wait接口
void post(size_t n = 1) {
std::unique_lock<std::recursive_mutex> lock(_mutex);
_count += n;
if (n == 1) {
_condition.notify_one();
} else {
_condition.notify_all();
}
}
void wait() {
std::unique_lock<std::recursive_mutex> lock(_mutex);
while (_count == 0) { //之所以while是避免虚假唤醒
_condition.wait(lock);
}
--_count;
}
_count成员记录希望wait接口返回的次数。默认情况下,每调用一次post,_count加1,通过notify_one,所有等待在_condition.wait上的线程其中一个被唤醒。如果n大于1,将通过notify_all唤醒所有等待在_condition.wait上的线程。
对于生产者消费者模型,count记录产品数量,post每次生产一个或者多个产品,wait每次消费一个产品,wait被唤醒后,会一直执行,直到产品都被消耗完。
TaskQueue
思考:如果我们要写一个任务队列,用来存储任务,它需要些什么东西。
- 首先,它应该有一个容器,用来存放任务,这里我们用list来存储
- 然后,因为我们不知道任务是什么类型的,所以我们用到了,模板
template<typename T>
class TaskQueue {
private:
std::list<T> _queue;
};
然后,它应该提供如下几个接口:
- push_task:将任务放入任务队列,push(需要放入的任务)
- 参数也应该是模板,因为不知道具体类型,即 push(T)
- 为保证原模原样,xxxx应该是一个右值,即push(T &&)
- get_task:如果任务队列不为空,那么就取出一个任务
- 它可以返回两个东西,任务&&是否成功获取到了任务,所以这个函数可以长这样: bool get_task(T &)
- size:当前任务队列有多长,所以为 int size()
- clear:清空所有任务
为什么不直接使用void push_task(T &&task_func)而要用模板呢?
template<typename C> void push_task(C &&task_func);
void push_task(T &&task_func);
void push_task(T &task_func);
void push_task(const T &task_func);
push_task(T &&task_func)只能接受右值引用,push_task(T &task_func)不能接受右值引用,push_task(const T &task_func)接受右值时会产生临时变量。模板因为引用折叠,既能接受左值,也能接受右值。
template<typename T>
class TaskQueue {
public:
template<typename C>
void push_task(C &&task_func) {
_queue.emplace_back(std::forward<C>(task_func));
}
template<typename C>
void push_task_first(C &&task_func) {
_queue.emplace_front(std::forward<C>(task_func));
}
//从列队获取一个任务,由执行线程执行
bool get_task(T &tsk){
if (_queue.empty()) {
return false;
}
tsk = std::move(_queue.front());
_queue.pop_front();
return true;
}
size_t size() const {
return _queue.size();
}
void clear(){
_queue.clear();
}
private:
std::list<T> _queue;
};
上面是线程不安全的,而任务队列通常需要在多个线程中用到,所以需要一个std::mutex成员。又在size_t size() const
也需要用到锁,所以这个成员应该是mutable
size_t size() const {
_mutex.lock();
int n = _queue.size();
_mutex.unlock();
return n;
}
private:
mutable std::mutex _mutex;
当然,有更时尚的写法:
size_t size() const {
std::lock_guard<decltype(_mutex)> guard(_mutex);
return _queue.size();
}
private:
mutable std::mutex _mutex;
std::list<T> _queue;
完整的写法是这样的:
template<typename T>
class TaskQueue {
public:
template<typename C>
void push_task(C &&task_func) {
std::lock_guard<decltype(_mutex)> guard(_mutex);
_queue.emplace_back(std::forward<C>(task_func));
}
template<typename C>
void push_task_first(C &&task_func) {
std::lock_guard<decltype(_mutex)> guard(_mutex);
_queue.emplace_front(std::forward<C>(task_func));
}
//从列队获取一个任务,由执行线程执行
bool get_task(T &tsk){
std::lock_guard<decltype(_mutex)> guard(_mutex);
if (_queue.empty()) {
return false;
}
tsk = std::move(_queue.front());
_queue.pop_front();
return true;
}
size_t size() const {
std::lock_guard<decltype(_mutex)> guard(_mutex);
return _queue.size();
}
void clear(){
std::lock_guard<decltype(_mutex)> guard(_mutex);
_queue.clear();
}
private:
mutable std::mutex _mutex;
std::list<T> _queue;
};
等一等,如果我们在获取任务是要求没有任务时就等待,知道获取到任务,而不是直接返回false。应该怎么做呢?用信号量(条件变量)来做同步
template<typename C>
void push_task(C &&task_func) {
{
std::lock_guard<decltype(_mutex)> guard(_mutex);
_queue.emplace_back(std::forward<C>(task_func));
}
_sem.post();
}
template<typename C>
void push_task_first(C &&task_func) {
{
std::lock_guard<decltype(_mutex)> guard(_mutex);
_queue.emplace_front(std::forward<C>(task_func));
}
_sem.post();
}
//从列队获取一个任务,由执行线程执行
bool get_task(T &tsk){
_sem.wait(); //等待任务
std::lock_guard<decltype(_mutex)> guard(_mutex);
if (_queue.empty()) {
return false;
}
tsk = std::move(_queue.front());
_queue.pop_front();
return true;
}
private:
mutable std::mutex _mutex;
std::list<T> _queue;
semaphore _sem;
而clear()时需要改一下,应为如果不改的话,可能造成在clear()之前就get_task()的某些调用地方死等。
void clear(size_t n){
_sem.post(n);
}
这么改之后get_task接口返回false。
注意,在代码中clear–》push_exit
void push_exit(size_t n) {
_sem.post(n);
}
线程池中利用get_task返回false来达到退出线程的目的。如启动了4个工作线程来获取任务,每个线程都等待在get_task上,主线程中调用push_exit(4),get_task就会有四次返回false,以此让工作线程退出。
线程组thread_group
管理一组线程。首先它ban掉了拷贝构造和拷贝复制
class thread_group {
private:
thread_group(thread_group const &);
thread_group &operator=(thread_group const &);
然后它提供了如下功能:
//判断当前线程是否在线程组中
bool is_this_thread_in()
//判断指定线程是否在线程组中
bool is_thread_in(std::thread *thrd)
// 添加一个线程到线程组中
template<typename F>
std::thread *create_thread(F &&threadfunc)
// 移除指定线程
void remove_thread(std::thread *thrd)
// 阻塞等待线程组中所有线程退出
void join_all()
// 当前现场组管理着的线程数量
size_t size()
它使用了一个map来管理线程组中的所有线程,key为线程ID,value就是相应线程:
private:
std::unordered_map<std::thread::id, std::shared_ptr<std::thread>> _threads;