#include <iostream>
#include<thread>
#include<vector>
#include<queue>
#include<functional>
#include<mutex>
#include<algorithm>
#include<future>
#include<list>
class join_threads {
private:
std::vector<std::thread>& threads;
public:
explicit join_threads(std::vector<std::thread>& threads_)
: threads(threads_) {}
~join_threads() {
for (auto& t : threads) {
if (t.joinable()) {
t.join();
}
}
}
};
template<class T>
class threadsafe_queue {
private:
mutable std::mutex m;
std::queue<std::shared_ptr<T>>data_queue; //内部使用指针存储不会抛出异常
std::condition_variable data_cond; //条件变量
public:
threadsafe_queue() {}
threadsafe_queue(const threadsafe_queue& other) {
std::lock_guard<std::mutex>lock(other.m); //锁上互斥锁
data_queue = other.data_queue;
}
threadsafe_queue& operator=(const threadsafe_queue&) = delete;
void push(T new_value) {
data_queue.push(std::make_shared<T>(std::move(new_value))); //在锁外提前准备数据
std::lock_guard<std::mutex>lock(m);
data_cond.notify_one(); //队列不为空,唤醒线程
}
void wait_and_pop(T& value) {
std::unique_lock<std::mutex>lock(m);
data_cond.wait(lock, [this] {return !data_queue.empty(); });
value = std::move(*data_queue.front());
data_queue.pop();
}
std::shared_ptr<T> wait_and_pop() {
std::unique_lock<std::mutex>lock(m);
data_cond.wait(lock, [this] {return !data_queue.empty(); });
std::shared_ptr<T>res = data_queue.front();
data_queue.pop();
return res;
}
bool try_pop(T& value) {
std::lock_guard<std::mutex>lock(m);
if (data_queue.empty()) return false;
value = std::move(*data_queue.front());
data_queue.pop();
return true;
}
std::shared_ptr<T> try_pop() {
std::lock_guard<std::mutex>lock(m);
if (data_queue.empty()) return nullptr;
std::shared_ptr<T>res = data_queue.front();
data_queue.pop();
return res;
}
bool empty() const {
std::lock_guard<std::mutex>lock(m);
return data_queue.empty();
}
};
class function_wrapper {
private:
struct impl_base {
virtual void call() = 0;
virtual~impl_base() {};
};
std::unique_ptr<impl_base>impl;
template<typename F>
struct impl_type :impl_base {
F f;
impl_type(F&& f_) :f(std::move(f_)) {}
void call() { f(); }
};
public:
template<typename F>
function_wrapper(F&& f)
:impl(new impl_type<F>(std::move(f))) {}
void operator()() { impl->call(); }
function_wrapper() = default;
function_wrapper(function_wrapper&& other) noexcept
:impl(std::move(other.impl)) {}
function_wrapper& operator = (function_wrapper&& other) noexcept {
impl = std::move(other.impl);
return *this;
}
function_wrapper(function_wrapper const& other) = delete;
function_wrapper(function_wrapper& other) = delete;
function_wrapper& operator= (function_wrapper const& other) = delete;
function_wrapper& operator= (function_wrapper& other) = delete;
};
class work_stealing_queue {
private:
typedef function_wrapper data_type;
std::deque<data_type> the_queue;
mutable std::mutex the_mutex;
public:
work_stealing_queue() = default;
work_stealing_queue(work_stealing_queue const& other) = delete;
work_stealing_queue& operator=(work_stealing_queue const& other) = delete;
void push(data_type data) {
std::lock_guard<std::mutex> lock(the_mutex);
the_queue.push_front(std::move(data));
}
bool try_pop(data_type& res) {
std::lock_guard<std::mutex> lock(the_mutex);
if (the_queue.empty()) { return false; }
else {
res = std::move(the_queue.front());
the_queue.pop_front();
return true;
}
}
bool try_steal(data_type& res) {
std::lock_guard<std::mutex>lock(the_mutex);
if (the_queue.empty()) { return false; }
else {
res = std::move(the_queue.back());
the_queue.pop_back();
return true;
}
}
bool empty() {
std::lock_guard<std::mutex>lock(the_mutex);
return the_queue.empty();
}
};
class thread_pool {
private:
typedef function_wrapper task_type;
threadsafe_queue<task_type>pool_work_queue;
std::vector<std::unique_ptr<work_stealing_queue>>queues;
static thread_local work_stealing_queue* local_work_queue;
static thread_local unsigned my_index;
std::atomic<bool>done;
std::vector<std::thread> threads;
join_threads joiner;
bool pop_task_from_local_queue(task_type& task) {
return local_work_queue && local_work_queue->try_pop(task);
}
bool pop_task_from_pool_queue(task_type& task) {
return pool_work_queue.try_pop(task);
}
bool pop_task_from_other_thread_queue(task_type& task) {
for (unsigned i = 0; i < queues.size(); ++i) {
unsigned const index = (my_index + i + 1) % queues.size();
if (queues[index]->try_steal(task)) { return true; }
}
return false;
}
public:
void run_pending_task() {
task_type task;
if (pop_task_from_local_queue(task) || pop_task_from_pool_queue(task)
|| pop_task_from_other_thread_queue(task)) {
task();
}
else { std::this_thread::yield(); }
}
void worker_thread(unsigned const my_index_) {
my_index = my_index_;
local_work_queue = queues[my_index].get();
while (!done) { run_pending_task(); }
}
template<typename FunctionType>
std::future<typename std::result_of<FunctionType()>::type>submit(FunctionType f) {
typedef typename std::result_of<FunctionType()>::type result_type;
std::packaged_task<result_type()>task(std::move(f));
std::future<result_type>res(task.get_future());
if (local_work_queue) {
local_work_queue->push(std::move(task));
}
else {
pool_work_queue.push(std::move(task));
}
return res;
}
thread_pool() :
done(false), joiner(threads) {
unsigned const thread_count = std::thread::hardware_concurrency();
try {
for (unsigned i = 0; i < thread_count; ++i) {
queues.push_back(std::unique_ptr<work_stealing_queue>(new work_stealing_queue));
}
for (unsigned i = 0; i < thread_count; ++i) {
threads.push_back(std::thread(&thread_pool::worker_thread, this, i));
}
}
catch (...) {
done = true;
throw;
}
}
~thread_pool() { done.store(true); }
};
thread_local work_stealing_queue* thread_pool::local_work_queue;
thread_local unsigned thread_pool::my_index;
template<typename T>
struct sorter {
thread_pool pool;
std::list<T> do_sort(std::list<T>& chunk_data) {
if (chunk_data.empty()) { return chunk_data; }
std::list<T> result;
result.splice(result.begin(), chunk_data, chunk_data.begin());
T const& partition_val = *result.begin();
typename std::list<T>::iterator divide_point =
std::partition(chunk_data.begin(), chunk_data.end(),
[&](T const& val) {return val < partition_val; });
std::list<T> new_lower_chunk;
new_lower_chunk.splice(new_lower_chunk.end(),
chunk_data, chunk_data.begin(), divide_point);
std::future<std::list<T> > new_lower = pool.submit(std::bind(&sorter::do_sort, this, std::move(new_lower_chunk)));
std::list<T> new_higher(do_sort(chunk_data));
result.splice(result.end(), new_higher);
while (new_lower.wait_for(std::chrono::seconds(0)) == std::future_status::timeout) {
pool.run_pending_task();
}
result.splice(result.begin(), new_lower.get());
return result;
}
};
template<typename T>
std::list<T> parallel_quick_sort(std::list<T> input) {
if (input.empty()) { return input; }
sorter<T> s;
return s.do_sort(input);
}
int main() {
std::list<int>list;
for (int i = 1; i <= 10; ++i) {
list.push_front(i);
}
std::cout << "initial list:\n";
for (auto& x : list) {
std::cout << x << " ";
}
std::cout << std::endl;
auto res = parallel_quick_sort(list);
std::cout << "sorted list:\n";
for (auto& x : res) {
std::cout << x << " ";
}
std::cout << std::endl;
return 0;
}
【线程池】使用工作窃取的线程池:实现并行快速排序
最新推荐文章于 2024-11-05 17:16:24 发布