手写线程池
需求:
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <thread>
#include <functional>
#include <vector>
#include <unordered_map>
#include <queue>
#include <mutex>
#include <condition_variable>
using namespace std;
bool is_prime(int x) {
for (int i = 2; i * i <= x; i++) {
if (x % i == 0) return false;
}
return true;
}
int prime_count_test(int l, int r) {
int ans = 0;
for (int i = l; i <= r; i++) {
ans += is_prime(i);
}
return ans;
}
void worker(int l, int r, int &ret) {
ret = prime_count_test(l, r);
return ;
}
int main() {
#define batch 5000000
// 创建一个线程池
ThreadPool tp(5);
int ret[10];
// 把任务丢进任务队列中,等待执行
for (int i = 0, j = 1; i < 10; i++, j += batch) {
tp.add_task(worker, j, j + batch - 1, ref(ret[i]));
}
// 等待线程池结束,并销毁线程池
tp.stop();
int ans = 0;
for (auto x : ret) ans += x;
cout << ans << endl;
return 0;
}
使这段代码可以正常运行,这段代码是暴力求素数个数的一个测试用例
实现:
我们先来分析一下需求(我们想要创建哪些类,可以满足我们使用线程池的需求):
- 首先肯定得有一个线程池的类
- 其次线程还得从任务队列中取任务,所有我们还有一个任务队列的类,但是呢?由于
STL
中有已经封装好的,所有可以直接使用。 - 允许创建具有任意函数和相应参数的任务的一个类,我们这里叫做
Task
类。
我们来依次实现一下,先实现一个简单的类(Task
):
class Task {
public:
template<typename FUN_C, typename ...ARGS>
Task(FUN_C fun, ARGS ...args) {
func = bind(fun, forward<ARGS>(args)...);
}
void run() {
func();
}
private:
function<void()> func;
};
- 由于我们得接收任意函数的一个类,所有在这里我们得使用变参模板来实现这个需求,
- 我们还得把任务给打包好,所有我们这里使用的是
bind
函数来实现这个需求; - 我们得用
function
函数来接收和调用这个任务。
由于任务队列的类,STL
中有封装好的,所有我们直接来写我们主要的线程池类:
先来实现一个(ThreadPool tp(5);
)
class ThreadPool {
public:
ThreadPool(int n = 1) : threads(n), thread_size(n), starting(false) {
this->start();
return ;
}
void worker() {
while (1) {
// 取任务
// 执行任务
}
return ;
}
void start() {
if (starting == true) return ;
for (int i = 0; i < thread_size; i++) {
threads[i] = new thread(&ThreadPool::worker, this);
}
starting = true;
return ;
}
private:
bool starting;
queue<Task *> task_que;
vector<thread *> threads;
int thread_size;
};
-
构造函数:
ThreadPool(int n = 1) : threads(n), thread_size(n), starting(false) { this->start(); return ; }
这是线程池的构造函数,接受一个整数参数n,默认值为1。它初始化了线程池的一些成员变量,然后调用了
start
函数启动线程池。 -
start函数:
void start() { if (starting == true) return ; for (int i = 0; i < thread_size; i++) { threads[i] = new thread(&ThreadPool::worker, this); } starting = true; return ; }
start
函数用于启动线程池。它首先检查starting
变量,如果线程池已经启动,就直接返回。然后,它通过循环创建了指定数量的线程,每个线程都执行worker
函数。 -
worker函数:
void worker() { while (1) { // 取任务 // 执行任务 } return ; }
worker
函数是线程的主体函数,实现了线程的具体工作。它使用一个无限循环来不断地从任务队列中取出任务并执行。这里的具体任务处理逻辑需要根据实际需求来添加。 -
成员变量:
bool starting; queue<Task *> task_que; vector<thread *> threads; int thread_size;
starting
标志表示线程池是否已经启动。task_que
是一个任务队列,用于存储需要线程池执行的任务。threads
是一个存储线程指针的向量。thread_size
表示线程池中线程的数量。
再来实现一个tp.add_task(worker, j, j + batch - 1, ref(ret[i]));
class ThreadPool {
public:
ThreadPool(int n = 1) : threads(n), thread_size(n), starting(false) {
this->start();
return ;
}
void worker() {
while (1) {
Task *t = get_task();
t->run();
delete t;
}
return ;
}
void start() {
if (starting == true) return ;
for (int i = 0; i < thread_size; i++) {
threads[i] = new thread(&ThreadPool::worker, this);
}
starting = true;
return ;
}
template<typename FUN_C, typename ...ARGS>
void add_task(FUN_C fun, ARGS ...args) {
unique_lock<mutex> lock(m_mutex);
task_que.push(new Task(fun, forward<ARGS>(args)...));
m_cond.notify_one();
return ;
}
private:
Task *get_task() {
unique_lock<mutex> lock(m_mutex);
while (task_que.empty()) m_cond.wait(lock);
Task *t = task_que.front();
task_que.pop();
return t;
}
std::mutex m_mutex;
std::condition_variable m_cond;
bool starting;
queue<Task *> task_que;
vector<thread *> threads;
int thread_size;
};
-
worker函数:
void worker() { while (1) { Task *t = get_task(); t->run(); delete t; } return ; }
worker
函数是线程的主体函数,实现了线程的具体工作。它通过get_task
函数获取任务,执行任务的run
函数,然后释放任务的内存。 -
add_task函数:
template<typename FUN_C, typename ...ARGS> void add_task(FUN_C fun, ARGS ...args) { unique_lock<mutex> lock(m_mutex); task_que.push(new Task(fun, forward<ARGS>(args)...)); m_cond.notify_one(); return ; }
add_task
函数用于向任务队列中添加任务。它使用了可变参数模板,允许传递任意类型和数量的参数。添加任务后通过条件变量通知等待中的线程。 -
get_task函数:
Task *get_task() { unique_lock<mutex> lock(m_mutex); while (task_que.empty()) m_cond.wait(lock); Task *t = task_que.front(); task_que.pop(); return t; }
get_task
函数从任务队列中获取任务,如果队列为空,则线程等待条件变量的通知。获取到任务后,从队列中移除,并返回任务指针。 -
成员变量:
std::mutex m_mutex; std::condition_variable m_cond; bool starting; queue<Task *> task_que; vector<thread *> threads; int thread_size; unordered_map<decltype(this_thread::get_id()), bool> running;
m_mutex
是互斥锁,用于保护对任务队列的访问。m_cond
是条件变量,用于线程之间的同步。
实现tp.stop();
class ThreadPool {
public:
ThreadPool(int n = 1) : threads(n), thread_size(n), starting(false) {
this->start();
return ;
}
void worker() {
auto id = this_thread::get_id();
running[id] = true;
while (running[id]) {
Task *t = get_task();
t->run();
delete t;
}
return ;
}
void start() {
if (starting == true) return ;
for (int i = 0; i < thread_size; i++) {
threads[i] = new thread(&ThreadPool::worker, this);
}
starting = true;
return ;
}
template<typename FUN_C, typename ...ARGS>
void add_task(FUN_C fun, ARGS ...args) {
unique_lock<mutex> lock(m_mutex);
task_que.push(new Task(fun, forward<ARGS>(args)...));
m_cond.notify_one();
return ;
}
void stop() {
if (starting == false) return ;
for (int i = 0; i < thread_size; i++) {
task_que.push(new Task(&ThreadPool::stop_thread, this));
}
for (int i = 0; i < thread_size; i++) {
threads[i]->join();
}
for (int i = 0; i < thread_size; i++) {
delete threads[i];
threads[i] = nullptr;
}
starting = false;
return ;
}
~ThreadPool() {
this->stop();
while (!task_que.empty()) {
delete task_que.front();
task_que.pop();
}
return ;
}
private:
void stop_thread() {
auto id = this_thread::get_id();
running[id] = false;
return ;
}
Task *get_task() {
unique_lock<mutex> lock(m_mutex);
while (task_que.empty()) m_cond.wait(lock);
Task *t = task_que.front();
task_que.pop();
return t;
}
std::mutex m_mutex;
std::condition_variable m_cond;
bool starting;
queue<Task *> task_que;
vector<thread *> threads;
int thread_size;
unordered_map<decltype(this_thread::get_id()), bool> running;
};
-
worker函数:
void worker() { auto id = this_thread::get_id(); running[id] = true; while (running[id]) { Task *t = get_task(); t->run(); delete t; } return ; }
worker
函数是线程的主体函数,实现了线程的具体工作。在循环中,它通过get_task
函数获取任务,执行任务的run
函数,然后释放任务的内存。此外,每个线程都有一个运行状态,保存在running
映射中。 -
add_task函数:
template<typename FUN_C, typename ...ARGS> void add_task(FUN_C fun, ARGS ...args) { unique_lock<mutex> lock(m_mutex); task_que.push(new Task(fun, forward<ARGS>(args)...)); m_cond.notify_one(); return ; }
add_task
函数用于向任务队列中添加任务。它使用了可变参数模板,允许传递任意类型和数量的参数。添加任务后通过条件变量通知等待中的线程。 -
stop函数:
void stop() { if (starting == false) return ; for (int i = 0; i < thread_size; i++) { task_que.push(new Task(&ThreadPool::stop_thread, this)); } for (int i = 0; i < thread_size; i++) { threads[i]->join(); } for (int i = 0; i < thread_size; i++) { delete threads[i]; threads[i] = nullptr; } starting = false; return ; }
stop
函数用于停止线程池。它向任务队列中添加了一个特殊的任务,该任务调用stop_thread
函数,使所有线程退出循环。接着,等待每个线程完成,然后释放线程资源。 -
stop_thread函数:
void stop_thread() { auto id = this_thread::get_id(); running[id] = false; return ; }
stop_thread
函数用于在每个线程中调用,设置相应线程的运行状态为false,从而退出循环。 -
析构函数:
~ThreadPool() { this->stop(); while (!task_que.empty()) { delete task_que.front(); task_que.pop(); } return ; }
析构函数在对象销毁时会调用,这里调用了
stop
函数以确保线程池停止,并释放任务队列中的任务。running
是一个unordered_map,用于跟踪线程的运行状态。
完整代码:
/*************************************************************************
> File Name: thread_pool.cpp
> Author:Xiao Yuheng
> Mail:3312638794@qq.com
> Created Time: Sat Nov 18 18:18:18 2023
************************************************************************/
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <thread>
#include <functional>
#include <vector>
#include <unordered_map>
#include <queue>
#include <mutex>
#include <condition_variable>
using namespace std;
class Task {
public:
template<typename FUN_C, typename ...ARGS>
Task(FUN_C fun, ARGS ...args) {
func = bind(fun, forward<ARGS>(args)...);
}
void run() {
func();
}
private:
function<void()> func;
};
class ThreadPool {
public:
ThreadPool(int n = 1) : threads(n), thread_size(n), starting(false) {
this->start();
return ;
}
void worker() {
auto id = this_thread::get_id();
running[id] = true;
while (running[id]) {
Task *t = get_task();
t->run();
delete t;
}
return ;
}
void start() {
if (starting == true) return ;
for (int i = 0; i < thread_size; i++) {
threads[i] = new thread(&ThreadPool::worker, this);
}
starting = true;
return ;
}
template<typename FUN_C, typename ...ARGS>
void add_task(FUN_C fun, ARGS ...args) {
unique_lock<mutex> lock(m_mutex);
task_que.push(new Task(fun, forward<ARGS>(args)...));
m_cond.notify_one();
return ;
}
void stop() {
if (starting == false) return ;
for (int i = 0; i < thread_size; i++) {
task_que.push(new Task(&ThreadPool::stop_thread, this));
}
for (int i = 0; i < thread_size; i++) {
threads[i]->join();
}
for (int i = 0; i < thread_size; i++) {
delete threads[i];
threads[i] = nullptr;
}
starting = false;
return ;
}
~ThreadPool() {
this->stop();
while (!task_que.empty()) {
delete task_que.front();
task_que.pop();
}
return ;
}
private:
void stop_thread() {
auto id = this_thread::get_id();
running[id] = false;
return ;
}
Task *get_task() {
unique_lock<mutex> lock(m_mutex);
while (task_que.empty()) m_cond.wait(lock);
Task *t = task_que.front();
task_que.pop();
return t;
}
std::mutex m_mutex;
std::condition_variable m_cond;
bool starting;
queue<Task *> task_que;
vector<thread *> threads;
int thread_size;
unordered_map<decltype(this_thread::get_id()), bool> running;
};
bool is_prime(int x) {
for (int i = 2; i * i <= x; i++) {
if (x % i == 0) return false;
}
return true;
}
int prime_count_test(int l, int r) {
int ans = 0;
for (int i = l; i <= r; i++) {
ans += is_prime(i);
}
return ans;
}
void worker(int l, int r, int &ret) {
ret = prime_count_test(l, r);
return ;
}
int main() {
#define batch 5000000
// 创建一个线程池
ThreadPool tp(5);
int ret[10];
// 把任务丢进任务队列中,等待执行
for (int i = 0, j = 1; i < 10; i++, j += batch) {
tp.add_task(worker, j, j + batch - 1, ref(ret[i]));
}
// 等待线程池结束,并销毁线程池
tp.stop();
int ans = 0;
for (auto x : ret) ans += x;
cout << ans << endl;
return 0;
}