【C++ 并发:第二版】:线程池——阻塞队列

1、包装任务函数

自定义带有函数操作符的类型擦除类来包装函数【function_wrapper.hpp】:

#pragma once
#include <memory>

namespace tp
{
    class function_wrapper
    {
    public:
        function_wrapper() = default;
        function_wrapper(function_wrapper&& other) noexcept 
            : impl(std::move(other.impl)) 
        {}
        
		template<typename F>
        explicit function_wrapper(F&& f) 
            : impl(new impl_type<F>(std::move(f)))
        {}

		void operator()() { impl->call(); }
        function_wrapper& operator=(function_wrapper&& other) noexcept
        {
            impl = std::move(other.impl);
            return *this;
        }

        function_wrapper(const function_wrapper&) = delete;
        function_wrapper(function_wrapper&) = delete;
        function_wrapper& operator=(const function_wrapper&) = delete;

    private:
        struct impl_base
        {
            virtual void call() = 0;
            virtual ~impl_base() {}
        };

        std::unique_ptr<impl_base> impl;
        template<typename F>
        struct impl_type : impl_base
        {
            F f;
            explicit impl_type(F&& f_) : f(std::move(f_)) {}
            void call() override { f(); }
        };
    };
}
2、RAII管理线程汇入

RAII控制众线程离开作用区域均可汇入【join_threads.hpp】:

#pragma once
#include <vector>
#include <thread>

namespace tp
{
    class join_threads
    {
    public:
        explicit join_threads(std::vector<std::thread>& threads_)
            : threads(threads_)
        {}

        ~join_threads()
        {
            for (auto& thread : threads)
            {
                if (thread.joinable())
                    thread.join();
            }
        }
    private:
        std::vector<std::thread>& threads;
    };
}
3、线程安全任务队列

可上锁和等待的线程安全队列【thread_safe_queue.hpp】:

#pragma once
#include <memory>
#include <mutex>

namespace tp
{
    template<typename T>
    class thread_safe_queue
    {
    public:
        thread_safe_queue() :
            head(new node), tail(head.get())
        {}
        thread_safe_queue(const thread_safe_queue& other) = delete;
        thread_safe_queue& operator=(const thread_safe_queue& other) = delete;

        std::shared_ptr<T> try_pop()
        {
            std::unique_ptr<node> old_head = try_pop_head();
            return old_head != nullptr ? old_head->data : nullptr;
        }

        bool try_pop(T& value)
        {
            std::unique_ptr<node> old_head = try_pop_head(value);
            return old_head != nullptr;
        }

        std::shared_ptr<T> wait_and_pop()
        {
            std::unique_ptr<node> const old_head = wait_pop_head();
            return old_head->data;
        }

        std::shared_ptr<T> wait_and_pop(T& value)
        {
            std::unique_ptr<node> const old_head = wait_pop_head(value);
            return old_head->data;
        }

        template<typename T>
        void push(T new_value)
        {
            std::shared_ptr<T> new_data(std::make_shared<T>(std::move(new_value)));
            std::unique_ptr<node> p(new node);
            {
                std::lock_guard<std::mutex> tail_lock(tail_mutex);
                tail->data = new_data;
                node* const new_tail = p.get();
                tail->next = std::move(p);
                tail = new_tail;
            }
            data_cond.notify_one();
        }

        bool empty()
        {
            std::lock_guard<std::mutex> head_lock(head_mutex);
            return (head.get() == get_tail());
        }

    private:
        struct node
        {
            std::shared_ptr<T> data;
            std::unique_ptr<node> next;
        };

        std::mutex head_mutex;
        std::unique_ptr<node> head;
        std::mutex tail_mutex;
        node* tail;
        std::condition_variable data_cond;

    private:
        node* get_tail()
        {
            std::lock_guard<std::mutex> tail_lock(tail_mutex);
            return tail;
        }

        std::unique_ptr<node> pop_head()  // 1
        {
            std::unique_ptr<node> old_head = std::move(head);
            head = std::move(old_head->next);
            return old_head;
        }

        std::unique_lock<std::mutex> wait_for_data()  // 2
        {
            std::unique_lock<std::mutex> head_lock(head_mutex);
            data_cond.wait(head_lock, [&] {return head.get() != get_tail(); });
            return std::move(head_lock);  // 3
        }

        std::unique_ptr<node> wait_pop_head()
        {
            std::unique_lock<std::mutex> head_lock(wait_for_data());  // 4
            return pop_head();
        }

        std::unique_ptr<node> wait_pop_head(T& value)
        {
            std::unique_lock<std::mutex> head_lock(wait_for_data());  // 5
            value = std::move(*head->data);
            return pop_head();
        }

        std::unique_ptr<node> try_pop_head()
        {
            std::lock_guard<std::mutex> head_lock(head_mutex);
            if (head.get() == get_tail())
            {
                return nullptr;
            }
            return pop_head();
        }

        std::unique_ptr<node> try_pop_head(T& value)
        {
            std::lock_guard<std::mutex> head_lock(head_mutex);
            if (head.get() == get_tail())
            {
                return nullptr;
            }
            value = std::move(*head->data);
            return pop_head();
        }
    };
}
4、任务窃取

基于锁的任务窃取队列【work_stealing_queue.h】:

#pragma once
#include <deque>
#include <mutex>

#include "function_wrapper.hpp"

namespace tp
{
    class work_stealing_queue
    {
        typedef function_wrapper data_type;

    public:
        work_stealing_queue();
        work_stealing_queue(const work_stealing_queue& other) = delete;
        work_stealing_queue& operator=(const work_stealing_queue& other) = delete;

        void push(data_type data);
        bool empty() const;
        bool try_pop(data_type& res);
        bool try_steal(data_type& res);

    private:
        std::deque<data_type> the_queue;  // 1
        mutable std::mutex the_mutex;
    };
}

应让无任务的线程从其他线程的任务队列中分担任务【work_stealing_queue.cpp】:

#include "work_stealing_queue.h"

namespace tp
{
    work_stealing_queue::work_stealing_queue() {}

    void work_stealing_queue::push(data_type data)  // 2
    {
        std::lock_guard<std::mutex> lock(the_mutex);
        the_queue.push_front(std::move(data));
    }

    bool work_stealing_queue::empty() const
    {
        std::lock_guard<std::mutex> lock(the_mutex);
        return the_queue.empty();
    }

    bool work_stealing_queue::try_pop(data_type& res)  // 3
    {
        std::lock_guard<std::mutex> lock(the_mutex);
        if (the_queue.size() == 0)
        {
            return false;
        }

        res = std::move(the_queue.front());
        the_queue.pop_front();
        return true;
    }

    bool work_stealing_queue::try_steal(data_type& res)  // 4
    {
        std::lock_guard<std::mutex> lock(the_mutex);
        if (the_queue.size() == 0)
        {
            return false;
        }

        res = std::move(the_queue.back());
        the_queue.pop_back();
        return true;
    }
}
5、 线程池

最终目标线程池【thread_pool.h】:

#pragma once
#include <iostream>
#include <atomic>
#include <thread>
#include <vector>
#include <future>

#include "function_wrapper.hpp"
#include "join_threads.hpp"

#ifdef LOCK_FREE
#include "lock_free_queue.hpp"
#else
#include "thread_safe_queue.hpp"
#endif

#include "work_stealing_queue.h"

namespace tp
{
    class thread_pool
    {
        typedef function_wrapper task_type;

    public:
        thread_pool();
        ~thread_pool();

        void run_pending_task();

        template<typename fun_type>
        std::future<typename std::result_of<fun_type()>::type> submit(fun_type f)
        {
            typedef typename std::result_of<fun_type()>::type result_type;
            std::packaged_task<result_type()> task(f);
            std::future<result_type> res(task.get_future());
            if (local_work_queue)
            {
                local_work_queue->push(task_type(std::move(task)));
            }
            else
            {
                pool_work_queue.push(task_type(std::move(task)));
            }
            return res;
        }

    private:
        void worker_thread(unsigned my_index_);
        bool pop_task_from_local_queue(task_type& task);
        bool pop_task_from_pool_queue(task_type& task);
        bool pop_task_from_other_thread_queue(task_type& task);

    private:
        static thread_local work_stealing_queue* local_work_queue;  // 2
        static thread_local unsigned my_index;

        std::atomic_bool done;
#ifdef LOCK_FREE
        lock_free_queue<task_type> pool_work_queue;
#else
        thread_safe_queue<task_type> pool_work_queue;
#endif
        std::vector<std::unique_ptr<work_stealing_queue> > queues;  // 1
        std::vector<std::thread> threads;
        join_threads joiner;
    };

}

【thread_pool.cpp】:

#include "thread_pool.h"

namespace tp
{
	thread_local work_stealing_queue* tp::thread_pool::local_work_queue = nullptr;
	thread_local unsigned tp::thread_pool::my_index = 0;

    thread_pool::thread_pool()
        : done(false)
        , joiner(threads)
    {
        unsigned const thread_count = std::thread::hardware_concurrency();
        try
        {
            for (unsigned i = 0; i < thread_count; ++i)
            {
                queues.push_back(std::unique_ptr<work_stealing_queue>(new work_stealing_queue));  // 6
                threads.push_back(std::thread(&thread_pool::worker_thread, this, i));
            }
        }
        catch (...)
        {
            done = true;
            throw;
        }
    }

    thread_pool::~thread_pool()
    {
        done = true;
    }

	void thread_pool::run_pending_task()
	{
        task_type task;
        if (pop_task_from_local_queue(task) ||  // 7
            pop_task_from_pool_queue(task) ||  // 8
            pop_task_from_other_thread_queue(task))  // 9
        {
            task();
        }
        else
        {
            std::this_thread::yield();
        }
	}


    void thread_pool::worker_thread(unsigned my_index_)
    {
        my_index = my_index_;
        local_work_queue = queues[my_index].get();  // 3
        while (!done)
        {
            run_pending_task();
        }
    }

    bool thread_pool::pop_task_from_local_queue(task_type& task)
    {
        return local_work_queue && local_work_queue->try_pop(task);
    }

    bool thread_pool::pop_task_from_pool_queue(task_type& task)
    {
        return pool_work_queue.try_pop(task);
    }

    bool thread_pool::pop_task_from_other_thread_queue(task_type& task)  // 4
    {
        for (unsigned i = 0; i < queues.size(); ++i)
        {
            unsigned const index = (my_index + i + 1) % queues.size();  // 5
            if (queues[index]->try_steal(task))
            {
                return true;
            }
        }
        return false;
    }
}
6、测试代码
#include <iostream>
#include <vector>
#include <future>
#include <iterator>
#include <numeric>

#include "thread_pool.h"
#include "function_wrapper.hpp"
#include "thread_safe_queue.hpp"

template<typename Iterator, typename T>
struct accumulate_block
{
    T operator()(Iterator first, Iterator last)
    {
        return std::accumulate(first, last, T());  // 1
    }
};

template<typename Iterator, typename T>
T parallel_accumulate(Iterator first, Iterator last, int init)
{
    unsigned long const length = std::distance(first, last);

    if (!length)
        return init;

    unsigned long const block_size = 120000;	// 块尺寸太小则任务过多,切换&竞争耗时
    unsigned long const num_blocks = (length + block_size - 1) / block_size;  // 1

    std::vector<std::future<T> > futures(num_blocks - 1);
    tp::thread_pool pool;

    Iterator block_start = first;
    for (unsigned long i = 0; i < (num_blocks - 1); ++i)
    {
        Iterator block_end = block_start;
        std::advance(block_end, block_size);
        futures[i] = pool.submit([=] { return accumulate_block<Iterator, T>()(block_start, block_end); }); // 2
        block_start = block_end;
    }
    int last_result = accumulate_block<Iterator, T>()(block_start, last);
    int result = init;
    for (unsigned long i = 0; i < (num_blocks - 1); ++i)
    {
        result += futures[i].get();
    }
    result += last_result;
    return result;
}

int main()
{
    std::vector<int> data(500000);
    std::iota(std::begin(data), std::end(data), 0);

    //auto start = std::chrono::steady_clock::now();
    //int result = 0;
    //for (auto& item : data)
    //    result += item;
    //auto end = std::chrono::steady_clock::now();
    //auto duration = std::chrono::duration<double>(end - start).count();
    
    //auto start = std::chrono::steady_clock::now();
    //auto result = std::accumulate(std::begin(data), std::end(data), 0);
    //auto end = std::chrono::steady_clock::now();
    //auto duration = std::chrono::duration<double>(end - start).count();

    auto start = std::chrono::steady_clock::now();
    auto result = parallel_accumulate<std::vector<int>::iterator, int>(std::begin(data), std::end(data), 0);
    auto end = std::chrono::steady_clock::now();
    auto duration = std::chrono::duration<double>(end - start).count();
    
    std::cout << result << std::endl;
    return 0;
}
  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值