tensorflow之RunHandler

RunHandler是用来调度OP的。在多个Session:Run(s)时,就用到RunHandler来执行。其中的线程是来自GlobalPool。用RunHandler就不用直接使用GlobalPool了,而且RunHandler在调度上估了优化。

RunHandler只能通过 RunHandlerPool::Get().来获得。

class RunHandler {
 public:
  //调度一个fn
  void ScheduleInterOpClosure(std::function<void()> fn);
  thread::ThreadPoolInterface* AsIntraThreadPoolInterface();

  ~RunHandler();

 private:
  class Impl;
  friend class RunHandlerPool::Impl;

  explicit RunHandler(Impl* impl);

 };

//实际上使用pool_impl来执行,这里只是创建个work
void RunHandler::Impl::ScheduleInterOpClosure(std::function<void()> fn) {
  VLOG(3) << "Scheduling inter work for  " << tws()->GetTracemeId();
  pool_impl_->run_handler_thread_pool()->AddWorkToQueue(tws(), true,
                                                        std::move(fn));
}

void RunHandler::Impl::ScheduleIntraOpClosure(std::function<void()> fn) {
  VLOG(3) << "Scheduling intra work for " << tws()->GetTracemeId();
  pool_impl_->run_handler_thread_pool()->AddWorkToQueue(tws(), false,
                                                        std::move(fn));
}

RunHnadlerPool


// RunHandlerPool is a fixed size pool of pre-allocated RunHandlers
// that can be used for tracking inter-op work for a given Session::Run().
// RunHandler(s) in the pool are initially 'inactive'. A RunHandler becomes
// 'active' when its unique_ptr is returned by Get() and is being used by a
// client. It becomes 'inactive' once more when its unique_ptr gets destroyed.
//
// Expected usage:
//
// * Create a single RunHandlerPool (say run_handler_pool_).
//
// * When a Session::Run() is invoked, obtain a handler by:
// auto handler = run_handler_pool_->Get();
//
// * Use handler for scheduling all inter-op work by:
// handler->ScheduleInterOpClosure(closure);
//
// This class is thread safe.
class RunHandlerPool {
 public:
  explicit RunHandlerPool(int num_inter_op_threads);

  RunHandlerPool(int num_inter_op_threads, int num_intra_op_threads);
  ~RunHandlerPool();

  // Returns an inactive RunHandler from the pool.
  //
  // RunHandlers in RunHandlerPool are initially 'inactive'.
  // A RunHandler becomes 'active' when its unique_ptr its returned by Get()
  // and is being used by a client.  It becomes 'inactive' once more when the
  // unique_ptr is destroyed.
  //
  // Will block unless there is an inactive handler.
  std::unique_ptr<RunHandler> Get(
      int64_t step_id = 0, int64_t timeout_in_ms = 0,
      const RunOptions::Experimental::RunHandlerPoolOptions& options =
          RunOptions::Experimental::RunHandlerPoolOptions());

  // Get the priorities for active handlers. The return result is with the same
  // order of the active handler list.
  std::vector<int64_t> GetActiveHandlerPrioritiesForTesting() const;

 private:
  class Impl;
  friend class RunHandler;

  std::unique_ptr<Impl> impl_;
};

  • RunHandler里有个Impl.
  • RunHnadlerPool里也有个Impl.
  • 真正实现功能的是Impl。
  • Impl可以由外部实现

RunHandlerThreadPool是线程池,真正执行任务


class RunHandlerThreadPool {
 public:
  struct PerThread {
    constexpr PerThread() : pool(nullptr), thread_id(-1) {}
    RunHandlerThreadPool* pool;  // Parent pool, or null for normal threads.
    int thread_id;               // Worker thread index in pool.
  };

  RunHandlerThreadPool(int num_blocking_threads, int num_non_blocking_threads,
                       Env* env, const ThreadOptions& thread_options,
                       const string& name,
                       Eigen::MaxSizeVector<mutex>* waiters_mu,
                       Eigen::MaxSizeVector<Waiter>* queue_waiters);

  ~RunHandlerThreadPool();

  void Start();

  void StartOneThreadForTesting();

  void AddWorkToQueue(ThreadWorkSource* tws, bool is_blocking,
                      std::function<void()> fn);

  // Set work queues from which the thread 'tid' can steal its work.
  // The request with start_request_idx will be attempted first. Other requests
  // will be attempted in FIFO order based on their arrival time.
  void SetThreadWorkSources(
      int tid, int start_request_idx, uint64 version,
      const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources);

  PerThread* GetPerThread();

  int CurrentThreadId() const;

  int NumThreads() const;

  int NumBlockingThreads() const;

  int NumNonBlockingThreads() const;

  void WorkerLoop(int thread_id, bool may_steal_blocking_work);

  // Search tasks from Requets range searching_range_start to
  // searching_range_end. If there is no tasks in the search range and
  // may_steal_blocking_work is true, then search from all requests.
  Task FindTask(
      int searching_range_start, int searching_range_end, int thread_id,
      int sub_thread_pool_id, int max_blocking_inflight,
      bool may_steal_blocking_work,
      const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources,
      bool* task_from_blocking_queue, ThreadWorkSource** tws);

  void WaitForWork(bool is_blocking, int thread_id,
                   int32_t max_blocking_inflight);

  void WaitForWorkInSubThreadPool(bool is_blocking, int sub_thread_pool_id);

 private:
  struct ThreadData {
    ThreadData();
    mutex mu;
    uint64 new_version;
    condition_variable sources_not_empty;
    std::unique_ptr<Thread> thread;
    int current_index;
    std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>
        new_thread_work_sources TF_GUARDED_BY(mu);

    uint64 current_version;
    // Should only be accessed by one thread.
    std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>
        current_thread_work_sources;

    int sub_thread_pool_id;
  };

  const int num_threads_;
  const int num_blocking_threads_;
  const int num_non_blocking_threads_;
  Eigen::MaxSizeVector<ThreadData> thread_data_;
  internal::RunHandlerEnvironment env_;
  std::atomic<bool> cancelled_;
  string name_;
  Eigen::MaxSizeVector<mutex>* waiters_mu_;
  Eigen::MaxSizeVector<Waiter>* queue_waiters_;

  bool use_sub_thread_pool_;
  std::vector<int> num_threads_in_sub_thread_pool_;

  // Threads in each sub thread pool will search tasks from the given
  // start_request_percentage to end_request_percentage in a round robin
  // fashion.
  std::vector<double> sub_thread_pool_start_request_percentage_;
  std::vector<double> sub_thread_pool_end_request_percentage_;
};

//还是使用env_去创建thread
void RunHandlerThreadPool::Start() {
  cancelled_ = false;
  int num_blocking_threads = num_blocking_threads_;
  for (int i = 0; i < num_threads_; i++) {
    int sub_thread_pool_id = num_threads_in_sub_thread_pool_.size() - 1;
    for (int j = 0; j < num_threads_in_sub_thread_pool_.size(); ++j) {
      if (i < num_threads_in_sub_thread_pool_[j]) {
        sub_thread_pool_id = j;
        break;
      }
    }
    thread_data_[i].sub_thread_pool_id = sub_thread_pool_id;
    const bool is_blocking_thread = (i < num_blocking_threads) ? true : false;
    // The blocking threads will handle both inter and intra op workload;
    // non-blocking thread will handle intra op workload only; and the
    // sub thread pool is only provided for blocking threads.
    // Name the threads accordingly.
    thread_data_[i].thread.reset(env_.CreateThread(
        [this, is_blocking_thread, i, sub_thread_pool_id]() {
          WorkerLoop(i, is_blocking_thread);
        },
        is_blocking_thread
            ? strings::StrCat(name_, "_blocking_thread_", sub_thread_pool_id)
            : strings::StrCat(name_, "_non_blocking_thread")));
  }
}

基本用法

tensorflow/core/framework/run_handler_test.cc

如同一个线程池,能提交任务


TEST(RunHandlerUtilTest, TestBasicScheduling) {
  int num_threads = 2;
  int num_handlers = 10;
  创建pool,指定线程数
  std::unique_ptr<RunHandlerPool> pool(
      new RunHandlerPool(num_threads, num_threads));

  // RunHandler should always be able to run num_threads inter closures
  absl::Barrier barrier(num_threads);

  BlockingCounter counter(2 * num_handlers * num_threads);

  thread::ThreadPool test_pool(Env::Default(), "test", num_handlers);
  for (int i = 0; i < num_handlers; ++i) {
    test_pool.Schedule([&counter, &barrier, &pool, i, num_threads]() {
      auto handler = pool->Get(i);
      BlockingCounter local_counter(2 * num_threads);
      auto intra_thread_pool = handler->AsIntraThreadPoolInterface();

      for (int j = 0; j < num_threads; ++j) {
        handler->ScheduleInterOpClosure(
            [&local_counter, &counter, &barrier, i]() {
              if (i == 2) {
                barrier.Block();
              }
              counter.DecrementCount();
              local_counter.DecrementCount();
            });
        intra_thread_pool->Schedule([&local_counter, &counter]() {
          counter.DecrementCount();
          local_counter.DecrementCount();
        });
      }
      local_counter.Wait();
    });
  }
  counter.Wait();
}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值