转载:醍醐灌顶全方位击破C++线程池及异步处理 - 知乎 (zhihu.com)
重点:
转载的代码有点乱,一共有二种方法。
第一种:
重点:
1.采用嵌套类:为了不被外部访问。
SafeQueue.h
#pragma once
#include <mutex>
#include <queue>
using namespace std;
template<typename T>
class SafeQueue
{
public:
SafeQueue() = default;
~SafeQueue() = default;
public:
bool Empty()
{
unique_lock<mutex> lock(mu);
return m_queue.empty();
}
int Size()
{
unique_lock<mutex> lock(mu);
return m_queue.size();
}
//向队列添加元素
void Enqueue(const T& t)
{
unique_lock<mutex> lock(mu);
m_queue.push(t);
}
//向队列取出元素
bool Dequeue(T& t)
{
unique_lock<mutex> lock(mu);
if (m_queue.empty()) return false;
t = move(m_queue.front());
m_queue.pop();
return true;
}
private:
queue<T> m_queue;
mutex mu;
};
ThreadPool.h
#pragma once
#include <functional>
#include <future>
#include "SafeQueue.h"
class ThreadPool
{
private:
class ThreadWorker //内置线程工作类
{
private:
int m_id; //工作线程
ThreadPool* m_pool; //所属线程池
public:
ThreadWorker(ThreadPool* pool, const int id):m_pool(pool),m_id(id)
{
}
public:
void operator()() //重载操作符,变成函数对象,ThreadWorker 是一个函数对象
{
function<void()> func; //定义基础函数类
bool m_dequeued;
while (!m_pool->m_shutdown)
{
unique_lock<mutex> lock(m_pool->mu);
if (m_pool->m_queue.Empty())
{
m_pool->m_cv.wait(lock);
}
m_dequeued = m_pool->m_queue.Dequeue(func);
if (m_dequeued)
{
func();
}
}
}
};
public:
ThreadPool(const int num) :m_threads(vector<thread>(num)), m_shutdown(false)
{
}
ThreadPool(const ThreadPool&) = delete;
ThreadPool(ThreadPool&&) = delete;
ThreadPool& operator=(const ThreadPool&) = delete;
ThreadPool& operator=(ThreadPool&&) = delete;
public:
void Init()
{
for (int i=0;i<m_threads.size();++i)
{
//thread初始构造函数,ThreadWorker 是函数对象,这一步巧妙
m_threads[i] = thread(ThreadWorker(this, i));
}
}
void ShutDown()
{
m_shutdown = true;
m_cv.notify_all();
for (int i = 0; i < m_threads.size(); ++i)
{
if (m_threads[i].joinable())
{
m_threads[i].join();
}
}
}
template<typename F, typename...Args>
auto submit(F&& f, Args&&... args) -> std::future<decltype(f(args...))> {
std::function<decltype(f(args...))()> func = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
//封装获取任务对象,方便另外一个线程查看结果
auto task_ptr = std::make_shared<std::packaged_task<decltype(f(args...))()>>(func);
// Wrap packaged task into void function
//利用正则表达式,返回一个函数对象
std::function<void()> wrapper_func = [task_ptr]() {
(*task_ptr)(); //这一步调用内部类的操作符了
};
// 队列通用安全封包函数,并压入安全队列
m_queue.Enqueue(wrapper_func);
// 唤醒一个等待中的线程
m_cv.notify_one();
// 返回先前注册的任务指针
return task_ptr->get_future();
}
private:
atomic<bool> m_shutdown; //线程池是否关闭
SafeQueue<function<void()>> m_queue;//执行函数安全队列,即任务队列
vector<thread> m_threads; //工作线程队列
mutex mu;
condition_variable m_cv;
};
Test.cpp
#include "ThreadPool.h"
#include <random>
using namespace std;
random_device rd;
mt19937 mt(rd());
std::uniform_int_distribution<int> dist(-1000, 1000);//生成-1000到1000之间的离散均匀分部数
auto rnd = std::bind(dist, mt);
//设置线程睡眠时间
void simulate_hard_computation() {
std::this_thread::sleep_for(std::chrono::milliseconds(2000 + rnd()));
}
// 添加两个数字的简单函数并打印结果
void multiply(const int a, const int b) {
simulate_hard_computation();
const int res = a * b;
std::cout << a << " * " << b << " = " << res << std::endl;
}
//添加并输出结果
void multiply_output(int& out, const int a, const int b) {
simulate_hard_computation();
out = a * b;
std::cout << a << " * " << b << " = " << out << std::endl;
}
// 结果返回
int multiply_return(const int a, const int b) {
simulate_hard_computation();
const int res = a * b;
std::cout << a << " * " << b << " = " << res << std::endl;
return res;
}
void example() {
// 创建3个线程的线程池
ThreadPool pool(3);
// 初始化线程池
pool.Init();
// 提交乘法操作,总共30个
for (int i = 1; i < 2; ++i) {
for (int j = 1; j < 10; ++j) {
pool.submit(multiply, i, j);
}
}
// 使用ref传递的输出参数提交函数
int output_ref;
auto future1 = pool.submit(multiply_output, std::ref(output_ref), 5, 6);
// 等待乘法输出完成
future1.get();
std::cout << "Last operation result is equals to " << output_ref << std::endl;
// 使用return参数提交函数
auto future2 = pool.submit(multiply_return, 5, 3);
// 等待乘法输出完成
int res = future2.get();
std::cout << "Last operation result is equals to " << res << std::endl;
//关闭线程池
pool.ShutDown();
}
int main()
{
example();
return 0;
}
第二种:
Threadpool.h
#pragma once
#include <functional>
#include <thread>
#include <queue>
#include <condition_variable>
#include <future>
using namespace std;
using Task = function<void()>;
class ThreadPool
{
public:
ThreadPool(size_t size = 4);
~ThreadPool();
public:
template<typename T, typename...Args>
auto Commit(T&& t, Args&&...args)->future<decltype(t(args...))>
{
if (m_stop.load())
{
throw runtime_error("task has closed commit");
}
using ResType = decltype(t(args...));
auto task = make_shared<packaged_task<ResType()>>(
bind(forward<T>(t), forward<Args>(args)...));
unique_lock<mutex> lock(mu);
m_tasks.emplace([task]() {
(*task)();
});
m_cv.notify_all(); //唤醒等待线程
future<ResType> fu = task->get_future();
return fu;
}
public:
void ShutDown(); //停止任务提交
void Restart(); //重启任务提交
private:
Task GetOneTask();//获取一个待执行的task
void Schedual(); //任务调度
private:
vector<thread> m_pool;
mutex mu;
queue<Task> m_tasks;
condition_variable m_cv;
atomic<bool> m_stop;
};
ThreadPool.cpp
#include "ThreadPool.h"
#include <future>
ThreadPool::ThreadPool(size_t size) :m_stop{false}
{
size = size < 1 ? 1 : size;
for (size_t i=0;i<size;++i)
{
m_pool.emplace_back(&ThreadPool::Schedual, this);
}
}
ThreadPool::~ThreadPool()
{
for (auto&t:m_pool)
{
t.detach(); //让线程自身自灭
//t.join(); //等任务结束,前提:线程一定会执行完
}
}
void ThreadPool::ShutDown()
{
m_stop.store(true);//对内存进行访问memory_order_seq_cst,采用store
}
void ThreadPool::Restart()
{
m_stop.store(false);//对内存进行访问memory_order_seq_cst,采用store
}
Task ThreadPool::GetOneTask()
{
unique_lock<mutex> lock(mu);
m_cv.wait(lock, [this] {return !m_tasks.empty(); });
Task task(move(m_tasks.front()));
m_tasks.pop();
return task;
}
void ThreadPool::Schedual()
{
while (true)
{
if (Task task =GetOneTask())
{
task();
}
else
{
return; //结束
}
}
}
Test.cpp
// Test.cpp : 此文件包含 "main" 函数。程序执行将在此处开始并结束。
//
#include <iostream>
#include <future>
#include "ThreadPool.h"
using namespace std;
void fun()
{
for (int i = 0; i < 100000; ++i)
{
cout << "hello"<<i << endl;
}
}
struct Gan
{
int operator()() {
cout << "hello,gan" << endl;
return 42;
}
};
int main() {
try
{
ThreadPool task(10);
future<void> ff = task.Commit(fun);
future<int> fg = task.Commit(Gan());
future<string> fs = task.Commit([]()->string {
return "hello,fs";
});
task.ShutDown();
ff.get();
cout << "fg.get : " << fg.get ()<< endl;
this_thread::sleep_for(chrono::seconds(5));
task.Restart(); //重启任务
cout << "end " << endl;
return 0;
}
catch (const std::exception& e)
{
cout << "soming is wrong "<< e.what() << endl;
}
return 0;
}