如题所示,周末无聊,手写一份线程池代码;代码比较简单,大家看看就好。
前言
- 声明: 只是写个示例代码,表示了一般的线程池编写方法,是基于“单生产者-多消费者”的模型。
- 代码仅供参考,如果实际运用,还是有很多优化的地方。
- 这个模型应该适用于很多地方,例如开发TCP服务器模块。
一、主要干什么事
- 用一张图简单表示一下。任务主要是“计算2个数字之和”;两个数字来源于TCP客户端的发送数据。
二、UML
- 大概是这样,用一张图表示。
三、贴个代码
//CTaskcalculate.h
#pragma once
#include <iostream>
#include "TaskBase.h"
class CTaskcalculate :
public CTaskBase
{
public:
CTaskcalculate(int a, int b);
protected:
virtual bool DoAction() override;
private:
int m_par1;
int m_par2;
};
//CTaskcalculate.cpp
#include "stdafx.h"
#include "CTaskcalculate.h"
CTaskcalculate::CTaskcalculate(int a, int b)
: m_par1(a),m_par2(b)
{
}
bool CTaskcalculate::DoAction()
{
int c = m_par1 + m_par2;
std::cout<< m_par1 << "+" << m_par2 << "=" << c << std::endl;
return true;
}
//CTaskThread.h
#pragma once
#include <iostream>
#include "CTaskThreadBase.h"
#include "TaskQeue.h"
class CTaskThread :
public CTaskThreadBase
{
public:
CTaskThread();
~CTaskThread() {}
protected:
virtual void Run() override;
};
//CTaskThread.cpp
#include "stdafx.h"
#include "CTaskThread.h"
CTaskThread::CTaskThread()
{
}
void CTaskThread::Run()
{
while (true)
{
auto it = CTaskQeue::Instance().GetTask();
if (it != nullptr)
{
std::cout << "id = " << std::this_thread::get_id()<< " get task ";
it->DoAction();
}
}
}
//CTaskThreadBase.h
#pragma once
#include <thread>
#include <memory>
#include <iostream>
class CTaskThreadBase
{
public:
CTaskThreadBase();
~CTaskThreadBase() {}
void Start();
void Exit();
void Detach();
void Join();
protected:
virtual void Run() = 0;
std::shared_ptr<std::thread> m_thread;
};
//CTaskThreadBase.cpp
#include "stdafx.h"
#include "CTaskThreadBase.h"
CTaskThreadBase::CTaskThreadBase()
{
}
void CTaskThreadBase::Start()
{
m_thread = std::make_shared<std::thread>(&CTaskThreadBase::Run, this);
}
void CTaskThreadBase::Exit()
{
}
void CTaskThreadBase::Detach()
{
m_thread->detach();
}
void CTaskThreadBase::Join()
{
if (m_thread->joinable())
m_thread->join();
}
//TaskBase.h
#pragma once
#include <iostream>
class CTaskBase
{
public:
CTaskBase(void);
virtual bool DoAction() = 0;
virtual ~CTaskBase(void);
};
//TaskBase.cpp
#include "StdAfx.h"
#include "TaskBase.h"
CTaskBase::CTaskBase(void)
{
}
CTaskBase::~CTaskBase(void)
{
std::cout << "CTaskBase out\n";
}
//TaskQeue.h
#pragma once
#include <mutex>
#include <queue>
#include <memory>
#include "TaskBase.h"
class CTaskQeue
{
public:
~CTaskQeue(void);
bool PutTask(std::shared_ptr<CTaskBase> ptr);
std::shared_ptr<CTaskBase> GetTask();
bool Empty();
static CTaskQeue& Instance();
private:
CTaskQeue(void);
static CTaskQeue m_taskIns;
std::mutex m_mutex;
std::queue<std::shared_ptr<CTaskBase>> m_queue;
std::condition_variable m_cond;
};
//TaskQeue.cpp
#include "StdAfx.h"
#include "TaskQeue.h"
CTaskQeue CTaskQeue::m_taskIns;
CTaskQeue::CTaskQeue(void)
{
}
CTaskQeue::~CTaskQeue(void)
{
}
CTaskQeue& CTaskQeue::Instance()
{
return m_taskIns;
}
bool CTaskQeue::PutTask(std::shared_ptr<CTaskBase> ptr)
{
std::unique_lock<std::mutex> lk(m_mutex);
m_queue.push(ptr);
lk.unlock();
m_cond.notify_one();
return true;
}
std::shared_ptr<CTaskBase> CTaskQeue::GetTask()
{
std::unique_lock<std::mutex> lk(m_mutex);
m_cond.wait(lk, [this]{
return !m_queue.empty();
});
std::shared_ptr<CTaskBase> ptr = m_queue.front();
m_queue.pop();
return ptr;
}
bool CTaskQeue::Empty()
{
std::lock_guard<std::mutex> lk(m_mutex);
bool bEmpty = m_queue.empty();
return bEmpty;
}
// threadpool.cpp : 定义控制台应用程序的入口点。
#include "stdafx.h"
#include <winsock2.h>
#include <vector>
#include "TaskQeue.h"
#include "CTaskThread.h"
#include "CTaskcalculate.h"
#pragma comment(lib,"ws2_32.lib")
int RecieveThread(int clientfd)
{
if (clientfd)
{
unsigned char buf[2] = { 0 };
while (1)
{
int nRet = recv(clientfd, (char*)buf, 2, 0);
if (nRet > 0)
{
auto it = std::make_shared<CTaskcalculate>(buf[0], buf[1]);
CTaskQeue::Instance().PutTask(it);
std::cout << "recv param:" << int(buf[0]) << "," << int(buf[1]) << std::endl;
}
}
}
return 0;
}
int _tmain(int argc, _TCHAR* argv[])
{
WSADATA a;
WORD version = MAKEWORD(2, 2);
WSAStartup(version, &a);
int sock_fd = 0;
struct sockaddr_in serv_addr;
struct sockaddr_in client_addr;
int addr_len = sizeof(struct sockaddr_in);
sock_fd = socket(AF_INET, SOCK_STREAM, 0);
memset(&serv_addr, 0, sizeof(struct sockaddr_in));
serv_addr.sin_family = AF_INET;
serv_addr.sin_addr.s_addr = 0;
serv_addr.sin_port = htons(10001);
int ret = bind(sock_fd, (struct sockaddr*)&serv_addr, sizeof(struct sockaddr_in));
if (ret == 0) {
printf("bind ok\n");
}
else {
printf("bind failed\n");
closesocket(sock_fd);
return 0;
}
ret = listen(sock_fd, 10);
if(ret == 0)
{
printf("listen ok\n");
}
else {
printf("listen failed\n");
closesocket(sock_fd);
return 0;
}
int clientfd = accept(sock_fd, (sockaddr*)&client_addr, &addr_len);
std::vector<CTaskThread> m_vec;
for (int i = 0; i < 30; i++)
{
CTaskThread aaa;
aaa.Start();
m_vec.push_back(aaa);
}
if (clientfd)
{
std::thread m_temp(RecieveThread, clientfd);
m_temp.join();
}
for (int i = 0; i < 30; i++)
{
m_vec[i].Join();
}
//实际不会走到这步
closesocket(clientfd);
closesocket(sock_fd);
WSACleanup();
return 0;
}