C++线程池的实现

这是一个使用C++11标准实现的线程池类库,包括线程类、阻塞队列类和线程工厂类。线程池能够根据任务需求动态调整工作线程数量,支持任务的添加、执行和线程的回收。测试代码展示了如何创建线程池并提交任务,以及线程池的工作流程。
摘要由CSDN通过智能技术生成

编译环境:C++11

线程池源码:

#ifndef THREAD_POOL_H
#define THREAD_POOL_H

#include <thread> // std::thread
#include <functional> // std::function
#include <mutex> // std::mutex  std::unique_lock
#include <chrono> // std::chrono
#include <atomic> // std::atomic
#include <memory> // std::shared_ptr
#include <iostream> // std::cout

class Thread
{
public:
    template<typename _Callable, typename... _Args>
    Thread(_Callable&& __f, _Args&&... __args)
     : _M_handler(std::forward<_Callable>(__f), std::forward<_Args>(__args)...)
    { }

    Thread(const Thread&) = delete;
    Thread(Thread &&__t)
     : _M_handler(std::move(__t._M_handler))
    { }

    // 异步执行
    void detach()
    { _M_handler.detach(); }

private:
    std::thread _M_handler;
};

template<typename _Tp>
class BlockingQueue
{
public:
    using ValueType = _Tp;

private:
    struct _Node
    {
        ValueType data;
        _Node *prev;
        _Node *next;
    };

    using NodeType = _Node;

public:
    BlockingQueue()
     : _M_head(nullptr), _M_tail(nullptr), _M_lock(), _M_count(0)
    { }

    BlockingQueue(const BlockingQueue &__q)
     : BlockingQueue()
    {
        NodeType *__node = __q._M_head;
        while(__node)
        {
            _M_push_node(new NodeType{__node->data, nullptr, nullptr});
            __node = __node->next;
        }
    }

    BlockingQueue(BlockingQueue&& __q)
     : BlockingQueue()
    {
        std::swap(_M_head, __q._M_head);
        std::swap(_M_tail, __q._M_tail);
        std::swap(_M_count, __q._M_count);
    }

    ~BlockingQueue()
    {
        if(_M_head)
        {
            NodeType *__node = _M_head, *__tmp;
            while(__node)
            {
                __tmp = __node;
                __node = __node->next;
                delete __tmp;
            }
            _M_head = _M_tail = nullptr;
        }
    }

    void push(const ValueType &__v)
    {
        std::unique_lock<std::mutex> __lock(_M_lock);
        _M_push_node(new NodeType{std::forward<ValueType>(__v), nullptr, nullptr});
    }

    void push(ValueType&& __v)
    {
        std::unique_lock<std::mutex> __lock(_M_lock);
        _M_push_node(new NodeType{std::move(__v), nullptr, nullptr});
    }

    ValueType pop()
    {
        std::unique_lock<std::mutex> __lock(_M_lock);
        ValueType __ret;
        if(_M_head)
        {
            ::new(&__ret) ValueType(std::move(_M_head->data));

            NodeType *__node = _M_head;
            _M_head = _M_head->next;
            delete __node;
            if(!_M_head)
            {
                _M_tail = nullptr;
            }
            --_M_count;
        }
        __lock.unlock();
        return __ret;
    }

    bool empty() const 
    { return size() == 0; }

    unsigned long long size() const 
    { return _M_count; }

private:
    void _M_push_node(NodeType *__node)
    {
        if(!_M_head)
        {
            _M_head = _M_tail = __node;
        }
        else 
        {
            __node->prev = _M_tail;
            _M_tail->next = __node;
            _M_tail = __node;
        }
        ++_M_count;
    }

private:    
    NodeType *_M_head;
    NodeType *_M_tail;
    std::mutex _M_lock;
    unsigned long long _M_count;
};

template<typename _ThreadType>
class ThreadFactory
{
public:
    using ValueType = _ThreadType;

public:
    ThreadFactory(unsigned int __pool_size, unsigned int __max_size)
     : _M_cur_size(0), _M_pool_size(__pool_size), _M_max_size(__max_size), _M_free_size(0)
    { }

    ThreadFactory(const ThreadFactory&) = delete;
    ThreadFactory(ThreadFactory&&) = delete;

    // 线程池容量
    unsigned int capacity() const 
    { return _M_pool_size; }

    // 线程池最大数量
    unsigned int max_size() const 
    { return _M_max_size; }

    // 线程池当前正在运行的线程数
    unsigned int size() const 
    { return _M_cur_size; }

    // 生成一个线程
    template<typename _Callable, typename... _Args>
    ValueType alloc(_Callable&& __f, _Args&&... __args)
    {
        ++_M_cur_size;
        ++_M_free_size;
        return ValueType(std::forward<_Callable>(__f), std::forward<_Args>(__args)...);
    }

    // 空闲回调
    void free_callback()
    { ++_M_free_size; }

    // 工作回调
    void work_callback()
    { --_M_free_size; }

    // 空闲线程数量
    unsigned int free_size() const 
    { return _M_free_size; }

    // 线程结束回调
    void finish()
    {
        if(_M_cur_size-- <= 0)
        {
            _M_cur_size = 0;
        }
    }

private:
    std::atomic_int _M_cur_size;
    std::atomic_uint _M_pool_size;
    std::atomic_uint _M_max_size;
    std::atomic_uint _M_free_size;
};

template<typename _ThreadFactory>
class RejectedStrategy
{
public:
    // 是否拒绝新任务
    bool operator()(const _ThreadFactory &__tf)
    {
        return __tf.size() >= __tf.max_size();
    }
};

using Runnable = std::function<void()>;

template<typename _ThreadFactory = ThreadFactory<Thread>,
         typename _BlockingQueue = BlockingQueue<Runnable>,
         typename _RejectedStrategy = RejectedStrategy<_ThreadFactory>>
class ThreadPoolImpl
{
public:
    using FactoryType = _ThreadFactory;
    using BlockingQueueType = _BlockingQueue;
    using RejectedStrategyType = _RejectedStrategy;

    using FactoryTypePtr = std::shared_ptr<FactoryType>;
    using BlockingQueueTypePtr = std::shared_ptr<BlockingQueueType>;
    using RejectedStrategyTypePtr = std::shared_ptr<RejectedStrategyType>;
    using RunFlagTypePtr = std::shared_ptr<std::atomic_bool>;

    using ThreadType = typename FactoryType::ValueType;

public:
    /*
     * @__pool_size: 核心工作线程最大数量
     * @__max_size: 工作线程最大数量
     * @__keep_alive_time: 非核心工作线程允许空闲时间,单位:毫秒
     */
    ThreadPoolImpl(unsigned int __pool_size, unsigned int __max_size, long long __keep_alive_time)
     : _M_factory(std::make_shared<FactoryType>(__pool_size, __max_size)), 
     _M_queue(std::make_shared<BlockingQueueType>()), 
     _M_rejected_strategy(std::make_shared<RejectedStrategyType>()), 
     _M_run_flag(std::make_shared<std::atomic_bool>(true)), 
     _M_keep_alive_time(__keep_alive_time)
    { }

    ThreadPoolImpl(const ThreadPoolImpl&) = delete;
    ThreadPoolImpl(ThreadPoolImpl&&) = delete;

    ~ThreadPoolImpl()
    {
        *_M_run_flag = false;
    }

    bool execute(Runnable&& __r)
    {
        if((*_M_rejected_strategy)(*_M_factory))
        {
            return false;
        }
        
        _M_queue->push(std::move(__r));

        // 有空闲的线程就不启动新线程了
        if(_M_factory->free_size() > 0)
        { }
        // 当前运行的线程数未达到容量,则启动核心工作线程
        else if(_M_factory->size() < _M_factory->capacity())
        {
            _M_factory->alloc([](FactoryTypePtr __f, BlockingQueueTypePtr __q, RunFlagTypePtr __flag)
            {
                while(*__flag)
                {
                    Runnable __func = __q->pop();
                    if(__func)
                    {
                        __f->work_callback();
                        __func();
                        __f->free_callback();
                    }
                }
                __f->finish();
                std::cout << "core thread finish" << std::endl;
            }, _M_factory, _M_queue, _M_run_flag).detach();
        }
        // 当前运行的线程数已达到容量,则启动非核心工作线程
        else 
        {
            _M_factory->alloc([]
            (FactoryTypePtr __f, BlockingQueueTypePtr __q, RunFlagTypePtr __flag, long long __t)
            {
                long long __last, __now;
                while(*__flag)
                {
                    Runnable __func = __q->pop();
                    if(__func)
                    {
                        __f->work_callback();
                        __func();
                        __f->free_callback();
                        auto __clock = std::chrono::system_clock::now().time_since_epoch();
                        __last = std::chrono::duration_cast<std::chrono::milliseconds>(__clock).count();
                    }
                    auto __clock = std::chrono::system_clock::now().time_since_epoch();
                    __now = std::chrono::duration_cast<std::chrono::milliseconds>(__clock).count();

                    if(__now - __last >= __t)
                    {
                        break;
                    }
                }
                __f->finish();
                std::cout << "finish now thread mum: " << __f->size() << std::endl;
            }, _M_factory, _M_queue, _M_run_flag, _M_keep_alive_time).detach();
        }
        return true;
    }

    const FactoryTypePtr factory() const 
    { return _M_factory; }

    const BlockingQueueTypePtr blocking_queue() const 
    { return _M_queue; }

    const RejectedStrategyTypePtr rejected_strategy() const 
    { return _M_rejected_strategy; }

private:
    FactoryTypePtr _M_factory;
    BlockingQueueTypePtr _M_queue;
    RejectedStrategyTypePtr _M_rejected_strategy;
    RunFlagTypePtr _M_run_flag;
    long long _M_keep_alive_time;
};

using ThreadPool = ThreadPoolImpl<ThreadFactory<Thread>, 
                                   BlockingQueue<Runnable>, 
                                   RejectedStrategy<ThreadFactory<Thread>>>;

#endif // THREAD_POOL_H

测试代码:

#include <iostream> // std::cout
#include <chrono> // std::chrono
#include <thread> // std::this_thread::sleep_for
#include "TheadPool.h" // ThreadPool

static void func()
{
    std::this_thread::sleep_for(std::chrono::milliseconds{5000});
    std::cout << "my func finish" << std::endl;
}

static void f()
{
    ThreadPool pool(4, 8, 1000);
    const ThreadPool::FactoryTypePtr factory = pool.factory();

    for(int i = 0; i < 10; ++i)
    {
        pool.execute(func);
        std::cout << "insert now thread num: " << factory->size() << std::endl;
        std::this_thread::sleep_for(std::chrono::milliseconds{1000});
    }
}

int main()
{
    f(); // 为了演示内存释放后的线程运行状况,将线程池弄到另一个函数内部
#ifdef __WIN32
    system("pause");
#endif
    return 0;
}

测试结果:

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值