mxnet代码解析之dependency engine

mxnet的执行引擎用于序列化有依赖关系的一系列功能,目前有三种方式的引擎:NaiveEngine,ThreadedEnginePooled,ThreadedEnginePerDevice。这三种引擎都始于一个基类Engine,第一种引擎没有在mxnet中真正使用,而后两种引擎并不直接继承于Engine,而有一个中间类ThreadedEngine。ThreadedEnginePooled引擎用一个全局内存池完成所有设备的一般运算,用另一个内存池完成所有copy运算;ThreadedEnginePerDevice引擎为每一个设备固定线程的数量,用特定的线程完成所有copy运算。

在剖析engine之前,首先需要知道基本单元Var以及Op。

Var

Var就是用作virtual tag的基类,真正的实现在ThreadedVar 中:

class ThreadedVar final : public Var,
                          public common::ObjectPoolAllocatable<ThreadedVar> {
 public:
  ......
  ......
  inline void AppendReadDependency(OprBlock* opr_block);
  inline void AppendWriteDependency(OprBlock* opr_block);
  template <typename Dispatcher>
  inline void CompleteReadDependency(Dispatcher dispatcher);
  template <typename Dispatcher>
  inline bool CompleteWriteDependency(Dispatcher dispatcher);
  ......
  ......
 private:
  std::mutex m_;
  int num_pending_reads_{0};
  VersionedVarBlock* head_{nullptr};
  VersionedVarBlock* pending_write_{nullptr};
  bool to_delete_{false};
  static constexpr int kWriteTriggered = -1;
  inline bool is_ready_to_read() const {
    return pending_write_ == nullptr;
  }
};  // struct ThreadedVar

其中,VersionedVarBlock是ThreadedVar类中LinkedList的基本单元:

struct VersionedVarBlock
    : public common::ObjectPoolAllocatable<VersionedVarBlock> {
  VersionedVarBlock* next{nullptr};
  OprBlock* trigger{nullptr};
  bool write{false};
  DEFINE_ENGINE_DEBUG_INFO(VersionedVarBlock);
};  // struct VersionedVarBlock

trigger指向了这个VersionedVarBlock是什么操作引起的。

ThreadedVar 中,num_pending_reads_ 表示还未执行的读依赖;head_指向队列的尾部,是一个哨兵(空对象);pending_write_指向最老的写依赖;

AppendReadDependency
inline void ThreadedVar::AppendReadDependency(OprBlock* opr_block) {
    std::lock_guard<std::mutex> lock{m_};
    //如果没有写依赖
    if(pending_write_ == nullptr){
        // invariant: is_ready_to_read()
        CHECK_GE(num_pending_reads_, 0);
        // STATE CHANGE
        ++num_pending_reads_;
        // decrease wait counter
        opr_block->decr_wait();//该op可能还依赖其他var
     }
     else
     {
         //把读依赖添加到队列尾部;
          auto&& new_var_block = VersionedVarBlock::New();
          assert(head_->next == nullptr);
          assert(head_->trigger == nullptr);
          assert(head_->write == false);
          // append things to next.
          head_->next = new_var_block;
          head_->trigger = opr_block;
          head_ = new_var_block;
      }
}
AppendWriteDependency
inline void ThreadedVar::AppendWriteDependency(OprBlock* opr_block) {
    auto&& new_var_block = VersionedVarBlock::New();
    std::lock_guard<std::mutex> lock{m_};
    //将写依赖添加到队列尾部
    // invariant.
    assert(head_->next == nullptr);
    assert(head_->trigger == nullptr);
    assert(head_->write == false);
    // attach to head.
    head_->next = new_var_block;
    head_->trigger = opr_block;
    head_->write = true;
    //如果没有写依赖
    if (pending_write_ == nullptr) {
        pending_write_ = head_;
        CHECK_GE(num_pending_reads_, 0);
        //没有读依赖
        if (num_pending_reads_ == 0) {
            // STATE CHANGE
            opr_block->decr_wait();
            num_pending_reads_ = kWriteTriggered;
        }
    } else {
        CHECK_NE(num_pending_reads_, 0);
    }
    head_ = new_var_block;
}
CompleteReadDependency
template <typename Dispatcher>
inline void ThreadedVar::CompleteReadDependency(Dispatcher dispatcher) {
  OprBlock *trigger = nullptr;
  {
    std::lock_guard<std::mutex> lock{m_};
    CHECK_GT(num_pending_reads_, 0);
    //如果所有读依赖全部执行完成
    if (--num_pending_reads_ == 0) {
      //如果队列中还有写依赖
      if (pending_write_ != nullptr) {
        // STATE CHANGE
        trigger = pending_write_->trigger;
        num_pending_reads_ = kWriteTriggered;//表示进入写状态
      }
    }
  }
  if (trigger != nullptr && trigger->decr_wait() == 0) {
    dispatcher(trigger);//执行真正的操作,一般是PushToExecute
  }
}
CompleteWriteDependency
template <typename Dispatcher>
inline bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) {
  // this is lock scope
  VersionedVarBlock *old_pending_write, *end_of_read_chain;
  OprBlock* trigger_write = nullptr;
  {
    std::lock_guard<std::mutex> lock{m_};
    // invariants
    assert(head_->next == nullptr);
    assert(pending_write_ != nullptr);
    CHECK_EQ(num_pending_reads_, kWriteTriggered);

    // really delete
    if (to_delete_) {
      VersionedVarBlock *head = pending_write_->next;
      VersionedVarBlock::Delete(pending_write_);
      assert(head_ == head);
      VersionedVarBlock::Delete(head);
      return true;
    }
    // detach pending write
    old_pending_write = pending_write_;
    // search for chains to trigger
    end_of_read_chain = old_pending_write->next;
    // reset to 0 pending reads
    num_pending_reads_ = 0;
    //写依赖的后面可能跟着多个读依赖,遍历链表直到发现下一个写依赖, 这个写依赖由 end_of_read_chain 指针来表示
    while (end_of_read_chain != head_ &&
           end_of_read_chain->write == false) {
      ++num_pending_reads_;
      end_of_read_chain = end_of_read_chain->next;
    //如果当前写依赖后面没有下一个写依赖了
    if (end_of_read_chain == head_) {
      pending_write_ = nullptr;
    } else {
      // check if there is pending reads, if not trigger write
      assert(end_of_read_chain->write == true);
      pending_write_ = end_of_read_chain;
      if (num_pending_reads_ == 0) {
        // mark write as already actived in this var
        num_pending_reads_ = kWriteTriggered;
        trigger_write = end_of_read_chain->trigger;
      }
    }
  }
  //这里出了lock的范围了
  // pending_write_ 和num_pending_reads_不要再修改了
  // [old_pending_write, end_of_read_chain)之间的LinkedList已从该Var分离
  //old_pending_write指向已经完成的写依赖
  VersionedVarBlock *cur_head = old_pending_write->next;
  VersionedVarBlock::Delete(old_pending_write);

  // dispatch all the events
  //并行得执行这两个指针中间的读依赖
  while (cur_head != end_of_read_chain) {
    if (cur_head->trigger->decr_wait() == 0) {
      dispatcher(cur_head->trigger);
    }
    auto prev = cur_head;
    cur_head = cur_head->next;
    assert(cur_head != nullptr);
    VersionedVarBlock::Delete(prev);
  }
  if (trigger_write != nullptr && trigger_write->decr_wait() == 0) {
    dispatcher(trigger_write);
  }
  return false;
}

Op

Opr是引擎operator的基类,ThreadEngine中真正的operator是ThreadedOpr,它表示某一个操作依赖的变量、函数及函数属性 :

struct ThreadedOpr final : public Opr,
                           public common::ObjectPoolAllocatable<ThreadedOpr> {
  Engine::AsyncFn fn;
  std::vector<ThreadedVar*> const_vars;
  std::vector<ThreadedVar*> mutable_vars;
  FnProperty prop;
   //是否是一个临时operator,operation完成后即删除
  bool temporary{false};
   //把基类指针opr转化为ThreadedOpr指针
  inline static ThreadedOpr* CastFromBase(Opr* ptr) {
    return ptr->Cast<ThreadedOpr>();
  }
  // define possible debug information
  DEFINE_ENGINE_DEBUG_INFO(ThreadedOpr);
};  // struct ThreadedOpr

而向添加读写依赖的函数传递的是OprBlock对象,它统计了当前op还没有就绪的Var的个数,其初始值为len(const_vars)+len(mutable_vars),它统计的个数为0才开始进行真正的计算:

struct OprBlock : public common::ObjectPoolAllocatable<OprBlock> {
  /*!
   * \brief wait number of pending tasks this OprBlock is waiting for.
   */
  std::atomic<int> wait{0};
  /*! \brief Pointer to information on performing real operation */
  ThreadedOpr* opr{nullptr};
  /*! \brief The context this operator */
  Context ctx;
  /*! \brief priority of the function */
  int priority;
  // define possible debug information
  DEFINE_ENGINE_DEBUG_INFO(OprBlock);
  /*!
   * \brief call this function to decrease the wait counter.
   * \return the wait counter after the decreasement.
   */
  inline int decr_wait() {
    // chack invariant, avoid over trigger
    int ret = --wait;
    CHECK_GE(ret, 0);
    return ret;
  }
};  // struct OprBlock

无论是var还是opr,内存分配的操作都是交给了ObjectPoolAllocatable类的,ObjectPoolAllocatable是ObjectPool类的简化调用,而ObjectPool真正实现了快速的内存分配和释放,它使用了placement new来重复利用已分配的内存构建对象,从而避免了频繁构造和析构对象所造成的内存碎片。

engine

以下是engine类的全貌,需要调用engine的模块,如graph_executor、sgd、ndarray等,调用Get()来使用,在调用之前必须由对应的create函数创建。

class MXNET_API Engine {
 public:
  virtual void NotifyShutdown() = 0;
  virtual VarHandle NewVariable() = 0;
  virtual OprHandle NewOperator(AsyncFn fn,
                                std::vector<VarHandle> const& const_vars,
                                std::vector<VarHandle> const& mutable_vars,
                                FnProperty prop = FnProperty::kNormal) = 0;
  virtual void DeleteOperator(OprHandle op) = 0;
  virtual void Push(OprHandle op, Context exec_ctx, int priority = 0) = 0;
  virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx,
                         std::vector<VarHandle> const& const_vars,
                         std::vector<VarHandle> const& mutable_vars,
                         FnProperty prop = FnProperty::kNormal,
                         int priority = 0) = 0;
  virtual void DeleteVariable(SyncFn delete_fn,
                              Context exec_ctx,
                              VarHandle var) = 0;
  virtual void WaitForVar(VarHandle var) = 0;
  virtual void WaitForAll() = 0;
  virtual ~Engine() noexcept(false) {}
  static Engine* Get();
  static std::shared_ptr<Engine> _GetSharedRef();
  template<typename SyncFn>
  inline void PushSync(SyncFn exec_fn, Context exec_ctx,
                       std::vector<VarHandle> const& const_vars,
                       std::vector<VarHandle> const& mutable_vars,
                       FnProperty prop = FnProperty::kNormal,
                       int priority = 0) {
    this->PushAsync([exec_fn](RunContext ctx, CallbackOnComplete on_complete) {
        exec_fn(ctx);
        on_complete();
      }, exec_ctx, const_vars, mutable_vars, prop, priority);
  }

 protected:
  inline CallbackOnComplete CreateCallback(
      void (*callback)(Engine *, void *), void *param) {
    CallbackOnComplete ret;
    ret.callback_ = callback;
    ret.engine_ = this;
    ret.param_ = param;
    return ret;
  }
};  // class Engine

Engine类中的FnProperty表示push进引擎的operator的函数属性,创建operator、push操作都需要这个属性,如下:

enum class FnProperty {
  kNormal,
  kCopyFromGPU,
  kCopyToGPU,
  kCPUPrioritized,
  kAsync
}; 

引擎最主要的执行操作就是push及其一系列的拓展,以ThreadedEngine的push函数为例:

void ThreadedEngine::Push(OprHandle op, Context exec_ctx, int priority) {
  ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op);
  OprBlock* opr_block = OprBlock::New();
  opr_block->opr = threaded_opr;

  opr_block->wait.store(static_cast<int>(
      threaded_opr->const_vars.size() +
      threaded_opr->mutable_vars.size() + 1));
  opr_block->ctx = exec_ctx;
  opr_block->priority = priority;
  ++pending_;
  // Add read dependencies.
  for (auto&& i : threaded_opr->const_vars) {
    i->AppendReadDependency(opr_block);
  }
  // Add write dependencies.
  for (auto&& i : threaded_opr->mutable_vars) {
    i->AppendWriteDependency(opr_block);
  }
  if (opr_block->decr_wait() == 0) {
    this->PushToExecute(opr_block, true);
  }
}

其中opr_block即为需要压入的op,将op添加到它所依赖的Var的队列中,包括读队列和写队列,最后由不同的引擎执行不同的PushToExecute函数,不同的PushToExecute函数都会调用ThreadedEngine的ExecuteOprBlock函数。

ThreadedEngine继承了Engine,并定义了Threaded类引擎需要的变量和其子类的通用操作,Engine没有定义的部分大致如下:

class ThreadedEngine : public Engine {
protected:
    virtual void PushToExecute(OprBlock* opr_block, bool pusher_thread) = 0;
void ExecuteOprBlock(RunContext run_ctx, OprBlock *opr_block) {
    ThreadedOpr* threaded_opr = opr_block->opr;
    CallbackOnComplete callback = this->CreateCallback(
        ThreadedEngine::OnCompleteStatic, threaded_opr);
    if (!shutdown_phase_) {
      try {
        threaded_opr->fn(run_ctx, callback);
      } catch(dmlc::Error &e) {
            ......
      }
    } else {
      callback();
    }

    OprBlock::Delete(opr_block);
  }

 private:
  void CheckDuplicate(std::vector<VarHandle> const& const_vars,
                      std::vector<VarHandle> const& mutable_vars);

  inline void OnComplete(ThreadedOpr* threaded_opr);
  static void OnCompleteStatic(Engine *engine, void *threaded_opr);
  /*!
   * \brief Number of pending operations.
   */
  std::atomic<int> pending_{0};
  /*! \brief whether we want to kill the waiters */
  std::atomic<bool> kill_{false};
  /*! \brief whether it is during shutdown phase*/
  std::atomic<bool> shutdown_phase_{false};
  /*!\brief show more information from engine actions */
  bool engine_info_{false};
  /*!
   * \brief Mutex and condition_variable,
   *  used to Notify waits for single or all variables.
   */
  std::mutex finished_m_;
  std::condition_variable finished_cv_;

  /*!
   * \brief Holding a shared_ptr to the object pool to prevent it from being destructed too early
   * See also #309 (https://github.com/dmlc/mxnet/issues/309)
   */
  std::shared_ptr<common::ObjectPool<ThreadedOpr> >       objpool_opr_ref_;
  std::shared_ptr<common::ObjectPool<OprBlock> >          objpool_blk_ref_;
  std::shared_ptr<common::ObjectPool<VersionedVarBlock> > objpool_varblk_ref_;
  std::shared_ptr<common::ObjectPool<ThreadedVar> >       objpool_var_ref_;
  /*!
   * \brief Disallow copy construction and assignment.
   */
  DISALLOW_COPY_AND_ASSIGN(ThreadedEngine);
}

ThreadedEngine的OnCompleteStatic调用OnComplete:

void ThreadedEngine::OnCompleteStatic(
    Engine *engine, void *threaded_opr) {
  static_cast<ThreadedEngine*>(engine)->OnComplete(
      static_cast<ThreadedOpr*>(threaded_opr));
}

OnComplete是在op操作完成后更新Var队列:

inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
  // Mark complete for read variables
  for (auto&& i : threaded_opr->const_vars) {
    i->CompleteReadDependency([this](OprBlock* opr) {
        this->PushToExecute(opr, false);
      });
  }
  // Mark complete for write variables.
  for (auto&& i : threaded_opr->mutable_vars) {
    bool to_delete = i->CompleteWriteDependency(
        [this](OprBlock* opr) {
          this->PushToExecute(opr, false);
        });
    if (to_delete) {
      ThreadedVar::Delete(i);
    }
  }
  int npending;
  {
    std::unique_lock<std::mutex> lock{finished_m_};
    npending = --pending_;
  }
  CHECK_GE(npending, 0);
  if (npending == 0) {
    // no need to grab lock when notify.
    finished_cv_.notify_all();
  }
  // delte operator if it is temperory
  if (threaded_opr->temporary) {
    ThreadedOpr::Delete(threaded_opr);
  }
}

可以看出,OnComplete函数向CompleteReadDependency 和 CompleteWriteDependency传递了采用lamda表达式表示的匿名函数(dispatcher),dispatcher主要是调用PushToExecute。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值