1、包装任务函数 |
自定义带有函数操作符的类型擦除类来包装函数【function_wrapper.hpp】:
#pragma once
#include <memory>
namespace tp
{
class function_wrapper
{
public:
function_wrapper() = default;
function_wrapper(function_wrapper&& other) noexcept
: impl(std::move(other.impl))
{}
template<typename F>
explicit function_wrapper(F&& f)
: impl(new impl_type<F>(std::move(f)))
{}
void operator()() { impl->call(); }
function_wrapper& operator=(function_wrapper&& other) noexcept
{
impl = std::move(other.impl);
return *this;
}
function_wrapper(const function_wrapper&) = delete;
function_wrapper(function_wrapper&) = delete;
function_wrapper& operator=(const function_wrapper&) = delete;
private:
struct impl_base
{
virtual void call() = 0;
virtual ~impl_base() {}
};
std::unique_ptr<impl_base> impl;
template<typename F>
struct impl_type : impl_base
{
F f;
explicit impl_type(F&& f_) : f(std::move(f_)) {}
void call() override { f(); }
};
};
}
2、RAII管理线程汇入 |
RAII控制众线程离开作用区域均可汇入【join_threads.hpp】:
#pragma once
#include <vector>
#include <thread>
namespace tp
{
class join_threads
{
public:
explicit join_threads(std::vector<std::thread>& threads_)
: threads(threads_)
{}
~join_threads()
{
for (auto& thread : threads)
{
if (thread.joinable())
thread.join();
}
}
private:
std::vector<std::thread>& threads;
};
}
3、线程安全任务队列 |
可上锁和等待的线程安全队列【thread_safe_queue.hpp】:
#pragma once
#include <memory>
#include <mutex>
namespace tp
{
template<typename T>
class thread_safe_queue
{
public:
thread_safe_queue() :
head(new node), tail(head.get())
{}
thread_safe_queue(const thread_safe_queue& other) = delete;
thread_safe_queue& operator=(const thread_safe_queue& other) = delete;
std::shared_ptr<T> try_pop()
{
std::unique_ptr<node> old_head = try_pop_head();
return old_head != nullptr ? old_head->data : nullptr;
}
bool try_pop(T& value)
{
std::unique_ptr<node> old_head = try_pop_head(value);
return old_head != nullptr;
}
std::shared_ptr<T> wait_and_pop()
{
std::unique_ptr<node> const old_head = wait_pop_head();
return old_head->data;
}
std::shared_ptr<T> wait_and_pop(T& value)
{
std::unique_ptr<node> const old_head = wait_pop_head(value);
return old_head->data;
}
template<typename T>
void push(T new_value)
{
std::shared_ptr<T> new_data(std::make_shared<T>(std::move(new_value)));
std::unique_ptr<node> p(new node);
{
std::lock_guard<std::mutex> tail_lock(tail_mutex);
tail->data = new_data;
node* const new_tail = p.get();
tail->next = std::move(p);
tail = new_tail;
}
data_cond.notify_one();
}
bool empty()
{
std::lock_guard<std::mutex> head_lock(head_mutex);
return (head.get() == get_tail());
}
private:
struct node
{
std::shared_ptr<T> data;
std::unique_ptr<node> next;
};
std::mutex head_mutex;
std::unique_ptr<node> head;
std::mutex tail_mutex;
node* tail;
std::condition_variable data_cond;
private:
node* get_tail()
{
std::lock_guard<std::mutex> tail_lock(tail_mutex);
return tail;
}
std::unique_ptr<node> pop_head() // 1
{
std::unique_ptr<node> old_head = std::move(head);
head = std::move(old_head->next);
return old_head;
}
std::unique_lock<std::mutex> wait_for_data() // 2
{
std::unique_lock<std::mutex> head_lock(head_mutex);
data_cond.wait(head_lock, [&] {return head.get() != get_tail(); });
return std::move(head_lock); // 3
}
std::unique_ptr<node> wait_pop_head()
{
std::unique_lock<std::mutex> head_lock(wait_for_data()); // 4
return pop_head();
}
std::unique_ptr<node> wait_pop_head(T& value)
{
std::unique_lock<std::mutex> head_lock(wait_for_data()); // 5
value = std::move(*head->data);
return pop_head();
}
std::unique_ptr<node> try_pop_head()
{
std::lock_guard<std::mutex> head_lock(head_mutex);
if (head.get() == get_tail())
{
return nullptr;
}
return pop_head();
}
std::unique_ptr<node> try_pop_head(T& value)
{
std::lock_guard<std::mutex> head_lock(head_mutex);
if (head.get() == get_tail())
{
return nullptr;
}
value = std::move(*head->data);
return pop_head();
}
};
}
4、任务窃取 |
基于锁的任务窃取队列【work_stealing_queue.h】:
#pragma once
#include <deque>
#include <mutex>
#include "function_wrapper.hpp"
namespace tp
{
class work_stealing_queue
{
typedef function_wrapper data_type;
public:
work_stealing_queue();
work_stealing_queue(const work_stealing_queue& other) = delete;
work_stealing_queue& operator=(const work_stealing_queue& other) = delete;
void push(data_type data);
bool empty() const;
bool try_pop(data_type& res);
bool try_steal(data_type& res);
private:
std::deque<data_type> the_queue; // 1
mutable std::mutex the_mutex;
};
}
应让无任务的线程从其他线程的任务队列中分担任务【work_stealing_queue.cpp】:
#include "work_stealing_queue.h"
namespace tp
{
work_stealing_queue::work_stealing_queue() {}
void work_stealing_queue::push(data_type data) // 2
{
std::lock_guard<std::mutex> lock(the_mutex);
the_queue.push_front(std::move(data));
}
bool work_stealing_queue::empty() const
{
std::lock_guard<std::mutex> lock(the_mutex);
return the_queue.empty();
}
bool work_stealing_queue::try_pop(data_type& res) // 3
{
std::lock_guard<std::mutex> lock(the_mutex);
if (the_queue.size() == 0)
{
return false;
}
res = std::move(the_queue.front());
the_queue.pop_front();
return true;
}
bool work_stealing_queue::try_steal(data_type& res) // 4
{
std::lock_guard<std::mutex> lock(the_mutex);
if (the_queue.size() == 0)
{
return false;
}
res = std::move(the_queue.back());
the_queue.pop_back();
return true;
}
}
5、 线程池 |
最终目标线程池【thread_pool.h】:
#pragma once
#include <iostream>
#include <atomic>
#include <thread>
#include <vector>
#include <future>
#include "function_wrapper.hpp"
#include "join_threads.hpp"
#ifdef LOCK_FREE
#include "lock_free_queue.hpp"
#else
#include "thread_safe_queue.hpp"
#endif
#include "work_stealing_queue.h"
namespace tp
{
class thread_pool
{
typedef function_wrapper task_type;
public:
thread_pool();
~thread_pool();
void run_pending_task();
template<typename fun_type>
std::future<typename std::result_of<fun_type()>::type> submit(fun_type f)
{
typedef typename std::result_of<fun_type()>::type result_type;
std::packaged_task<result_type()> task(f);
std::future<result_type> res(task.get_future());
if (local_work_queue)
{
local_work_queue->push(task_type(std::move(task)));
}
else
{
pool_work_queue.push(task_type(std::move(task)));
}
return res;
}
private:
void worker_thread(unsigned my_index_);
bool pop_task_from_local_queue(task_type& task);
bool pop_task_from_pool_queue(task_type& task);
bool pop_task_from_other_thread_queue(task_type& task);
private:
static thread_local work_stealing_queue* local_work_queue; // 2
static thread_local unsigned my_index;
std::atomic_bool done;
#ifdef LOCK_FREE
lock_free_queue<task_type> pool_work_queue;
#else
thread_safe_queue<task_type> pool_work_queue;
#endif
std::vector<std::unique_ptr<work_stealing_queue> > queues; // 1
std::vector<std::thread> threads;
join_threads joiner;
};
}
【thread_pool.cpp】:
#include "thread_pool.h"
namespace tp
{
thread_local work_stealing_queue* tp::thread_pool::local_work_queue = nullptr;
thread_local unsigned tp::thread_pool::my_index = 0;
thread_pool::thread_pool()
: done(false)
, joiner(threads)
{
unsigned const thread_count = std::thread::hardware_concurrency();
try
{
for (unsigned i = 0; i < thread_count; ++i)
{
queues.push_back(std::unique_ptr<work_stealing_queue>(new work_stealing_queue)); // 6
threads.push_back(std::thread(&thread_pool::worker_thread, this, i));
}
}
catch (...)
{
done = true;
throw;
}
}
thread_pool::~thread_pool()
{
done = true;
}
void thread_pool::run_pending_task()
{
task_type task;
if (pop_task_from_local_queue(task) || // 7
pop_task_from_pool_queue(task) || // 8
pop_task_from_other_thread_queue(task)) // 9
{
task();
}
else
{
std::this_thread::yield();
}
}
void thread_pool::worker_thread(unsigned my_index_)
{
my_index = my_index_;
local_work_queue = queues[my_index].get(); // 3
while (!done)
{
run_pending_task();
}
}
bool thread_pool::pop_task_from_local_queue(task_type& task)
{
return local_work_queue && local_work_queue->try_pop(task);
}
bool thread_pool::pop_task_from_pool_queue(task_type& task)
{
return pool_work_queue.try_pop(task);
}
bool thread_pool::pop_task_from_other_thread_queue(task_type& task) // 4
{
for (unsigned i = 0; i < queues.size(); ++i)
{
unsigned const index = (my_index + i + 1) % queues.size(); // 5
if (queues[index]->try_steal(task))
{
return true;
}
}
return false;
}
}
6、测试代码 |
#include <iostream>
#include <vector>
#include <future>
#include <iterator>
#include <numeric>
#include "thread_pool.h"
#include "function_wrapper.hpp"
#include "thread_safe_queue.hpp"
template<typename Iterator, typename T>
struct accumulate_block
{
T operator()(Iterator first, Iterator last)
{
return std::accumulate(first, last, T()); // 1
}
};
template<typename Iterator, typename T>
T parallel_accumulate(Iterator first, Iterator last, int init)
{
unsigned long const length = std::distance(first, last);
if (!length)
return init;
unsigned long const block_size = 120000; // 块尺寸太小则任务过多,切换&竞争耗时
unsigned long const num_blocks = (length + block_size - 1) / block_size; // 1
std::vector<std::future<T> > futures(num_blocks - 1);
tp::thread_pool pool;
Iterator block_start = first;
for (unsigned long i = 0; i < (num_blocks - 1); ++i)
{
Iterator block_end = block_start;
std::advance(block_end, block_size);
futures[i] = pool.submit([=] { return accumulate_block<Iterator, T>()(block_start, block_end); }); // 2
block_start = block_end;
}
int last_result = accumulate_block<Iterator, T>()(block_start, last);
int result = init;
for (unsigned long i = 0; i < (num_blocks - 1); ++i)
{
result += futures[i].get();
}
result += last_result;
return result;
}
int main()
{
std::vector<int> data(500000);
std::iota(std::begin(data), std::end(data), 0);
//auto start = std::chrono::steady_clock::now();
//int result = 0;
//for (auto& item : data)
// result += item;
//auto end = std::chrono::steady_clock::now();
//auto duration = std::chrono::duration<double>(end - start).count();
//auto start = std::chrono::steady_clock::now();
//auto result = std::accumulate(std::begin(data), std::end(data), 0);
//auto end = std::chrono::steady_clock::now();
//auto duration = std::chrono::duration<double>(end - start).count();
auto start = std::chrono::steady_clock::now();
auto result = parallel_accumulate<std::vector<int>::iterator, int>(std::begin(data), std::end(data), 0);
auto end = std::chrono::steady_clock::now();
auto duration = std::chrono::duration<double>(end - start).count();
std::cout << result << std::endl;
return 0;
}