19.手写线程池

手写线程池

需求:

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <thread>
#include <functional>
#include <vector>
#include <unordered_map>
#include <queue>
#include <mutex>
#include <condition_variable>

using namespace std;

bool is_prime(int x) {
    for (int i = 2; i * i <= x; i++) {
        if (x % i == 0) return false;
    }
    return true;
}

int prime_count_test(int l, int r) {
    int ans = 0;
    for (int i = l; i <= r; i++) {
        ans += is_prime(i);
    }
    return ans;
}

void worker(int l, int r, int &ret) {
    ret = prime_count_test(l, r);
    return ;
}

int main() {
    #define batch 5000000
    // 创建一个线程池
    ThreadPool tp(5);
    int ret[10];
    // 把任务丢进任务队列中,等待执行
    for (int i = 0, j = 1; i < 10; i++, j += batch) {
        tp.add_task(worker, j, j + batch - 1, ref(ret[i]));
    }
    // 等待线程池结束,并销毁线程池
    tp.stop();
    int ans = 0;
    for (auto x : ret) ans += x;
    cout << ans << endl;
    return 0;
}

使这段代码可以正常运行,这段代码是暴力求素数个数的一个测试用例

实现:

​ 我们先来分析一下需求(我们想要创建哪些类,可以满足我们使用线程池的需求):

  • 首先肯定得有一个线程池的类
  • 其次线程还得从任务队列中取任务,所有我们还有一个任务队列的类,但是呢?由于STL中有已经封装好的,所有可以直接使用。
  • 允许创建具有任意函数和相应参数的任务的一个类,我们这里叫做Task类。

我们来依次实现一下,先实现一个简单的类(Task):

class Task {
public:
    template<typename FUN_C, typename ...ARGS>
    Task(FUN_C fun, ARGS ...args) {
        func = bind(fun, forward<ARGS>(args)...);
    }
    void run() {
        func();
    }

private:
    function<void()> func;
};
  • 由于我们得接收任意函数的一个类,所有在这里我们得使用变参模板来实现这个需求,
  • 我们还得把任务给打包好,所有我们这里使用的是bind函数来实现这个需求;
  • 我们得用function函数来接收和调用这个任务。

由于任务队列的类,STL中有封装好的,所有我们直接来写我们主要的线程池类:

先来实现一个(ThreadPool tp(5);
class ThreadPool {
public:
    ThreadPool(int n = 1) : threads(n), thread_size(n), starting(false) {
        this->start();
        return ;
    }
    void worker() {
        while (1) {
			// 取任务
            // 执行任务
        }
        return ;
    }
    void start() {
        if (starting == true) return ;
        for (int i = 0; i < thread_size; i++) {
            threads[i] = new thread(&ThreadPool::worker, this);
        }
        starting = true;
        return ;
    }

private:
    
    bool starting;
    queue<Task *> task_que;
    vector<thread *> threads;
    int thread_size;
};
  1. 构造函数:

    ThreadPool(int n = 1) : threads(n), thread_size(n), starting(false) {
        this->start();
        return ;
    }
    

    这是线程池的构造函数,接受一个整数参数n,默认值为1。它初始化了线程池的一些成员变量,然后调用了start函数启动线程池。

  2. start函数:

    void start() {
        if (starting == true) return ;
        for (int i = 0; i < thread_size; i++) {
            threads[i] = new thread(&ThreadPool::worker, this);
        }
        starting = true;
        return ;
    }
    

    start函数用于启动线程池。它首先检查starting变量,如果线程池已经启动,就直接返回。然后,它通过循环创建了指定数量的线程,每个线程都执行worker函数。

  3. worker函数:

    void worker() {
        while (1) {
            // 取任务
            // 执行任务
        }
        return ;
    }
    

    worker函数是线程的主体函数,实现了线程的具体工作。它使用一个无限循环来不断地从任务队列中取出任务并执行。这里的具体任务处理逻辑需要根据实际需求来添加。

  4. 成员变量:

    bool starting;
    queue<Task *> task_que;
    vector<thread *> threads;
    int thread_size;
    
    • starting标志表示线程池是否已经启动。
    • task_que是一个任务队列,用于存储需要线程池执行的任务。
    • threads是一个存储线程指针的向量。
    • thread_size表示线程池中线程的数量。
再来实现一个tp.add_task(worker, j, j + batch - 1, ref(ret[i]));
class ThreadPool {
public:
    ThreadPool(int n = 1) : threads(n), thread_size(n), starting(false) {
        this->start();
        return ;
    }
    void worker() {
        while (1) {
            Task *t = get_task();
            t->run();
            delete t;
        }
        return ;
    }
    void start() {
        if (starting == true) return ;
        for (int i = 0; i < thread_size; i++) {
            threads[i] = new thread(&ThreadPool::worker, this);
        }
        starting = true;
        return ;
    }
    template<typename FUN_C, typename ...ARGS>
    void add_task(FUN_C fun, ARGS ...args) {
        unique_lock<mutex> lock(m_mutex);
        task_que.push(new Task(fun, forward<ARGS>(args)...));
        m_cond.notify_one();
        return ;
    }

private:
    Task *get_task() {
        unique_lock<mutex> lock(m_mutex);
        while (task_que.empty()) m_cond.wait(lock);
        Task *t = task_que.front();
        task_que.pop();
        return t;
    }

    std::mutex m_mutex;
    std::condition_variable m_cond;
    bool starting;
    queue<Task *> task_que;
    vector<thread *> threads;
    int thread_size;
};
  1. worker函数:

    void worker() {
        while (1) {
            Task *t = get_task();
            t->run();
            delete t;
        }
        return ;
    }
    

    worker函数是线程的主体函数,实现了线程的具体工作。它通过get_task函数获取任务,执行任务的run函数,然后释放任务的内存。

  2. add_task函数:

    template<typename FUN_C, typename ...ARGS>
    void add_task(FUN_C fun, ARGS ...args) {
        unique_lock<mutex> lock(m_mutex);
        task_que.push(new Task(fun, forward<ARGS>(args)...));
        m_cond.notify_one();
        return ;
    }
    

    add_task函数用于向任务队列中添加任务。它使用了可变参数模板,允许传递任意类型和数量的参数。添加任务后通过条件变量通知等待中的线程。

  3. get_task函数:

    Task *get_task() {
        unique_lock<mutex> lock(m_mutex);
        while (task_que.empty()) m_cond.wait(lock);
        Task *t = task_que.front();
        task_que.pop();
        return t;
    }
    

    get_task函数从任务队列中获取任务,如果队列为空,则线程等待条件变量的通知。获取到任务后,从队列中移除,并返回任务指针。

  4. 成员变量:

    std::mutex m_mutex;
    std::condition_variable m_cond;
    bool starting;
    queue<Task *> task_que;
    vector<thread *> threads;
    int thread_size;
    unordered_map<decltype(this_thread::get_id()), bool> running;
    
    • m_mutex是互斥锁,用于保护对任务队列的访问。
    • m_cond是条件变量,用于线程之间的同步。
实现tp.stop();
class ThreadPool {
public:
    ThreadPool(int n = 1) : threads(n), thread_size(n), starting(false) {
        this->start();
        return ;
    }
    void worker() {
        auto id = this_thread::get_id();
        running[id] = true;
        while (running[id]) {
            Task *t = get_task();
            t->run();
            delete t;
        }
        return ;
    }
    void start() {
        if (starting == true) return ;
        for (int i = 0; i < thread_size; i++) {
            threads[i] = new thread(&ThreadPool::worker, this);
        }
        starting = true;
        return ;
    }
    template<typename FUN_C, typename ...ARGS>
    void add_task(FUN_C fun, ARGS ...args) {
        unique_lock<mutex> lock(m_mutex);
        task_que.push(new Task(fun, forward<ARGS>(args)...));
        m_cond.notify_one();
        return ;
    }
    void stop() {
        if (starting == false) return ;
        for (int i = 0; i < thread_size; i++) {
            task_que.push(new Task(&ThreadPool::stop_thread, this));
        }
        for (int i = 0; i < thread_size; i++) {
            threads[i]->join();
        }
        for (int i = 0; i < thread_size; i++) {
            delete threads[i];
            threads[i] = nullptr;
        }
        starting = false;
        return ;
    }
    ~ThreadPool() {
        this->stop();
        while (!task_que.empty()) {
            delete task_que.front();
            task_que.pop();
        }
        return ;
    }

private:
    void stop_thread() {
        auto id = this_thread::get_id();
        running[id] = false;
        return ;
    }
    Task *get_task() {
        unique_lock<mutex> lock(m_mutex);
        while (task_que.empty()) m_cond.wait(lock);
        Task *t = task_que.front();
        task_que.pop();
        return t;
    }

    std::mutex m_mutex;
    std::condition_variable m_cond;
    bool starting;
    queue<Task *> task_que;
    vector<thread *> threads;
    int thread_size;
    unordered_map<decltype(this_thread::get_id()), bool> running;
};
  1. worker函数:

    void worker() {
        auto id = this_thread::get_id();
        running[id] = true;
        while (running[id]) {
            Task *t = get_task();
            t->run();
            delete t;
        }
        return ;
    }
    

    worker函数是线程的主体函数,实现了线程的具体工作。在循环中,它通过get_task函数获取任务,执行任务的run函数,然后释放任务的内存。此外,每个线程都有一个运行状态,保存在running映射中。

  2. add_task函数:

    template<typename FUN_C, typename ...ARGS>
    void add_task(FUN_C fun, ARGS ...args) {
        unique_lock<mutex> lock(m_mutex);
        task_que.push(new Task(fun, forward<ARGS>(args)...));
        m_cond.notify_one();
        return ;
    }
    

    add_task函数用于向任务队列中添加任务。它使用了可变参数模板,允许传递任意类型和数量的参数。添加任务后通过条件变量通知等待中的线程。

  3. stop函数:

    void stop() {
        if (starting == false) return ;
        for (int i = 0; i < thread_size; i++) {
            task_que.push(new Task(&ThreadPool::stop_thread, this));
        }
        for (int i = 0; i < thread_size; i++) {
            threads[i]->join();
        }
        for (int i = 0; i < thread_size; i++) {
            delete threads[i];
            threads[i] = nullptr;
        }
        starting = false;
        return ;
    }
    

    stop函数用于停止线程池。它向任务队列中添加了一个特殊的任务,该任务调用stop_thread函数,使所有线程退出循环。接着,等待每个线程完成,然后释放线程资源。

  4. stop_thread函数:

    void stop_thread() {
        auto id = this_thread::get_id();
        running[id] = false;
        return ;
    }
    

    stop_thread函数用于在每个线程中调用,设置相应线程的运行状态为false,从而退出循环。

  5. 析构函数:

    ~ThreadPool() {
        this->stop();
        while (!task_que.empty()) {
            delete task_que.front();
            task_que.pop();
        }
        return ;
    }
    

    析构函数在对象销毁时会调用,这里调用了stop函数以确保线程池停止,并释放任务队列中的任务。

    running是一个unordered_map,用于跟踪线程的运行状态。

完整代码:

/*************************************************************************
        > File Name: thread_pool.cpp
        > Author:Xiao Yuheng
        > Mail:3312638794@qq.com
        > Created Time: Sat Nov 18 18:18:18 2023
 ************************************************************************/

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <thread>
#include <functional>
#include <vector>
#include <unordered_map>
#include <queue>
#include <mutex>
#include <condition_variable>

using namespace std;

class Task {
public:
    template<typename FUN_C, typename ...ARGS>
    Task(FUN_C fun, ARGS ...args) {
        func = bind(fun, forward<ARGS>(args)...);
    }
    void run() {
        func();
    }

private:
    function<void()> func;
};

class ThreadPool {
public:
    ThreadPool(int n = 1) : threads(n), thread_size(n), starting(false) {
        this->start();
        return ;
    }
    void worker() {
        auto id = this_thread::get_id();
        running[id] = true;
        while (running[id]) {
            Task *t = get_task();
            t->run();
            delete t;
        }
        return ;
    }
    void start() {
        if (starting == true) return ;
        for (int i = 0; i < thread_size; i++) {
            threads[i] = new thread(&ThreadPool::worker, this);
        }
        starting = true;
        return ;
    }
    template<typename FUN_C, typename ...ARGS>
    void add_task(FUN_C fun, ARGS ...args) {
        unique_lock<mutex> lock(m_mutex);
        task_que.push(new Task(fun, forward<ARGS>(args)...));
        m_cond.notify_one();
        return ;
    }
    void stop() {
        if (starting == false) return ;
        for (int i = 0; i < thread_size; i++) {
            task_que.push(new Task(&ThreadPool::stop_thread, this));
        }
        for (int i = 0; i < thread_size; i++) {
            threads[i]->join();
        }
        for (int i = 0; i < thread_size; i++) {
            delete threads[i];
            threads[i] = nullptr;
        }
        starting = false;
        return ;
    }
    ~ThreadPool() {
        this->stop();
        while (!task_que.empty()) {
            delete task_que.front();
            task_que.pop();
        }
        return ;
    }

private:
    void stop_thread() {
        auto id = this_thread::get_id();
        running[id] = false;
        return ;
    }
    Task *get_task() {
        unique_lock<mutex> lock(m_mutex);
        while (task_que.empty()) m_cond.wait(lock);
        Task *t = task_que.front();
        task_que.pop();
        return t;
    }

    std::mutex m_mutex;
    std::condition_variable m_cond;
    bool starting;
    queue<Task *> task_que;
    vector<thread *> threads;
    int thread_size;
    unordered_map<decltype(this_thread::get_id()), bool> running;
};

bool is_prime(int x) {
    for (int i = 2; i * i <= x; i++) {
        if (x % i == 0) return false;
    }
    return true;
}

int prime_count_test(int l, int r) {
    int ans = 0;
    for (int i = l; i <= r; i++) {
        ans += is_prime(i);
    }
    return ans;
}

void worker(int l, int r, int &ret) {
    ret = prime_count_test(l, r);
    return ;
}

int main() {
    #define batch 5000000
    // 创建一个线程池
    ThreadPool tp(5);
    int ret[10];
    // 把任务丢进任务队列中,等待执行
    for (int i = 0, j = 1; i < 10; i++, j += batch) {
        tp.add_task(worker, j, j + batch - 1, ref(ret[i]));
    }
    // 等待线程池结束,并销毁线程池
    tp.stop();
    int ans = 0;
    for (auto x : ret) ans += x;
    cout << ans << endl;
    return 0;
}
  • 32
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值