caffe2源码分析第一篇----threadpool

从本篇开始分析总结caffe2的源码,caffe2中大量使用了modern c++的新特性,所以在学习神经网络的知识的同时,也同步学习modern c++的特性。最近正好在工作中需要用到threadpool, 所以就先从threadpool开始分析了。分析的方法包括源码分析和最后的归纳总结。
1.先来看threadpool的目录结构,源码在caffe2/utils/threadpool目录下。包含
ThreadPoolCommon.h 主要是针对不同设备(android ios mac等)的宏定义,来同时适配所有的设备。

pthreadpool.h, 定义了一套c接口,创建threadpool等, 还有一系列便于多维同步运算的函数(利用threadpool)。这里创建threadpool的函数创建的就是ThreadPool.h里面定义的threadpool。
pthreadpool.cc 多维同步运算的函数实现
pthreadpool_impl.cc 实现了pthreadpool.h中定义的pthreadpool_create等函数。
所以以上三个文件是利用下面几个文件定义的threadpool类来实现的一套便于多维同步运算的c接口。这里就不分析其源码了。主要分析以下三个文件的源码。
下面三个文件就是threadpool类的真正实现:
WorkerPool.h
ThreadPool.h
ThreadPool.cc

源码分析:
1.WorkerPool.h //本文件创建的workerpool就是后面threadpool的基础

#pragma once

#include <atomic>
#include <condition_variable>
#include <thread>
#include "c10/util/thread_name.h"
#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"

#if defined(_MSC_VER) // mocrosoft推出的针对c/c++的扩展宏定义
#include <intrin.h>
#endif

namespace caffe2 {

/*这里的code参考了google的gemmlowp, gemm是深度学习经常使用的广义矩阵乘积操作*/
// Uses code derived from gemmlowp,
// https://github.com/google/gemmlowp/blob/6c91e1ed0c2eff1182d804310b92911fe9c18019/internal/multi_thread_gemm.h
// Changes:
// - allocation-free execute()
// - Use RAII where possible.  RAII就是提倡将资源的管理释放与生存周期绑定,比如shared_ptr就是典型的RAII思想,出了生存范围自动系够。
// - Run the first task on the main thread (since that is the largest task).
// - removed custom allocator.
// - Removed some ifdef's
// - cache-line align Worker.
// - use std::atomic instead of volatile and custom barriers.
// - use std::mutex/std::condition_variable instead of raw pthreads.

constexpr size_t kGEMMLOWPCacheLineSize = 64;

/*这里定义了一个模板类,在类的内部提供了两个public的static的成员函数alloc和release, 用来分配指定字节对齐的内存
(这里是kGEMMLOWPCacheLineSize = 64)和释放该内存。下面根据不同的平台,调用不同的分配对齐内存的函数。
由于其内部没有成员变量,只有两个static函数,所以该类使用时不用创建对象,所以其实就是定义了两个模板函数来
提供分配对齐内存的方法。这里可以用任何类型来实例化模板,然后就可以在对齐的地址上创建任何对象了*/
template <typename T>
struct AllocAligned {
  // Allocate a T aligned at an `align` byte address
  template <typename... Args>
  static T* alloc(Args&&... args) {
    void* p = nullptr;
/*会返回kGEMMLOWPCacheLineSize对其的地址,这里就是64字节对齐的地址*/
#if defined(__ANDROID__)
    p = memalign(kGEMMLOWPCacheLineSize, sizeof(T));
#elif defined(_MSC_VER)
    p = _aligned_malloc(sizeof(T), kGEMMLOWPCacheLineSize);
#else
    posix_memalign((void**)&p, kGEMMLOWPCacheLineSize, sizeof(T)); 
#endif

    if (p) {
    /*这里是new的一种比较少用的用法,就是在指定的内存上(这里就是p所指向的地址)构建对象T。
    参数使用完美转发(std::forward)来支持参数的copy和move。Args&&... args实现了可变参数的万能引用,即可以传入左值,又可以传入右值
    (关于万能引用不了解的可以查阅相关资料)。*/
      return new (p) T(std::forward<Args>(args)...); 
    }

    return nullptr;
  }

  // Free a T previously allocated via AllocAligned<T>::alloc()
  static void release(T* p) {
    if (p) {
      p->~T();
#if defined(_MSC_VER)
      _aligned_free((void*)p);
#else
      free((void*)p);
#endif
    }
  }
};

/*下面两个struct就是利用上面的在对齐地址上构建对象的类来创建Unique_ptr,因为这里是通过特殊的对齐内存
方式分配的内存,不能简单的用delete来析构,所以需要指定deleter。下面第一个struct就是定义了一个deleter来
析构该unique_ptr维护的内存。*/
// Deleter object for unique_ptr for an aligned object
template <typename T>
struct AlignedDeleter {
  void operator()(T* p) const { AllocAligned<T>::release(p); }
};

// make_unique that guarantees alignment
template <typename T>
struct MakeAligned {
  template <typename... Args>
  /*这里就创建了一个make uniqie_ptr的函数,利用这个函数,就可以创建任何类型的Unique_ptr,并且内部指定了deleter. 
   这个就是外部要用的主要函数。利用uniqie_ptr即可以直接使用,又可以用来初始化shared_ptr,非常方便*/
  static std::unique_ptr<T, AlignedDeleter<T>> make(Args&&... args) {
      return std::unique_ptr<T, AlignedDeleter<T>>(
        AllocAligned<T>::alloc(std::forward<Args>(args)...));
  }
};

/*下面定义了一系列多个nop的宏,一个nop就是一个空的机器指令,可以用于快速的,消耗非常少的短暂等待*/
const int kMaxBusyWaitNOPs = 32 * 1000 * 1000;

#if defined(_MSC_VER)
#define GEMMLOWP_NOP __nop();
#else
#define GEMMLOWP_NOP "nop\n"
#endif

#define GEMMLOWP_STRING_CONCAT_4(X) X X X X
#define GEMMLOWP_NOP4 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP)
#define GEMMLOWP_NOP16 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP4)
#define GEMMLOWP_NOP64 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP16)

inline int Do256NOPs() {
#if defined(_MSC_VER)
  GEMMLOWP_NOP64;
#else
  asm volatile(GEMMLOWP_NOP64);
#endif
  return 64;
}

#undef GEMMLOWP_STRING_CONCAT_4
#undef GEMMLOWP_NOP256
#undef GEMMLOWP_NOP64
#undef GEMMLOWP_NOP16
#undef GEMMLOWP_NOP4
#undef GEMMLOWP_NOP

// Waits until *var != initial_value.
//
// Returns the new value of *var. The guarantee here is that
// the return value is different from initial_value, and that that
// new value has been taken by *var at some point during the
// execution of this function. There is no guarantee that this is
// still the value of *var when this function returns, since *var is
// not assumed to be guarded by any lock.
//
// First does some busy-waiting for a fixed number of no-op cycles,
// then falls back to passive waiting for the given condvar, guarded
// by the given mutex.
//
// The idea of doing some initial busy-waiting is to help get
// better and more consistent multithreading benefits for small GEMM sizes.
// Busy-waiting help ensuring that if we need to wake up soon after having
// started waiting, then we can wake up quickly (as opposed to, say,
// having to wait to be scheduled again by the OS). On the other hand,
// we must still eventually revert to passive waiting for longer waits
// (e.g. worker threads having finished a GEMM and waiting until the next GEMM)
// so as to avoid permanently spinning.
//
/*这个函数就是利用condition_variable来等待某个变量变化,这个变量必须为atomic变量,但这里为了速度优化,
就在利用condition之前,先尝试做了nop来短暂等待,尝试之后如果还没有变化,则利用condition来wait。
这里面用到了一个CHECK_NE函数,这是caffe里面的打印的宏,用于debug。这一组函数如下:
CHECK_EQ(x,y)<<"x!=y",EQ即equation,意为“等于”,函数判断是否x等于y,当x!=y时,函数打印出x!=y。
  CHECK_NE(x,y)<<"x=y",NE即not equation,意为“不等于”,函数判断是否x不等于y,当x=y时,函数打印出x=y。
  CHECK_LE(x,y) <<"x>=y",LE即lower equation,意为小于等于,函数判断是否x小于等于y。当x>=y时,函数打印x>=y。
  CHECK_LT(x,y)<<"x>=y",LT即为lower to ,意为小于,函数判断是否x小于y,当x>y时,函数打印x>y。
  CHECK_GE(x,y) <<"x<=y",GE即为great equation,意为大于。判断意义根据上述可推导出。
  CHECK_GT(x,y) <<"x<=y",*/
  /*另外下面用到了一些诸如std::memory_order_relaxed, std::memory_order_acquire,  std::atomic_thread_fence()等特性,
  这些都是为了设置内存一致性的,因为编译器优化的时候会打乱一些代码执行顺序,所以如果对于一些有特殊执行顺序的代码,
  需要利用这些特性来告诉编译器维护内存的一致性,
  具体可以参考下面这个博客 :https://blog.csdn.net/lvdan1/article/details/54098559 */
template <typename T>
T WaitForVariableChange(std::atomic<T>* var,
                        T initial_value,
                        std::condition_variable* cond,
                        std::mutex* mutex) {
  // If we are on a platform that supports it, spin for some time.
  {
    int nops = 0;
    // First, trivial case where the variable already changed value.
    /*直接检查一次是否有变化了*/
    T new_value = var->load(std::memory_order_relaxed);
    if (new_value != initial_value) {
      std::atomic_thread_fence(std::memory_order_acquire); //内存栅栏,用于保证线程间内存一致性。
      return new_value;
    }
    // Then try busy-waiting.
    /*等待kMaxBusyWaitNOPs个nop,看看是否已经变化了*/
    while (nops < kMaxBusyWaitNOPs) {
      nops += Do256NOPs();
      new_value = var->load(std::memory_order_relaxed);
      if (new_value != initial_value) {
        std::atomic_thread_fence(std::memory_order_acquire);
        return new_value;
      }
    }
  }
/*这里就真正的利用condition来wait了,*/
  // Finally, do real passive waiting.
  {
    std::unique_lock<std::mutex> g(*mutex);
    T new_value = var->load(std::memory_order_relaxed);
    // Handle spurious wakeups.
    cond->wait(g, [&]() { //利用lambda表达式来设置返回的条件,也就是只有当值变化时来真正返回
      new_value = var->load(std::memory_order_relaxed);
      return new_value != initial_value;
    });
    DCHECK_NE(static_cast<size_t>(new_value), static_cast<size_t>(initial_value));
    return new_value;
  }
}

// A BlockingCounter lets one thread to wait for N events to occur.
// This is how the master thread waits for all the worker threads
// to have finished working.
/*上面的注释已经说的比较明白了,就是利用一个thread来等待多个event发生,
主要用于主线程等待所有的worker 线程的工作都结束*/
class BlockingCounter {
 public:
  // Sets/resets the counter; initial_count is the number of
  // decrementing events that the Wait() call will be waiting for.
  void Reset(std::size_t initial_count) {
  //lock_guard资源消耗比Unique_lock小,但是不能主动解锁,所以不能用于condition_variable, unique_lock可以用于condition。
    std::lock_guard<std::mutex> g(mutex_);
    DCHECK_EQ(count_, 0);
    count_ = initial_count;
  }

  // Decrements the counter; if the counter hits zero, signals
  // the thread that was waiting for that, and returns true.
  // Otherwise (if the decremented count is still nonzero),
  // returns false.
  bool DecrementCount() {
    const auto count_value = count_.fetch_sub(1, std::memory_order_relaxed) - 1;
    DCHECK_GE(count_value, 0);
    if (count_value == 0) {
      std::lock_guard<std::mutex> g(mutex_);
      cond_.notify_one();
    }
    bool retval = count_value == 0;
    return retval;
  }

  // Waits for the N other threads (N having been set by Reset())
  // to hit the BlockingCounter.
  void Wait() {
    while (size_t count_value = count_.load(std::memory_order_relaxed)) {
      WaitForVariableChange(&count_, count_value, &cond_, &mutex_);//利用了上面定义的等待value change的函数。
    }
  }

 private:
  std::condition_variable cond_;
  std::mutex mutex_;
  std::atomic<std::size_t> count_{0};
};

// A workload for a worker.
struct Task {
  Task() {}
  virtual ~Task() {}
  virtual void Run() = 0;
};

// A worker thread.
/*里面管理一个线程,和加给该线程的task,以及condition和mutex等资源*/
class alignas(kGEMMLOWPCacheLineSize) Worker {
 public:
  enum class State : uint8_t {
    ThreadStartup, // The initial state before the thread main loop runs.
    Ready, // Is not working, has not yet received new work to do.
    HasWork, // Has work to do.
    ExitAsSoonAsPossible // Should exit at earliest convenience.
  };

  explicit Worker(BlockingCounter* counter_to_decrement_when_ready)
      : task_(nullptr),
        state_(State::ThreadStartup),
        counter_to_decrement_when_ready_(counter_to_decrement_when_ready) {
    thread_ = caffe2::make_unique<std::thread>([this]() { this->ThreadFunc(); });
  }

  ~Worker() {
    ChangeState(State::ExitAsSoonAsPossible);
    thread_->join();
  }

  // Changes State; may be called from either the worker thread
  // or the master thread; however, not all state transitions are legal,
  // which is guarded by assertions.
  void ChangeState(State new_state) {
    std::lock_guard<std::mutex> g(state_mutex_);
    DCHECK(new_state != state_.load(std::memory_order_relaxed));
    switch (state_.load(std::memory_order_relaxed)) {
    case State::ThreadStartup:
      DCHECK(new_state == State::Ready);
      break;
    case State::Ready:
      DCHECK(new_state == State::HasWork || new_state == State::ExitAsSoonAsPossible);
      break;
    case State::HasWork:
      DCHECK(new_state == State::Ready || new_state == State::ExitAsSoonAsPossible);
      break;
    default:
      abort();
    }
    state_.store(new_state, std::memory_order_relaxed);
    state_cond_.notify_one();
    if (new_state == State::Ready) {
      counter_to_decrement_when_ready_->DecrementCount();
    }
  }

  // Thread entry point.
  void ThreadFunc() {
    c10::setThreadName("CaffeWorkersPool");
    ChangeState(State::Ready);

    // Thread main loop
    while (true) {
      // Get a state to act on
      // In the 'Ready' state, we have nothing to do but to wait until
      // we switch to another state.
      State state_to_act_upon =
          WaitForVariableChange(&state_, State::Ready, &state_cond_, &state_mutex_);

      // We now have a state to act on, so act.
      switch (state_to_act_upon) {
      case State::HasWork:
        // Got work to do! So do it, and then revert to 'Ready' state.
        DCHECK(task_.load());
        (*task_).Run();
        task_ = nullptr;
        ChangeState(State::Ready);
        break;
      case State::ExitAsSoonAsPossible:
        return;
      default:
        abort();
      }
    }
  }

  static void* ThreadFunc(void* arg) {
    static_cast<Worker*>(arg)->ThreadFunc();
    return nullptr;
  }

  // Called by the master thead to give this worker work to do.
  // It is only legal to call this if the worker
  void StartWork(Task* task) {
    DCHECK(!task_.load());
    task_ = task;
    DCHECK(state_.load(std::memory_order_acquire) == State::Ready);
    ChangeState(State::HasWork);
  }

 private:
  // The underlying thread.
  std::unique_ptr<std::thread> thread_;

  // The task to be worked on.
  std::atomic<Task*> task_;

  // The condition variable and mutex guarding state changes.
  std::condition_variable state_cond_;
  std::mutex state_mutex_;

  // The state enum tells if we're currently working, waiting for work, etc.
  std::atomic<State> state_;

  // pointer to the master's thread BlockingCounter object, to notify the
  // master thread of when this worker switches to the 'Ready' state.
  BlockingCounter* const counter_to_decrement_when_ready_;
};
/*真正的workpool了,就是内部管理几个worker,分配任务给他们,并且管理它们的生存周期,
它只有一个Execute函数来传任务进来,并且等待任务都完成才返回。所以是一个同步的接口,
添加进来的是Task的vector,这里task定义了一个虚类,所以要根据具体情况创建具体的task*/
class WorkersPool {
 public:
  WorkersPool() {}

  void Execute(const std::vector<std::shared_ptr<Task>>& tasks) {
    CAFFE_ENFORCE_GE(tasks.size(), 1);
    // One of the tasks will be run on the current thread.
    int workers_count = tasks.size() - 1;
    CreateWorkers(workers_count); //根据task个数创建对应的worker。
    DCHECK_LE(workers_count, (int)workers_.size());
    counter_to_decrement_when_ready_.Reset(workers_count);
    for (size_t task = 1; task < tasks.size(); ++task) {
      workers_[task - 1]->StartWork(tasks[task].get());
    }
    // Execute the remaining workload immediately on the current thread.
    auto& task = tasks.front();
    task->Run();//留一个任务在当前线程执行,其他的分配给其他线程。
    // Wait for the workers submitted above to finish.
    //完成当前任务后等待其他线程完成。所以是一个同步的调用,只有当所有任务都完成才返回
    counter_to_decrement_when_ready_.Wait();
  }

 private:
  // Ensures that the pool has at least the given count of workers.
  // If any new worker has to be created, this function waits for it to
  // be ready.
  void CreateWorkers(std::size_t workers_count) {
    if (workers_.size() >= workers_count) {
      return;
    }
    counter_to_decrement_when_ready_.Reset(workers_count - workers_.size());
    while (workers_.size() < workers_count) {
      workers_.push_back(MakeAligned<Worker>::make(&counter_to_decrement_when_ready_));
    }
    counter_to_decrement_when_ready_.Wait();
  }

  C10_DISABLE_COPY_AND_ASSIGN(WorkersPool);
  std::vector<std::unique_ptr<Worker, AlignedDeleter<Worker>>> workers_;
  // The BlockingCounter used to wait for the workers.
  BlockingCounter counter_to_decrement_when_ready_;
};
} // namespace caffe2

ThreadPool.h

#ifndef CAFFE2_UTILS_THREADPOOL_H_
#define CAFFE2_UTILS_THREADPOOL_H_

#include "ThreadPoolCommon.h"

#include <functional>
#include <memory>
#include <mutex>
#include <vector>

#include "caffe2/core/common.h"

//
// A work-stealing threadpool loosely based off of pthreadpool
//

namespace caffe2 {

struct Task;
class WorkersPool;

constexpr size_t kCacheLineSize = 64;

// A threadpool with the given number of threads.
// NOTE: the kCacheLineSize alignment is present only for cache
// performance, and is not strictly enforced (for example, when
// the object is created on the heap). Thus, in order to avoid
// misaligned intrinsics, no SSE instructions shall be involved in
// the ThreadPool implementation.
// Note: alignas is disabled because some compilers do not deal with
// CAFFE2_API and alignas annotations at the same time.

class CAFFE2_API /*alignas(kCacheLineSize)*/ ThreadPool {  //ThreadPool类
 public:
  static std::unique_ptr<ThreadPool> defaultThreadPool();//会根据cpu的核的个数来决定thread的个数,这样计算的速度最快
  ThreadPool(int numThreads); //传入thread的个数的构造函数。
  ~ThreadPool();
  // Returns the number of threads currently in use
  int getNumThreads() const;

  // Sets the minimum work size (range) for which to invoke the
  // threadpool; work sizes smaller than this will just be run on the
  // main (calling) thread
  /*有一个minWorkSize_的成员变量,如果tasks的个数少于这个值,则在本线程全部执行,不会分给各个线程,
  节省各个线程调度的消耗*/
  void setMinWorkSize(size_t size); 
  size_t getMinWorkSize() const { return minWorkSize_; }
  /*这个就是外部调用run task的函数,具体的解释在该函数的实现处*/
  void run(const std::function<void(int, size_t)>& fn, size_t range);

  // Run an arbitrary function in a thread-safe manner accessing the Workers
  // Pool
  /*直接调用workersPool来执行某个函数*/
  void withPool(const std::function<void(WorkersPool*)>& fn);

 private:
  mutable std::mutex executionMutex_;
  size_t minWorkSize_;
  size_t numThreads_;
  std::shared_ptr<WorkersPool> workersPool_;
  std::vector<std::shared_ptr<Task>> tasks_;
};

} // namespace caffe2

#endif // CAFFE2_UTILS_THREADPOOL_H_

ThreadPool.cc

#include "caffe2/utils/threadpool/ThreadPool.h"
#include "WorkersPool.h"
#include "caffe2/core/logging.h"

#include <cpuinfo.h>

C10_DEFINE_bool(
    caffe2_threadpool_force_inline,
    false,
    "Force to always run jobs on the calling thread");

// Whether or not threadpool caps apply to Android
C10_DEFINE_int(caffe2_threadpool_android_cap, true, "");

// Whether or not threadpool caps apply to iOS
C10_DEFINE_int(caffe2_threadpool_ios_cap, true, "");

namespace caffe2 {

// Default smallest amount of work that will be partitioned between
// multiple threads; the runtime value is configurable
constexpr size_t kDefaultMinWorkSize = 1;
/*利用获取的cpu核的参数来决定创建的线程数*/
std::unique_ptr<ThreadPool> ThreadPool::defaultThreadPool() {
  CAFFE_ENFORCE(cpuinfo_initialize(), "cpuinfo initialization failed");
  int numThreads = cpuinfo_get_processors_count();

  bool applyCap = false;
#if C10_ANDROID
  applyCap = FLAGS_caffe2_threadpool_android_cap;
#elif C10_IOS
  applyCap = FLAGS_caffe2_threadpool_ios_cap;
#endif

  if (applyCap) {
    switch (numThreads) {
#if C10_ANDROID && (CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64)
      case 4:
          switch (cpuinfo_get_core(0)->midr & UINT32_C(0xFF00FFF0)) {
            case UINT32_C(0x51002110): /* Snapdragon 820 Kryo Silver */
            case UINT32_C(0x51002010): /* Snapdragon 821 Kryo Silver */
            case UINT32_C(0x51002050): /* Snapdragon 820/821 Kryo Gold */
              /* Kryo: 2+2 big.LITTLE */
              numThreads = 2;
              break;
            default:
              /* Anything else: assume homogeneous architecture */
              numThreads = 4;
              break;
          }
        break;
#endif
      case 5:
        /* 4+1 big.LITTLE */
        numThreads = 4;
        break;
      case 6:
        /* 2+4 big.LITTLE */
        numThreads = 2;
        break;
      case 8:
        /* 4+4 big.LITTLE */
        numThreads = 4;
        break;
      case 10:
        /* 4+4+2 Min.Med.Max, running on Med cores */
        numThreads = 4;
        break;
      default:
        if (numThreads > 4) {
          numThreads = numThreads / 2;
        }
        break;
    }
  }
  LOG(INFO) << "Constructing thread pool with " << numThreads << " threads";
  return caffe2::make_unique<ThreadPool>(numThreads);//创建具体的threadPool
}

ThreadPool::ThreadPool(int numThreads)
    : minWorkSize_(kDefaultMinWorkSize), numThreads_(numThreads),
      workersPool_(std::make_shared<WorkersPool>()) {}

ThreadPool::~ThreadPool() {}

int ThreadPool::getNumThreads() const {
  std::lock_guard<std::mutex> guard(executionMutex_);
  return numThreads_;
}

// Sets the minimum work size (range) for which to invoke the
// threadpool; work sizes smaller than this will just be run on the
// main (calling) thread
void ThreadPool::setMinWorkSize(size_t size) {
  std::lock_guard<std::mutex> guard(executionMutex_);
  minWorkSize_ = size;
}

void ThreadPool::run(const std::function<void(int, size_t)>& fn, size_t range) { //range就是fn要执行的次数
  std::lock_guard<std::mutex> guard(executionMutex_);
  // If there are no worker threads, or if the range is too small (too
  // little work), just run locally
  const bool runLocally = range < minWorkSize_ ||
      FLAGS_caffe2_threadpool_force_inline || (numThreads_ == 0);
  if (runLocally) {
  /*满足本线程执行的条件,就在本线程依次执行。*/
    // Work is small enough to just run locally; multithread overhead
    // is too high
    for (size_t i = 0; i < range; ++i) {
      fn(0, i);  //这里fn的两个参数有一些不明白,后面看到具体用例的时候再回来补充
    }
    return;
  }
//下面就需要多线程跑了,所以定义了Task类的子类FnTask
/*这里具体解释一下原理,在thread中真正要跑的任务函数是std::function<void(int, size_t)>, 就是有两个参数的一个函数,
第一个参数为thread的index,也就是说如果该threadPool创建的4个线程,index就是0-3,第二个参数为任务的index,就是上面
run函数传入的range参数就是fn要运行的次数,比如这里我们假设为9, 则每个线程要跑(9+4-1)/4 = 3, 也就是基本每个线程最多跑3个任务,
则在FnTask中就有两个参数(start_和end_)来记录该task所需要跑的fn的index,这里range是9,则该fn要运行9次,则第一个线程
要跑的index为start_=0, end_=3,第二个为start_=3, end_=6, 其他依次类推。从这里我们可以看出, fn的两个参数就是上面我们解释
的,一个thread的index, 一个task的index。range就是task的数量。该函数会根据线程的数量和task的数量,平均将task分配给各个线程。
通过fn的两个参数,在fn中可以知道该fn是跑在哪个线程中(线程index),是第几个task(task index)*/
  struct FnTask : public Task {
    FnTask(){};
    ~FnTask() override{};
    const std::function<void(int, size_t)> *fn_;
    int idx_;
    size_t start_;
    size_t end_;
    void Run() override {
      for (auto i = start_; i < end_; ++i) {
        (*fn_)(idx_, i);
      }
    }
  };

  CAFFE_ENFORCE_GE(numThreads_, 1);
  const size_t unitsPerTask = (range + numThreads_ - 1) / numThreads_;//这里算一下每个线程要执行的任务数。
  tasks_.resize(numThreads_);
  for (size_t i = 0; i < numThreads_; ++i) { /根据threadNum的个数来创建task队列
    if (!tasks_[i]) {
      tasks_[i].reset(new FnTask());
    }
    auto *task = (FnTask *)tasks_[i].get();//初始化每个FnTask
    task->fn_ = &fn;
    task->idx_ = i; //thread的index
    task->start_ = std::min<size_t>(range, i * unitsPerTask);
    task->end_ = std::min<size_t>(range, (i + 1) * unitsPerTask);
    if (task->start_ >= task->end_) {
      tasks_.resize(i);
      break;
    }
    CAFFE_ENFORCE_LE(task->start_, range);
    CAFFE_ENFORCE_LE(task->end_, range);
  }
  CAFFE_ENFORCE_LE(tasks_.size(), numThreads_);
  CAFFE_ENFORCE_GE(tasks_.size(), 1);
  workersPool_->Execute(tasks_);
}
/*这里相当与将内部的workersPool_直接给外部使用*/
void ThreadPool::withPool(const std::function<void(WorkersPool*)>& f) {
  std::lock_guard<std::mutex> guard(executionMutex_);
  f(workersPool_.get());
}

} // namespace caffe2
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值