mxnet的执行引擎用于序列化有依赖关系的一系列功能,目前有三种方式的引擎:NaiveEngine,ThreadedEnginePooled,ThreadedEnginePerDevice。这三种引擎都始于一个基类Engine,第一种引擎没有在mxnet中真正使用,而后两种引擎并不直接继承于Engine,而有一个中间类ThreadedEngine。ThreadedEnginePooled引擎用一个全局内存池完成所有设备的一般运算,用另一个内存池完成所有copy运算;ThreadedEnginePerDevice引擎为每一个设备固定线程的数量,用特定的线程完成所有copy运算。
在剖析engine之前,首先需要知道基本单元Var以及Op。
Var
Var就是用作virtual tag的基类,真正的实现在ThreadedVar 中:
class ThreadedVar final : public Var,
public common::ObjectPoolAllocatable<ThreadedVar> {
public:
......
......
inline void AppendReadDependency(OprBlock* opr_block);
inline void AppendWriteDependency(OprBlock* opr_block);
template <typename Dispatcher>
inline void CompleteReadDependency(Dispatcher dispatcher);
template <typename Dispatcher>
inline bool CompleteWriteDependency(Dispatcher dispatcher);
......
......
private:
std::mutex m_;
int num_pending_reads_{0};
VersionedVarBlock* head_{nullptr};
VersionedVarBlock* pending_write_{nullptr};
bool to_delete_{false};
static constexpr int kWriteTriggered = -1;
inline bool is_ready_to_read() const {
return pending_write_ == nullptr;
}
}; // struct ThreadedVar
其中,VersionedVarBlock是ThreadedVar类中LinkedList的基本单元:
struct VersionedVarBlock
: public common::ObjectPoolAllocatable<VersionedVarBlock> {
VersionedVarBlock* next{nullptr};
OprBlock* trigger{nullptr};
bool write{false};
DEFINE_ENGINE_DEBUG_INFO(VersionedVarBlock);
}; // struct VersionedVarBlock
trigger指向了这个VersionedVarBlock是什么操作引起的。
ThreadedVar 中,num_pending_reads_ 表示还未执行的读依赖;head_指向队列的尾部,是一个哨兵(空对象);pending_write_指向最老的写依赖;
AppendReadDependency
inline void ThreadedVar::AppendReadDependency(OprBlock* opr_block) {
std::lock_guard<std::mutex> lock{m_};
//如果没有写依赖
if(pending_write_ == nullptr){
// invariant: is_ready_to_read()
CHECK_GE(num_pending_reads_, 0);
// STATE CHANGE
++num_pending_reads_;
// decrease wait counter
opr_block->decr_wait();//该op可能还依赖其他var
}
else
{
//把读依赖添加到队列尾部;
auto&& new_var_block = VersionedVarBlock::New();
assert(head_->next == nullptr);
assert(head_->trigger == nullptr);
assert(head_->write == false);
// append things to next.
head_->next = new_var_block;
head_->trigger = opr_block;
head_ = new_var_block;
}
}
AppendWriteDependency
inline void ThreadedVar::AppendWriteDependency(OprBlock* opr_block) {
auto&& new_var_block = VersionedVarBlock::New();
std::lock_guard<std::mutex> lock{m_};
//将写依赖添加到队列尾部
// invariant.
assert(head_->next == nullptr);
assert(head_->trigger == nullptr);
assert(head_->write == false);
// attach to head.
head_->next = new_var_block;
head_->trigger = opr_block;
head_->write = true;
//如果没有写依赖
if (pending_write_ == nullptr) {
pending_write_ = head_;
CHECK_GE(num_pending_reads_, 0);
//没有读依赖
if (num_pending_reads_ == 0) {
// STATE CHANGE
opr_block->decr_wait();
num_pending_reads_ = kWriteTriggered;
}
} else {
CHECK_NE(num_pending_reads_, 0);
}
head_ = new_var_block;
}
CompleteReadDependency
template <typename Dispatcher>
inline void ThreadedVar::CompleteReadDependency(Dispatcher dispatcher) {
OprBlock *trigger = nullptr;
{
std::lock_guard<std::mutex> lock{m_};
CHECK_GT(num_pending_reads_, 0);
//如果所有读依赖全部执行完成
if (--num_pending_reads_ == 0) {
//如果队列中还有写依赖
if (pending_write_ != nullptr) {
// STATE CHANGE
trigger = pending_write_->trigger;
num_pending_reads_ = kWriteTriggered;//表示进入写状态
}
}
}
if (trigger != nullptr && trigger->decr_wait() == 0) {
dispatcher(trigger);//执行真正的操作,一般是PushToExecute
}
}
CompleteWriteDependency
template <typename Dispatcher>
inline bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) {
// this is lock scope
VersionedVarBlock *old_pending_write, *end_of_read_chain;
OprBlock* trigger_write = nullptr;
{
std::lock_guard<std::mutex> lock{m_};
// invariants
assert(head_->next == nullptr);
assert(pending_write_ != nullptr);
CHECK_EQ(num_pending_reads_, kWriteTriggered);
// really delete
if (to_delete_) {
VersionedVarBlock *head = pending_write_->next;
VersionedVarBlock::Delete(pending_write_);
assert(head_ == head);
VersionedVarBlock::Delete(head);
return true;
}
// detach pending write
old_pending_write = pending_write_;
// search for chains to trigger
end_of_read_chain = old_pending_write->next;
// reset to 0 pending reads
num_pending_reads_ = 0;
//写依赖的后面可能跟着多个读依赖,遍历链表直到发现下一个写依赖, 这个写依赖由 end_of_read_chain 指针来表示
while (end_of_read_chain != head_ &&
end_of_read_chain->write == false) {
++num_pending_reads_;
end_of_read_chain = end_of_read_chain->next;
//如果当前写依赖后面没有下一个写依赖了
if (end_of_read_chain == head_) {
pending_write_ = nullptr;
} else {
// check if there is pending reads, if not trigger write
assert(end_of_read_chain->write == true);
pending_write_ = end_of_read_chain;
if (num_pending_reads_ == 0) {
// mark write as already actived in this var
num_pending_reads_ = kWriteTriggered;
trigger_write = end_of_read_chain->trigger;
}
}
}
//这里出了lock的范围了
// pending_write_ 和num_pending_reads_不要再修改了
// [old_pending_write, end_of_read_chain)之间的LinkedList已从该Var分离
//old_pending_write指向已经完成的写依赖
VersionedVarBlock *cur_head = old_pending_write->next;
VersionedVarBlock::Delete(old_pending_write);
// dispatch all the events
//并行得执行这两个指针中间的读依赖
while (cur_head != end_of_read_chain) {
if (cur_head->trigger->decr_wait() == 0) {
dispatcher(cur_head->trigger);
}
auto prev = cur_head;
cur_head = cur_head->next;
assert(cur_head != nullptr);
VersionedVarBlock::Delete(prev);
}
if (trigger_write != nullptr && trigger_write->decr_wait() == 0) {
dispatcher(trigger_write);
}
return false;
}
Op
Opr是引擎operator的基类,ThreadEngine中真正的operator是ThreadedOpr,它表示某一个操作依赖的变量、函数及函数属性 :
struct ThreadedOpr final : public Opr,
public common::ObjectPoolAllocatable<ThreadedOpr> {
Engine::AsyncFn fn;
std::vector<ThreadedVar*> const_vars;
std::vector<ThreadedVar*> mutable_vars;
FnProperty prop;
//是否是一个临时operator,operation完成后即删除
bool temporary{false};
//把基类指针opr转化为ThreadedOpr指针
inline static ThreadedOpr* CastFromBase(Opr* ptr) {
return ptr->Cast<ThreadedOpr>();
}
// define possible debug information
DEFINE_ENGINE_DEBUG_INFO(ThreadedOpr);
}; // struct ThreadedOpr
而向添加读写依赖的函数传递的是OprBlock对象,它统计了当前op还没有就绪的Var的个数,其初始值为len(const_vars)+len(mutable_vars),它统计的个数为0才开始进行真正的计算:
struct OprBlock : public common::ObjectPoolAllocatable<OprBlock> {
/*!
* \brief wait number of pending tasks this OprBlock is waiting for.
*/
std::atomic<int> wait{0};
/*! \brief Pointer to information on performing real operation */
ThreadedOpr* opr{nullptr};
/*! \brief The context this operator */
Context ctx;
/*! \brief priority of the function */
int priority;
// define possible debug information
DEFINE_ENGINE_DEBUG_INFO(OprBlock);
/*!
* \brief call this function to decrease the wait counter.
* \return the wait counter after the decreasement.
*/
inline int decr_wait() {
// chack invariant, avoid over trigger
int ret = --wait;
CHECK_GE(ret, 0);
return ret;
}
}; // struct OprBlock
无论是var还是opr,内存分配的操作都是交给了ObjectPoolAllocatable类的,ObjectPoolAllocatable是ObjectPool类的简化调用,而ObjectPool真正实现了快速的内存分配和释放,它使用了placement new来重复利用已分配的内存构建对象,从而避免了频繁构造和析构对象所造成的内存碎片。
engine
以下是engine类的全貌,需要调用engine的模块,如graph_executor、sgd、ndarray等,调用Get()来使用,在调用之前必须由对应的create函数创建。
class MXNET_API Engine {
public:
virtual void NotifyShutdown() = 0;
virtual VarHandle NewVariable() = 0;
virtual OprHandle NewOperator(AsyncFn fn,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal) = 0;
virtual void DeleteOperator(OprHandle op) = 0;
virtual void Push(OprHandle op, Context exec_ctx, int priority = 0) = 0;
virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
int priority = 0) = 0;
virtual void DeleteVariable(SyncFn delete_fn,
Context exec_ctx,
VarHandle var) = 0;
virtual void WaitForVar(VarHandle var) = 0;
virtual void WaitForAll() = 0;
virtual ~Engine() noexcept(false) {}
static Engine* Get();
static std::shared_ptr<Engine> _GetSharedRef();
template<typename SyncFn>
inline void PushSync(SyncFn exec_fn, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
int priority = 0) {
this->PushAsync([exec_fn](RunContext ctx, CallbackOnComplete on_complete) {
exec_fn(ctx);
on_complete();
}, exec_ctx, const_vars, mutable_vars, prop, priority);
}
protected:
inline CallbackOnComplete CreateCallback(
void (*callback)(Engine *, void *), void *param) {
CallbackOnComplete ret;
ret.callback_ = callback;
ret.engine_ = this;
ret.param_ = param;
return ret;
}
}; // class Engine
Engine类中的FnProperty表示push进引擎的operator的函数属性,创建operator、push操作都需要这个属性,如下:
enum class FnProperty {
kNormal,
kCopyFromGPU,
kCopyToGPU,
kCPUPrioritized,
kAsync
};
引擎最主要的执行操作就是push及其一系列的拓展,以ThreadedEngine的push函数为例:
void ThreadedEngine::Push(OprHandle op, Context exec_ctx, int priority) {
ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op);
OprBlock* opr_block = OprBlock::New();
opr_block->opr = threaded_opr;
opr_block->wait.store(static_cast<int>(
threaded_opr->const_vars.size() +
threaded_opr->mutable_vars.size() + 1));
opr_block->ctx = exec_ctx;
opr_block->priority = priority;
++pending_;
// Add read dependencies.
for (auto&& i : threaded_opr->const_vars) {
i->AppendReadDependency(opr_block);
}
// Add write dependencies.
for (auto&& i : threaded_opr->mutable_vars) {
i->AppendWriteDependency(opr_block);
}
if (opr_block->decr_wait() == 0) {
this->PushToExecute(opr_block, true);
}
}
其中opr_block即为需要压入的op,将op添加到它所依赖的Var的队列中,包括读队列和写队列,最后由不同的引擎执行不同的PushToExecute函数,不同的PushToExecute函数都会调用ThreadedEngine的ExecuteOprBlock函数。
ThreadedEngine继承了Engine,并定义了Threaded类引擎需要的变量和其子类的通用操作,Engine没有定义的部分大致如下:
class ThreadedEngine : public Engine {
protected:
virtual void PushToExecute(OprBlock* opr_block, bool pusher_thread) = 0;
void ExecuteOprBlock(RunContext run_ctx, OprBlock *opr_block) {
ThreadedOpr* threaded_opr = opr_block->opr;
CallbackOnComplete callback = this->CreateCallback(
ThreadedEngine::OnCompleteStatic, threaded_opr);
if (!shutdown_phase_) {
try {
threaded_opr->fn(run_ctx, callback);
} catch(dmlc::Error &e) {
......
}
} else {
callback();
}
OprBlock::Delete(opr_block);
}
private:
void CheckDuplicate(std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars);
inline void OnComplete(ThreadedOpr* threaded_opr);
static void OnCompleteStatic(Engine *engine, void *threaded_opr);
/*!
* \brief Number of pending operations.
*/
std::atomic<int> pending_{0};
/*! \brief whether we want to kill the waiters */
std::atomic<bool> kill_{false};
/*! \brief whether it is during shutdown phase*/
std::atomic<bool> shutdown_phase_{false};
/*!\brief show more information from engine actions */
bool engine_info_{false};
/*!
* \brief Mutex and condition_variable,
* used to Notify waits for single or all variables.
*/
std::mutex finished_m_;
std::condition_variable finished_cv_;
/*!
* \brief Holding a shared_ptr to the object pool to prevent it from being destructed too early
* See also #309 (https://github.com/dmlc/mxnet/issues/309)
*/
std::shared_ptr<common::ObjectPool<ThreadedOpr> > objpool_opr_ref_;
std::shared_ptr<common::ObjectPool<OprBlock> > objpool_blk_ref_;
std::shared_ptr<common::ObjectPool<VersionedVarBlock> > objpool_varblk_ref_;
std::shared_ptr<common::ObjectPool<ThreadedVar> > objpool_var_ref_;
/*!
* \brief Disallow copy construction and assignment.
*/
DISALLOW_COPY_AND_ASSIGN(ThreadedEngine);
}
ThreadedEngine的OnCompleteStatic调用OnComplete:
void ThreadedEngine::OnCompleteStatic(
Engine *engine, void *threaded_opr) {
static_cast<ThreadedEngine*>(engine)->OnComplete(
static_cast<ThreadedOpr*>(threaded_opr));
}
OnComplete是在op操作完成后更新Var队列:
inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
// Mark complete for read variables
for (auto&& i : threaded_opr->const_vars) {
i->CompleteReadDependency([this](OprBlock* opr) {
this->PushToExecute(opr, false);
});
}
// Mark complete for write variables.
for (auto&& i : threaded_opr->mutable_vars) {
bool to_delete = i->CompleteWriteDependency(
[this](OprBlock* opr) {
this->PushToExecute(opr, false);
});
if (to_delete) {
ThreadedVar::Delete(i);
}
}
int npending;
{
std::unique_lock<std::mutex> lock{finished_m_};
npending = --pending_;
}
CHECK_GE(npending, 0);
if (npending == 0) {
// no need to grab lock when notify.
finished_cv_.notify_all();
}
// delte operator if it is temperory
if (threaded_opr->temporary) {
ThreadedOpr::Delete(threaded_opr);
}
}
可以看出,OnComplete函数向CompleteReadDependency 和 CompleteWriteDependency传递了采用lamda表达式表示的匿名函数(dispatcher),dispatcher主要是调用PushToExecute。