SyncCopyFromCPU、SyncCopyToCPU
SyncCopyFromCPU从内存连续区域同步拷贝数据,会在拷贝前调用WaitToWrite,直到当前NDArray的所有读/写操作都完成后可执行写入操作。
void NDArray::SyncCopyFromCPU(const void *data, size_t size) const {
TShape dshape = this->shape();
CHECK_EQ(dshape.Size(), size) << "Memory size do not match";//检测内存与目标存储空间规模是否相同
TBlob src((void*)data, dshape, cpu::kDevMask, this->dtype_, 0);//在内存中构成TBlob数据结构
if (this->ctx().dev_mask() == cpu::kDevMask) {//如果是拷贝到别的内存位置
this->WaitToWrite();//等待现有读写操作完成
RunContext rctx{this->ctx(), nullptr};//定义时钟
TBlob dst = this->data();//获取目标内存位置
ndarray::Copy<cpu, cpu>(src, &dst, Context::CPU(), Context::CPU(), rctx);//同步拷贝
} else {//如果是拷贝到显存中
#if MXNET_USE_CUDA
Engine::Get()->PushAsync(
[&](RunContext rctx, Engine::CallbackOnComplete on_complete) {
TBlob dst = this->data();//目标显存位置
ndarray::Copy<cpu, gpu>(src, &dst, Context::CPU(), this->ctx(), rctx);
// Wait GPU kernel to complete
rctx.get_stream<gpu>()->Wait();//等待gpu处理完成
on_complete();//调用回调函数
}, this->ctx(), {}, {this->var()},//显存区域是可更改的
FnProperty::kCopyToGPU, 0, "SyncCopyCPU2GPU");//优先级为0
this->WaitToRead();//需要等待当前NDArray的写操作完成可开始执行读取操作
#else
LOG(FATAL) << "GPU is not enabled";
#endif
}
}
WaitToRead和WaitToWrite的主要等待部分为WaitForVar,当引擎为naive_engine时,WaitForVar为空;当引擎为threaded_engine时:
void ThreadedEngine::WaitForVar(VarHandle var) {
BulkFlush();//flush current bulk to execution
ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var);
if (threaded_var->ready_to_read()) {
ThrowException(threaded_var);
return;
}
if (engine_info_) {
LOG(INFO) << "Wait for " << threaded_var;
debug_wait_var_ = threaded_var;
}
std::atomic<bool> done{false};
this->PushAsync([this, &done](RunContext, CallbackOnComplete on_complete) {
if (engine_info_) {
LOG(INFO) << "Sync is executed";
}
{
std::unique_lock<std::mutex> lock{finished_m_};
done.store(true);
}
finished_cv_.notify_all();
if (engine_info_) {
LOG(INFO) << "Sync is notified";
}
on_complete();
}, Context::CPU(), {var}, {}, FnProperty::kNormal, 0, "WaitForVar", true);
{
std::unique_lock<std::mutex> lock{finished_m_};
finished_cv_.wait(lock, [this, &done]() {
return done.load() || kill_.load();
});
}
ThrowException(threaded_var);
}
SyncCopyToCPU是拷贝数据到内存,在拷贝前调用WaitToRead:
void NDArray::SyncCopyToCPU(void *data, size_t size) const {
TShape dshape = this->shape();
CHECK_EQ(dshape.Size(), size) << "Memory size do not match";
TBlob dst(data, dshape, cpu::kDevMask, this->dtype_, 0);
if (this->ctx().dev_mask() == cpu::kDevMask) {
this->WaitToRead();
RunContext rctx{this->ctx(), nullptr};
NDArray src = *this;
#if MXNET_USE_MKLDNN == 1
if (src.IsMKLDNNData())
src = this->Reorder2Default();
#endif
ndarray::Copy<cpu, cpu>(src.data(), &dst,
Context::CPU(), Context::CPU(), rctx);
} else {
#if MXNET_USE_CUDA
Engine::Get()->PushAsync(
[&](RunContext rctx, Engine::CallbackOnComplete on_complete) {
ndarray::Copy<gpu, cpu>(this->data(), &dst,
this->ctx(), Context::CPU(), rctx);
// Wait GPU kernel to complete
rctx.get_stream<gpu>()->Wait();
on_complete();
}, this->ctx(), {this->var()}, {},
FnProperty::kCopyFromGPU, 0, "SyncCopyGPU2CPU");
this->WaitToWrite();
#else
LOG(FATAL) << "GPU is not enabled";
#endif
}
}
以上经常使用到PushAsync,其定义在不同引擎类型中不相同,在threaded型中为
void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop,
int priority,
const char* opr_name,
bool wait) {
#if MXNET_USE_CUDA
if (exec_ctx.dev_mask() == gpu::kDevMask) {
if (device_count_ < 0) {
int tmp = -1;
cudaGetDeviceCount(&tmp);//获取显卡数量
device_count_ = tmp;
CHECK_GT(device_count_, 0) << "GPU usage requires at least 1 GPU";
}
CHECK_LT(exec_ctx.dev_id, device_count_)
<< "Invalid GPU Id: " << exec_ctx.dev_id
<< ", Valid device id should be less than device_count: "
<< device_count_;
}//检测是否存在显卡,指定的显卡序号是否有效
#endif
ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name, wait);//将运算函数、常量、变量、拷贝方式等信息记录后返回
opr->temporary = true;
const bool profiling = profiler_->IsProfiling(profiler::Profiler::kImperative);
Push(opr, exec_ctx, priority, profiling);
}
void ThreadedEngine::Push(OprHandle op, Context exec_ctx, int priority, bool profiling) {
BulkFlush();
ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op);
OprBlock* opr_block = OprBlock::New();
opr_block->opr = threaded_opr;
opr_block->wait.store(static_cast<int>(
threaded_opr->const_vars.size() +
threaded_opr->mutable_vars.size() + 1));
opr_block->ctx = exec_ctx;
opr_block->priority = priority;
opr_block->profiling = profiling;
++pending_;
// Add read dependencies.
for (auto&& i : threaded_opr->const_vars) {
i->AppendReadDependency(opr_block);
}
// Add write dependencies.
for (auto&& i : threaded_opr->mutable_vars) {
i->AppendWriteDependency(opr_block);
}
if (opr_block->decr_wait() == 0) {
this->PushToExecute(opr_block, true);
}
}