利用C++11的std::thread,实现生产消费模型。
/*任务接口*/
#ifndef _ITASK_HEAD_
#define _ITASK_HEAD_
class ITask
{
public:
ITask(void *arg = nullptr) :m_arg(arg)
{
}
virtual ~ITask()
{
}
void SetArg(void *arg)
{
m_arg = arg;
}
virtual void process() = 0;
protected:
void *m_arg;
};
#endif
/*任务*/
#ifndef _MYTASK_HEAD_
#define _MYTASK_HEAD_
#include "ITask.h"
class MyTask :public ITask
{
public:
MyTask();
virtual ~MyTask();
virtual void process();
};
#endif
#include "MyTask.h"
#include <iostream>
using namespace std;
MyTask::MyTask()
{
}
MyTask::~MyTask()
{
}
void MyTask::process()
{
cout << "Process Task:" << *(int*)m_arg<< endl;
delete m_arg;
m_arg = nullptr;
}
/*线程池*/
#ifndef _THREADPOOL_HEAD_
#define _THREADPOOL_HEAD_
#include "ITask.h"
#include <mutex>
#include <condition_variable>
#include <vector>
#include <map>
#include <thread>
class ThreadPool
{
public:
ThreadPool(int maxThreadCount = 100, int coreThreadCount = 3);
~ThreadPool();
void AddTask(ITask* task);
void join();
private:
void Init();
void ThreadFunc();
private:
int m_maxThreadCount;
int m_coreThreadCount;
int m_currentThreadCount;
int m_idleThreadCount;
std::mutex m_mutex;
std::condition_variable m_cond;
std::map<std::thread::id,std::thread*> m_threads;
std::vector<ITask*> m_tasks; // 任务队列
const double m_threshold = 0.5;
};
#endif
/*线程池实现*/
#include "ThreadPool.h"
#include <thread>
using namespace std;
ThreadPool::ThreadPool(int maxThreadCount, int coreThreadCount)
:m_maxThreadCount(maxThreadCount),
m_coreThreadCount(coreThreadCount),
m_currentThreadCount(0),
m_idleThreadCount(0)
{
Init();
}
ThreadPool::~ThreadPool()
{
}
void ThreadPool::Init()
{
unique_lock<mutex> lock(m_mutex);
m_coreThreadCount = m_maxThreadCount < m_coreThreadCount ? m_maxThreadCount : m_coreThreadCount;
for (int i = 0; i < m_coreThreadCount; i++)
{
thread *th = new thread(&ThreadPool::ThreadFunc, this);
m_threads.insert(make_pair(th->get_id(),th));
m_currentThreadCount++;
}
}
void ThreadPool::AddTask(ITask *task)
{
unique_lock<mutex> lock(m_mutex);
m_tasks.push_back(task);
if (m_idleThreadCount > 0)
{
m_cond.notify_one();
}
else if (m_currentThreadCount < m_maxThreadCount)
{
thread *th = new thread(&ThreadPool::ThreadFunc, this);
m_threads.insert(make_pair(th->get_id(),th));
m_currentThreadCount++;
}
}
void ThreadPool::join()
{
for (auto th : m_threads)
{
th.second->join();
}
}
void ThreadPool::ThreadFunc()
{
while (1)
{
unique_lock<mutex> lock(m_mutex);
while (m_tasks.empty())
{
m_idleThreadCount++;
if (m_idleThreadCount > m_threshold * m_currentThreadCount
&& m_currentThreadCount > m_coreThreadCount)
{
//auto iter = m_threads.find(std::this_thread::get_id());
//delete iter->second; 怎么释放thread对象呢?
//m_threads.erase(iter);
m_idleThreadCount--;
m_currentThreadCount--;
return;
}
m_cond.wait(lock);
m_idleThreadCount--;
}
ITask *task = *(m_tasks.begin());
m_tasks.erase(m_tasks.begin());
lock.unlock();
task->process();
delete task;
task = nullptr;
}
}
/*测试*/
#include "ThreadPool.h"
#include "MyTask.h"
using namespace std;
int _tmain(int argc, _TCHAR* argv[])
{
ThreadPool pool;
for (int i = 0; i < 20; i++)
{
MyTask* task = new MyTask();
task->SetArg(new int(i));
pool.AddTask(task);
}
pool.join();
return 0;
}