正常线程执行完任务后即会销毁,而反复创建和销毁线程的代价较高,Fixed 线程池中的线程一直存在,会不断查看任务队列是否有新任务完成,当提交任务给线程池时,若当时线程池已满会返回错误;
1. ThreadPool
线程池可以通过start 开启并传入最大线程数控制线程池规模:
// ThreadPool pool;
// pool.start(4);
void ThreadPool::start(int initThreadSize) {
initThreadSize_ = initThreadSize;
for (int i = 0; i < initThreadSize_; i++) {
auto ptr = std::make_unique<Thread>(std::bind(&ThreadPool::threadFunc, this));
threads_.emplace_back(std::move(ptr));
}
for (int i = 0; i < initThreadSize_; i++) {
threads_[i]->start();
}
线程池拥有vector 管理含Thread 资源的智能指针,Thread 类由C++11 thread 实现线程功能,thread 执行函数由ThreadPool::threadFunc 决定,因为本质上所有线程做的都是同一件事:向线程池中索要执行任务;
void ThreadPool::threadFunc() {
while(true) {
std::shared_ptr<Task> curTask;
{
std::unique_lock<std::mutex> lock(taskQueMtx_);
notEmpty_.wait(lock, [&]() -> bool { return taskQue_.size() > 0; });
curTask = taskQue_.front();
taskQue_.pop();
taskSize_--;
if (taskQue_.size() > 0) {
notEmpty_.notify_all();
}
notFull_.notify_all();
}
if (curTask != nullptr) {
curTask->exec();
}
}
ThreadPool 使用生产者消费者模型,当任务队列有任务时,notEmpty 信号量苏醒,取任务执行并通知,因为执行任务导致队列缩短,因此notFull 通知,可以来提交任务了;(注意增加的Block,不然taskQueMtx_ 会导致其他线程一直等待从而只有一个线程在实际运行;
Result ThreadPool::submitTask(std::shared_ptr<Task> ptr) {
std::unique_lock<std::mutex> lock(taskQueMtx_);
if (!notFull_.wait_for(lock, std::chrono::seconds(1), [&]()->bool{return taskQue_.size() < taskQueMaxThreshHold_;})) {
std::cerr<<"task queue is full, submitTask failed.\n";
return Result(ptr, false);
}
taskQue_.emplace(ptr);
taskSize_++;
notEmpty_.notify_all();
return Result(ptr);
}
提交任务会等待队列空出位置,不然报错,若有空余,添入队列并通知notEmpty 信号量;
2. Thread
Thread 很简单,注意开启detach;
Thread::Thread(ThreadFunc threadFunc) {
func_ = threadFunc;
}
Thread::~Thread() {
}
void Thread::start() {
std::thread t(func_);
t.detach();
}
3. Task
Task 为所有任务的父类,当有定义任务时,继承Task 并重写run() 即可:
class Task {
public:
void exec();
void setResult(Result*);
virtual Any run() = 0;
private:
Result* result_;
};
void Task::exec() {
if (result_ != nullptr) {
auto x = run();
result_->setVal(std::move(x));
}
}
void Task::setResult(Result* r) {
result_ = r;
}
考虑到线程执行完会有返回值,但是不同任务返回值可能不一样,引入Any 上帝类和 Result 结果类;
3. Any And Result
any C++17 中已经由标准实现,这儿手撸一个无值语义的Any,大致思路和下面标准类似:
any a = 1;
cout<<any_cast<int>(a)<<endl;
a = 2.2;
cout<<any_cast<double>(a)<<endl;
// >> 1
// >> 2.2
class Any {
public:
Any() = default;
~Any() = default;
Any(const Any &) = delete;
Any &operator=(const Any &) = delete;
Any(Any &&) = default;
Any &operator=(Any &&) = default;
template<typename T>
Any(T data) : base_(std::make_unique<Derived < T>>
(data)) {
std::cout<<"T is constructing "<<data<<std::endl;
}
template<typename T>
T cast() {
auto pd = dynamic_cast<Derived <T> *>(base_.get());
if (pd == nullptr) {
throw "type is unmatched";
}
return pd->get();
}
private:
class Base {
public:
virtual ~Base() = default;
};
template<typename T>
class Derived : public Base {
public:
Derived(T data) : data_(data) {}
~Derived() = default;
T get() {
return data_;
}
private:
T data_;
};
private:
std::unique_ptr<Base> base_;
};
Result 存放了Any 作为接受线程返回值,但是Result 获取有一个前提条件,任务线程必须先执行完再获取值,所以Result 生命周期必须比Task 长,而且Task 和 Result 有相互关系,Task 执行完向Result 存值,Result 持有shared_ptr 并在构造时向Task 传递;
实现的思路很简单,通过信号量同步:
Result::Result(std::shared_ptr<Task> sp, bool isValid): sp_(sp), isValid_(isValid){
sp_->setResult(this);
}
Any Result::get() {
if (!isValid_) {
return "";
}
sem_.wait();
return std::move(any_);
}
void Result::setVal(Any&& any) {
any_ = std::move(any);
sem_.post();
}
4. How to Use
class MyTask : public Task {
public:
MyTask(int begin, int end) : begin_(begin), end_(end) {}
Any run() override {
std::cout << "tid: " << std::this_thread::get_id() << " begin!\n";
int sum = 0;
for (int i = begin_; i <= end_; i++) {
sum += i;
}
std::cout << "tid: " << std::this_thread::get_id() << " end!\n";
return sum;
}
private:
int begin_;
int end_;
};
int main() {
ThreadPool pool;
pool.start(4);
Result res1 = pool.submitTask(std::make_shared<MyTask>(1, 1000));
Result res2 = pool.submitTask(std::make_shared<MyTask>(1001, 2000));
Result res3 = pool.submitTask(std::make_shared<MyTask>(2001, 3000));
int res = res1.get().cast<int>() + res2.get().cast<int>() + res3.get().cast<int>();
for (int i = 1; i <= 3000; i++) {
res -= i;
}
assert(res == 0);
exit(0);
}
assert 正常通过!