【线程池】使用工作窃取的线程池:实现并行快速排序

#include <iostream>
#include<thread>
#include<vector>
#include<queue>
#include<functional>
#include<mutex>
#include<algorithm>
#include<future>
#include<list>
class join_threads {
private:
	std::vector<std::thread>& threads;

public:
	explicit join_threads(std::vector<std::thread>& threads_)
		: threads(threads_) {}

	~join_threads() {
		for (auto& t : threads) {
			if (t.joinable()) {
				t.join();
			}
		}
	}
};

template<class T>
class threadsafe_queue {
private:
	mutable std::mutex m;
	std::queue<std::shared_ptr<T>>data_queue; //内部使用指针存储不会抛出异常
	std::condition_variable data_cond; //条件变量
public:
	threadsafe_queue() {}
	threadsafe_queue(const threadsafe_queue& other) {
		std::lock_guard<std::mutex>lock(other.m); //锁上互斥锁
		data_queue = other.data_queue;
	}
	threadsafe_queue& operator=(const threadsafe_queue&) = delete;
	void push(T new_value) {
		data_queue.push(std::make_shared<T>(std::move(new_value))); //在锁外提前准备数据
		std::lock_guard<std::mutex>lock(m);
		data_cond.notify_one(); //队列不为空,唤醒线程
	}
	void wait_and_pop(T& value) {
		std::unique_lock<std::mutex>lock(m);
		data_cond.wait(lock, [this] {return !data_queue.empty(); });
		value = std::move(*data_queue.front());
		data_queue.pop();
	}
	std::shared_ptr<T> wait_and_pop() {
		std::unique_lock<std::mutex>lock(m);
		data_cond.wait(lock, [this] {return !data_queue.empty(); });
		std::shared_ptr<T>res = data_queue.front();
		data_queue.pop();
		return res;
	}
	bool try_pop(T& value) {
		std::lock_guard<std::mutex>lock(m);
		if (data_queue.empty()) return false;
		value = std::move(*data_queue.front());
		data_queue.pop();
		return true;
	}
	std::shared_ptr<T> try_pop() {
		std::lock_guard<std::mutex>lock(m);
		if (data_queue.empty()) return nullptr;
		std::shared_ptr<T>res = data_queue.front();
		data_queue.pop();
		return res;
	}
	bool empty() const {
		std::lock_guard<std::mutex>lock(m);
		return data_queue.empty();
	}
};
class function_wrapper {
private:
	struct impl_base {
		virtual void call() = 0;
		virtual~impl_base() {};
	};

	std::unique_ptr<impl_base>impl;

	template<typename F>
	struct impl_type :impl_base {
		F f;

		impl_type(F&& f_) :f(std::move(f_)) {}

		void call() { f(); }
	};

public:
	template<typename F>
	function_wrapper(F&& f)
		:impl(new impl_type<F>(std::move(f))) {}

	void operator()() { impl->call(); }

	function_wrapper() = default;
	function_wrapper(function_wrapper&& other) noexcept
		:impl(std::move(other.impl)) {}
	function_wrapper& operator = (function_wrapper&& other) noexcept {
		impl = std::move(other.impl);
		return *this;
	}

	function_wrapper(function_wrapper const& other) = delete;
	function_wrapper(function_wrapper& other) = delete;

	function_wrapper& operator= (function_wrapper const& other) = delete;
	function_wrapper& operator= (function_wrapper& other) = delete;

};

class work_stealing_queue {
private:
	typedef function_wrapper data_type;

	std::deque<data_type> the_queue;
	mutable std::mutex the_mutex;

public:
	work_stealing_queue() = default;
	work_stealing_queue(work_stealing_queue const& other) = delete;
	work_stealing_queue& operator=(work_stealing_queue const& other) = delete;

	void push(data_type data) {
		std::lock_guard<std::mutex> lock(the_mutex);
		the_queue.push_front(std::move(data));
	}

	bool try_pop(data_type& res) {
		std::lock_guard<std::mutex> lock(the_mutex);

		if (the_queue.empty()) { return false; }
		else {
			res = std::move(the_queue.front());
			the_queue.pop_front();

			return true;
		}
	}

	bool try_steal(data_type& res) {
		std::lock_guard<std::mutex>lock(the_mutex);

		if (the_queue.empty()) { return false; }
		else {
			res = std::move(the_queue.back());
			the_queue.pop_back();

			return true;
		}
	}

	bool empty() {
		std::lock_guard<std::mutex>lock(the_mutex);
		return the_queue.empty();
	}

};

class thread_pool {
private:
	typedef function_wrapper task_type;

	threadsafe_queue<task_type>pool_work_queue;
	std::vector<std::unique_ptr<work_stealing_queue>>queues;

	static thread_local work_stealing_queue* local_work_queue;
	static thread_local unsigned my_index;

	std::atomic<bool>done;
	std::vector<std::thread> threads;
	join_threads joiner;

	bool pop_task_from_local_queue(task_type& task) {
		return local_work_queue && local_work_queue->try_pop(task);
	}

	bool pop_task_from_pool_queue(task_type& task) {
		return pool_work_queue.try_pop(task);
	}

	bool pop_task_from_other_thread_queue(task_type& task) {
		for (unsigned i = 0; i < queues.size(); ++i) {
			unsigned const index = (my_index + i + 1) % queues.size();
			if (queues[index]->try_steal(task)) { return true; }
		}
		return false;
	}


public:

	void run_pending_task() {
		task_type task;
		if (pop_task_from_local_queue(task) || pop_task_from_pool_queue(task)
			|| pop_task_from_other_thread_queue(task)) {
			task();
		}
		else { std::this_thread::yield(); }
	}

	void worker_thread(unsigned const my_index_) {
		my_index = my_index_;
		local_work_queue = queues[my_index].get();

		while (!done) { run_pending_task(); }
	}


	template<typename FunctionType>
	std::future<typename std::result_of<FunctionType()>::type>submit(FunctionType f) {

		typedef typename std::result_of<FunctionType()>::type result_type;

		std::packaged_task<result_type()>task(std::move(f));
		std::future<result_type>res(task.get_future());

		if (local_work_queue) {
			local_work_queue->push(std::move(task));
		}
		else {
			pool_work_queue.push(std::move(task));
		}
		return res;
	}

	thread_pool() :
		done(false), joiner(threads) {
		unsigned const thread_count = std::thread::hardware_concurrency();
		try {
			for (unsigned i = 0; i < thread_count; ++i) {
				queues.push_back(std::unique_ptr<work_stealing_queue>(new work_stealing_queue));
			}
			for (unsigned i = 0; i < thread_count; ++i) {
				threads.push_back(std::thread(&thread_pool::worker_thread, this, i));
			}
		}
		catch (...) {
			done = true;
			throw;
		}
	}

	~thread_pool() { done.store(true); }
};

thread_local work_stealing_queue* thread_pool::local_work_queue;
thread_local unsigned thread_pool::my_index;

template<typename T>
struct sorter {

	thread_pool pool;

	std::list<T> do_sort(std::list<T>& chunk_data) {

		if (chunk_data.empty()) { return chunk_data; }

		std::list<T> result;
		result.splice(result.begin(), chunk_data, chunk_data.begin());

		T const& partition_val = *result.begin();
		typename std::list<T>::iterator divide_point =
			std::partition(chunk_data.begin(), chunk_data.end(),
				[&](T const& val) {return val < partition_val; });

		std::list<T> new_lower_chunk;
		new_lower_chunk.splice(new_lower_chunk.end(),
			chunk_data, chunk_data.begin(), divide_point);

		std::future<std::list<T> > new_lower = pool.submit(std::bind(&sorter::do_sort, this, std::move(new_lower_chunk)));
		std::list<T> new_higher(do_sort(chunk_data));
		result.splice(result.end(), new_higher);

		while (new_lower.wait_for(std::chrono::seconds(0)) == std::future_status::timeout) {
			pool.run_pending_task();
		}
		result.splice(result.begin(), new_lower.get());
		return result;
	}
};

template<typename T>
std::list<T> parallel_quick_sort(std::list<T> input) {
	if (input.empty()) { return input; }

	sorter<T> s;
	return s.do_sort(input);
}

int main() {
	std::list<int>list;
	for (int i = 1; i <= 10; ++i) {
		list.push_front(i);
	}

	std::cout << "initial list:\n";
	for (auto& x : list) {
		std::cout << x << " ";
	}
	std::cout << std::endl;

	auto res = parallel_quick_sort(list);

	std::cout << "sorted list:\n";
	for (auto& x : res) {
		std::cout << x << " ";
	}
	std::cout << std::endl;

	return 0;
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值