TF session 会话分析

Author : jingwenyi
Date: 2016-6-10
E-mail : 1124427436@qq.com
代码版本:tensorflow-tensorflow-v0.8.0rc0-147-gcf7ce8a

TF session 会话分析 1
第一节 :tensorflow/examples/label_image/实例代码的运行 2
第二节 :代码运行流程跟踪 4
1、 图表分析运行流程 4
2、本地运行代码流程跟踪 5
3、分布式运行代码跟踪 7
第三节 :分布式服务进程分析 13
1、 分布式服务进程图解 13
2、分布式服务进程启动跟踪 13
3、ENQUEUE_REQUEST 宏分析 22
第四节:DoRecvTensor和DoRunGraph处理分析 25
1、DoRecvTensor 25
2、 DoRunGraph 26

第一节:tensorflow/examples/label_image/实例代码的运行
(关于运行步骤的具体详情,请看tensorflow/examples/label_image/README.md)
在tensorflow\examples\label_image\data下有怎么一个图片
tf测试图片
这个测试用例最后得到的结果是,这个图像中的这位帅哥,有60%多的概率是一个军人。以下是在“TF源码分析小组服务器上测试的流程和结果”:
(服务器地址:101.200.210.101)
1、从谷歌服务器下载训练数据
wget https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip -O tensorflow/examples/label_image/data/inception_dec_2015.zip &
2、解压下载的数据
unzip tensorflow/examples/label_image/data/inception_dec_2015.zip -d tensorflow/examples/label_image/data/ &
3、编译label_image 测试实例
bazel build tensorflow/examples/label_image/… &
4、训练数据,测试数据
bazel-bin/tensorflow/examples/label_image/label_image &

输出结果为:
I tensorflow/examples/label_image/main.cc:207] military uniform (866): 0.647299
I tensorflow/examples/label_image/main.cc:207] suit (794): 0.0477195
I tensorflow/examples/label_image/main.cc:207] academic gown (896): 0.0232407
I tensorflow/examples/label_image/main.cc:207] bow tie (817): 0.0157355
I tensorflow/examples/label_image/main.cc:207] bolo tie (940): 0.0145023

注:关于main.cc 中代码分析推荐博客:TensorFlow在图像识别中的应用
博客地址:http://www.csdn.net/article/2015-12-16/2826496?_t=t

第二节:代码运行流程跟踪
Session 的运行分本地和分布式两种,默认是分布式的。
1、图表分析运行流程
本地代码运行流程图
分布式代码运行流程图

2、本地运行代码流程跟踪
//Main.cc
main(int argc, char argv[])
{
session->Run({{input_layer, resized_tensor}},
{output_layer}, {}, &outputs);
}
—->
//session.h
virtual Status Run(const std::vector<std::pair<string, Tensor> >& inputs,
const std::vector<string>& output_tensor_names,
const std::vector<string>& target_node_names,
std::vector<Tensor>
outputs) = 0;
—->
//directsession.cc
Status DirectSession::Run(const RunOptions& run_options,
const NamedTensorList& inputs,
const std::vector<string>& output_names,
const std::vector<string>& target_nodes,
std::vector<Tensor> outputs,
RunMetadata
run_metadata) {
……
for (const auto& item : executors_and_keys->items) {
//runAsync 开始异步运行
item.executor->RunAsync(args, barrier->Get());
}
…………..
}
—————————->
//executor.cc
void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
(new ExecutorState(args, this))->RunAsync(done);
}
————————————>
void ExecutorState::RunAsync(Executor::DoneCallback done) {
…………..
ScheduleReady(ready, nullptr);
…………………
}
————————->
void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready,
std::deque<TaggedNode>* inline_ready) {
…………..
for (auto& tagged_node : ready) {
//process 函数开始处理
runner
(std::bind(&ME::Process, this, tagged_node, scheduled_usec));
}
…………………
}
———————->
void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {

………….
if (item.kernel_is_async) {
//异步计算
device->ComputeAsync(async, ctx, done);
…………..
}
Else
{
//同执行方法
device->Compute(CHECK_NOTNULL(op_kernel), &ctx);
}
}

//看下同步是怎样执行的
//device.h
(这里不同设备对,compute实现不一样,不过都会调到opkernel->Compute(context))
virtual void Compute(OpKernel op_kernel, OpKernelContext context) {
op_kernel->Compute(context);
}
——>
//op_kernel.h
virtual void Compute(OpKernelContext context) = 0;
(到这里就直接运行对应op 的操作了,比如main.cc 中的softmax op)
———————->
//Softmax_op.h
void Compute(OpKernelContext
context) override {
const Tensor& logits_in = context->input(0);
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()),
errors::InvalidArgument(“logits must be 2-dimensional”));
Tensor* softmax_out = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(0, logits_in.shape(), &softmax_out));
if (logits_in.NumElements()) {
functor::SoftmaxFunctor<Device, T> functor;
functor(context->eigen_device<Device>(), logits_in.matrix<T>(),
softmax_out->matrix<T>(), log
);
}
}

3、分布式运行代码跟踪
//Main.cc
main(int argc, char argv[])
{
session->Run({{input_layer, resized_tensor}},
{output_layer}, {}, &outputs);
}
—->
//session.h
virtual Status Run(const std::vector<std::pair<string, Tensor> >& inputs,
const std::vector<string>& output_tensor_names,
const std::vector<string>& target_node_names,
std::vector<Tensor>
outputs) = 0;
—->
//grpcsession.cc
Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& output_tensor_names,
const std::vector<string>& target_node_names,
std::vector<Tensor>* outputs) {
RunOptions run_options;
run_options.set_timeout_in_ms(options
.config.operation_timeout_in_ms());
return Run(run_options, inputs, output_tensor_names, target_node_names,
outputs, nullptr);
}
—->
Status GrpcSession::Run(const RunOptions& run_options,
const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& output_tensor_names,
const std::vector<string>& target_node_names,
std::vector<Tensor> outputs,
RunMetadata
run_metadata) {
……

TFRETURN_IF_ERROR(RunProto(&call_options, &req, &resp));
…..
}
—->
Status GrpcSession::RunProto(CallOptions call_options, RunStepRequest req,
RunStepResponse* resp) {
if (handle
.empty()) {
return errors::InvalidArgument(“A session is not created yet….”);
}

req->setsession_handle(handle);
return master->RunStep(call_options, req, resp);
}
—->
//master_interface.h
virtual Status RunStep(CallOptions call_options,
const RunStepRequest
request,
RunStepResponse response) = 0;
——>
//grpc_remote_master.cc
Status RunStep(CallOptions
call_options, const RunStepRequest request,
RunStepResponse
response) override {
::grpc::ClientContext ctx;
SetDeadline(&ctx, call_options->GetTimeout());
return FromGrpcStatus(stub
->RunStep(&ctx, request, response));
}
//grpc_master_service.cc
void RunStepHandler(MasterCall<RunStepRequest, RunStepResponse>
call) {
CallOptions callopts = new CallOptions;
call->SetCancelCallback(call_opts { call_opts->StartCancel(); });
master_impl
->RunStep(call_opts, &call->request, &call->response,
call, call_opts {
call->ClearCancelCallback();
delete call_opts;
call->SendResponse(ToGrpcStatus(status));
});
ENQUEUE_REQUEST(RunStep, true);
}
——>
//master.cc
void Master::RunStep(CallOptions
opts, const RunStepRequest req,
RunStepResponse
resp, MyClosure done) {
SchedClosure(this, start_time, session, opts, req, resp, done {
//这里开始运行
Status status = session->Run(opts, req, resp);
uint64 donetime = env->env->NowMicros();
done(status);
mutexlock l(mu);
//至少运行1000 步
last1000_steps.AddValue((donetime - start_time) / 1e9);
++step_count
;
});
}
—->
//mastersession_interface.h
virtual Status Run(CallOptions opts, const RunStepRequest req,
RunStepResponse resp) = 0;
//master_session.cc
Status MasterSession::Run(CallOptions
opts, const RunStepRequest req,
RunStepResponse
resp) {
UpdateLastAccessTime();
{
mutex_lock l(mu
);
++numrunning;
}
//本地执行
Status status = DoRunWithLocalExecution(opts, req, resp);
{
mutexlock l(mu);
—numrunning;
if (numrunning == 0) {
numrunning_is_zero.notifyall();
}
}
return status;
}
—>
Status MasterSession::DoRunWithLocalExecution(CallOptions opts,
const RunStepRequest
req,
RunStepResponse* resp) {
…….
TF_RETURN_IF_ERROR(rcg->RunPartitions(
env
, stepid, count, execution_state.get(), &pss, opts, req, resp));
……..
}
——>
Status MasterSession::ReffedClientGraph::RunPartitions(
const MasterEnv
env, int64 stepid, int64 execution_count,
SimpleGraphExecutionState execution_state, PerStepState pss,
CallOptions call_opts, const RunStepRequest& req, RunStepResponse resp) {
for (int i = 0; i < num; ++i) {
const Part& part = partitions
[i];
RunManyGraphs::Call call = calls.get(i);
TRACEPRINTF(“Partition %d %s”, i, part.name.c_str());
//异步运行graph
part.worker->RunGraphAsync(
&call->opts, &call->req, &call->resp,
std::bind(&RunManyGraphs::WhenDone, &calls, i, std::placeholders::_1));
}
}
————————>
//work_interface.h
virtual void RunGraphAsync(CallOptions
opts, const RunGraphRequest request,
RunGraphResponse
response,
StatusCallback done) = 0;
———————->
//grpcremote_worker.cc
void RunGraphAsync(CallOptions call_opts, const RunGraphRequest request,
RunGraphResponse response, StatusCallback done) override {
IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncRunGraph,
done, call_opts);
}
———————>
void IssueRequest(const RequestMessage
request, ResponseMessage response,
AsyncMethod<RequestMessage, ResponseMessage> async_method,
StatusCallback done, CallOptions
call_opts = nullptr) {
…..
auto rpc = (stub
.get()->async_method)(context, request, cq).release();
…..
}
——————->
//grpc_worker_service.cc
void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
//调度线程执行DoRunGraph
env
->compute_pool->Schedule(this, call { DoRunGraph(call); });
ENQUEUE_REQUEST(RunGraph, true);
}
—————->
//threadpool.cc
void ThreadPool::Impl::Schedule(std::function<void()> fn) {
uint64 id = 0;
if (port::Tracing::IsActive()) {
id = port::Tracing::UniqueId();
port::Tracing::RecordEvent(port::Tracing::EventCategory::kScheduleClosure,
id);
}

mutexlock l(mu);
pending.push_back({fn, id});
if (!waiters
.empty()) {
Waiter* w = waiters.back();
waiters
.pop_back();
w->ready = true;
w->cv.notify_one();
}
}
——————————>
//线程池工作循环
void ThreadPool::Impl::WorkerLoop() {
// Set the processor flag to flush denormals to zero
port::ScopedFlushDenormal flush;

port::Tracing::RegisterCurrentThread(name.c_str());
mutex_lock l(mu
);
Waiter w;
while (true) {
while (pending.empty()) {
// Wait for work to be assigned to me
//等待工作
w.ready = false;
waiters
.pushback(&w);
while (!w.ready) {
w.cv.wait(l);
}
}
// Pick up pending work
//获得未处理的任务
Task t = pending
.front();
pending.pop_front();
if (t.fn == nullptr) {
break;
}
mu
.unlock();
//t.fn 调用任务的处理函数
if (t.id != 0) {
port::Tracing::ScopedActivity region(
port::Tracing::EventCategory::kRunClosure, t.id);
t.fn();
} else {
t.fn();
}
mu_.lock();
}
}

——————————>
//grpc_worker_service.cc
void DoRunGraph(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
//这里才开始干正事
…………..
}

第三节:分布式服务进程分析
1、分布式服务进程图解
分布式服务进程

2、分布式服务进程启动跟踪
//grpctensorflow_server.cc
int main(int argc, char* argv[]) {
…….
//初始化server 对象
tensorflow::NewServer(server_def, &server);
//启动服务
server->Start();
//管理服务
server->Join();
}
———————>
//grpc_server_lib.cc
在初始化server 对象的时候,调用了GrpcServer的create函数
Status GrpcServer::Create(….)
{
ret->Init();
}
———————>
Status GrpcServer::Init() {
……
//分配异步控制服务对象
master_service
= NewGrpcMasterService(&masterenv, &builder);
//启动一个异步的工作对象
workerservice = NewGrpcWorkerService(&workerenv, &builder);
//分配缓冲
workerenv.workercache = NewGrpcWorkerCache(channel_cache.release());
//消息传递相关对象
worker_env
.graphmgr = new GraphMgr(&worker_env);
workerenv.rendezvousmgr = new RpcRendezvousMgr(&worker_env);
//计算线程池
workerenv.computepool = ComputePool(sess_opts);
………
}
——————>
Start 启动服务
Status GrpcServer::Start() {
//启动master service 线程
master_thread
.reset(
env->StartThread(ThreadOptions(), “TF_master_service”,
[this] { master_service
->HandleRPCsLoop(); }));
//启动了worker service 线程
workerthread.reset(env->StartThread(ThreadOptions()”TF_worker_service”,
[this] { worker_service
->HandleRPCsLoop(); }));
}
———————>
线程池是运行graph处理函数的地方
//processutil.cc
thread::ThreadPool ComputePool(const SessionOptions& options) {
//获取启动线程的个数,其实也是cpu的个数
static thread::ThreadPool
compute_pool = InitComputePool(options);
return compute_pool;
}
——————->
static thread::ThreadPool InitComputePool(const SessionOptions& options) {
int32 inter_op_parallelism_threads =
options.config.inter_op_parallelism_threads();
if (inter_op_parallelism_threads == 0) {
// Default to using the number of cores available in the process.
//获取cpu 的个数
inter_op_parallelism_threads = port::NumSchedulableCPUs();
}
//分配线程池对象
return new thread::ThreadPool(Env::Default(), “Compute”,
inter_op_parallelism_threads);
}
————————->
//threadpool.cc
hreadPool::ThreadPool(Env
env, const ThreadOptions& thread_options,
const string& name, int num_threads) {
CHECK_GE(num_threads, 1);
impl
.reset(
new ThreadPool::Impl(env, threadoptions, “tf“ + name, num_threads));
}
————>

ThreadPool::Impl::Impl(Env* env, const ThreadOptions& threadoptions,
const string& name, int num_threads)
: name
(name) {
//有几个cpu 就启动几个线程
for (int i = 0; i < numthreads; i++) {
threads
.push_back(
env->StartThread(thread_options, name, this { WorkerLoop(); }));
}
}
———————->
//线程池工作循环,这里是调用graph 处理函数的地方
void ThreadPool::Impl::WorkerLoop() {
// Set the processor flag to flush denormals to zero
port::ScopedFlushDenormal flush;

port::Tracing::RegisterCurrentThread(name.c_str());
mutex_lock l(mu
);
Waiter w;
while (true) {
while (pending.empty()) {
// Wait for work to be assigned to me
//等待工作
w.ready = false;
waiters
.pushback(&w);
while (!w.ready) {
w.cv.wait(l);
}
}
// Pick up pending work
//获得未处理的任务
Task t = pending
.front();
pending.pop_front();
if (t.fn == nullptr) {
break;
}
mu
.unlock();
//t.fn 调用任务的处理函数
if (t.id != 0) {
port::Tracing::ScopedActivity region(
port::Tracing::EventCategory::kRunClosure, t.id);
t.fn();
} else {
t.fn();
}
mu_.lock();
}
}

//schedule 把需要处理的graph 加入到pending 向量尾部
void ThreadPool::Impl::Schedule(std::function<void()> fn) {
……..
pending
.push_back({fn, id});
…….
}
——————->
Grpc 控制服务线程
//grpc_master_service.cc
void HandleRPCsLoop() {
创建一个session 请求回调函数
ENQUEUE_REQUEST(CreateSession, true);
扩展session 请求回调函数
ENQUEUE_REQUEST(ExtendSession, false);
for (int i = 0; i < 100; ++i) {
在session 执行一个step 请求回调函数
ENQUEUE_REQUEST(RunStep, true);
}
ENQUEUE_REQUEST(CloseSession, false);
列举出可用设备请求回调函数
ENQUEUE_REQUEST(ListDevices, false);
ENQUEUE_REQUEST(Reset, false);

  
  
  1. void* tag;

bool ok;
while 监听回调
while (cq->Next(&tag, &ok)) {
UntypedCall<GrpcMasterService>::Tag callback_tag =
static_cast<UntypedCall<GrpcMasterService>::Tag
>(tag);
if (callback_tag) {
callback_tag->OnCompleted(this, ok);
delete callback_tag;
} else {
// NOTE(mrry): A null callback_tag indicates that this is
// the shutdown alarm.
cq
->Shutdown();
}
}
}
———————>
Grpc 工作 线程
//grpc_worker_service.cc
void HandleRPCsLoop() {

  
  
  1. 注册获取状态请求回调函数
  2. ENQUEUE_REQUEST(GetStatus, false);

ENQUEUE_REQUEST(CleanupAll, false);
注册graph 请求回调函数
ENQUEUE_REQUEST(RegisterGraph, false);
ENQUEUE_REQUEST(DeregisterGraph, false);
接收tensor 请求回调函数
for (int i = 0; i < 1000; ++i) {
ENQUEUE_REQUEST(RecvTensor, true);
}
//运行graph 请求回调函数
for (int i = 0; i < 100; ++i) {
ENQUEUE_REQUEST(RunGraph, true);
}

  
  
  1. ENQUEUE_REQUEST(CleanupGraph, false);
  2. ENQUEUE_REQUEST(Logging, false);
  3. ENQUEUE_REQUEST(Tracing, false);
  4. void* tag;
  5. bool ok;
  6. while 等待call 调用
  7. while (cq_->Next(&tag, &ok)) {
  8. UntypedCall<GrpcWorkerService>::Tag* callback_tag =
  9. static_cast<UntypedCall<GrpcWorkerService>::Tag*>(tag);
  10. if (callback_tag) {
  11. callback_tag->OnCompleted(this, ok);
  12. delete callback_tag;
  13. } else {
  14. // NOTE(mrry): A null `callback_tag` indicates that this is
  15. // the shutdown alarm.
  16. cq_->Shutdown();
  17. }
  18. }

}
3、ENQUEUEREQUEST 宏分析
用ENQUEUE_REQUEST(GetStatus, false); 为列分析这个宏
ENQUEUE_REQUEST(GetStatus, false)—-> 展开后,如下
do{
mutex_lock l(shutdown_mu
);
if(!isshutdown)
{
//执行了call 模板中对应的EnqueueRequest 函数
Call<GrpcWorkerService, grpc::WorkerService::AsyncService,GetStatusRequest, GetStatusResponse>::
EnqueueRequest(&workerservice,
cq_,
//session 发出的请求
&grpc::WorkerService::AsyncService::RequestGetStatus,
//请求的处理函数
&GrpcWorkerService::GetStatusHandler,
false);
}
}

看看EnqueueRequest 函数,该函数在tensorflow/core/distributedruntime/rpc/grpc_call.h中
Call 是个模板类
template <class Service, class GrpcService, class RequestMessage,class ResponseMessage>这个模板跟上面带带绿色的部分一一对应。
static void EnqueueRequest(GrpcService grpc_service,
::grpc::ServerCompletionQueue
cq,
EnqueueFunction enqueue_function,
HandleRequestFunction handle_request_function,
bool supports_cancel) {
//初始化了一个call 对象
auto call = new Call<Service, GrpcService, RequestMessage, ResponseMessage>(
handle_request_function);
//判读是否支持取消,如果支持,则注册取消处理函数
if (supports_cancel) {
call->RegisterCancellationHandler();
}
//支持了request 的处理函数
(grpc_service->*enqueue_function)
(&call->ctx
,
&call->request,
&call->responder,
cq,
cq,
new typename UntypedCall<Service>::Tag(call, &UntypedCall<Service>::RequestReceived)
);
call->Unref();
}
看下GetStatusHandler 这个处理函数
把session 的请求和server 的处理函数加入到了线程池中,以GetStatus为列,就是把RequestGetStatus和GetStatusHandler 绑定起来注册到线程池中,当session 发出GetStatus的请求,就去执行GetStatusHandler这个函数
void GetStatusHandler(WorkerCall<GetStatusRequest, GetStatusResponse>* call) {
env
->computepool->Schedule(
this, call {
//获得设备管理对象
DeviceMgr* dm = env
->device_mgr;
std::vector<DeviceAttributes> devices;
//获取所有设备的属性
dm->ListDeviceAttributes(&devices);
call->response.mutable_device_attributes()->Reserve(devices.size());
//为每个cpu 添加状态数据交换区
for (size_t i = 0; i < devices.size(); i++) {
call->response.add_device_attributes()->Swap(&devices[i]);
}
call->SendResponse(::grpc::Status::OK);
}
);
ENQUEUE_REQUEST(GetStatus, false);
}

第四节:DoRecvTensor和DoRunGraph处理分析
1、DoRecvTensor
//Grpcworker_service.cc
DoRecvTensor(….)
{
//准备接收tensor,获取设备信息
Status s = PrepareRecvTensor(key, &src_dev);
env
->rendezvous_mgr->RecvLocalAsync(….,done);

}
—————————->
//base_rendezvous_mgr.cc
void BaseRendezvousMgr::RecvLocalAsync(int64 step_id, const string& key,
Rendezvous::DoneCallback done) {
BaseRemoteRendezvous* rendez = FindOrCreate(step_id);
rendez->RecvLocalAsync(
key, rendez, done {
rendez->Unref();
//这里这个done 调用了DoRecvTensor中的回调函数
done(s, send_args, recv_args, v, dead);
});
}

———————————->
//DoRecvTensor执行的具体内容
this, call, src_dev
{
call->ClearCancelCallback();
Status s = status;
if (s.ok()) {
// DMA can only be used for Tensors that do not fall into
// the following three odd edge cases: 1) a zero-size
// buffer, 2) a dead tensor which has an uninit value, and
// 3) the tensor has the on_host allocation attribute,
// i.e. it’s in CPU RAM independent of its assigned
// device type
.
// const size_t bytes = is_dead ? 0 : val.TotalBytes();
//回复当前设备的信息
const bool on_host = send_args.alloc_attrs.on_host();
const DeviceContext send_dev_context = send_args.device_context;
call->response.set_is_dead(is_dead);
StatusCallback response_ready = call {
// The value is now ready to be returned on the wire.
call->response.set_send_start_micros(Env::Default()->NowMicros());
call->SendResponse(ToGrpcStatus(s));
};
{
// Non-DMA cases.
if (src_dev->tensorflow_gpu_device_info() && (!on_host)) {
//如果有gpu ,把数据cp 给gpu
CHECK(send_dev_context)
<< “send dev name: “ << src_dev->name()
<< “ gpu_info: “ << src_dev->tensorflow_gpu_device_info();
// “val” is on a GPU. Uses GPUUtil to fill the response proto.
GPUUtil::SetProtoFromGPU(val, src_dev, send_dev_context,
call->response.mutable_tensor(),
is_dead, response_ready);
} else {
// “val” is in CPU memory.
//cpu 的处理方式
TensorProto
proto = call->response.mutable_tensor();
val.AsProtoTensorContent(proto);
response_ready(Status::OK());
}
}
} else {
// !s.ok()
call->SendResponse(ToGrpcStatus(s));
}
}

2、DoRunGraph
//Grpc_worker_service.cc
void DoRunGraph(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {

…..
//准备运行graph,cpu分配
Status s = PrepareRunGraph(call->request, &in, out);
…..
//异步执行
env->graph_mgr->ExecuteAsync(
//执行的graph 处理函数
call->request.graph_handle(),
step_id,//每一步唯一的id号
call->request.exec_opts(),
collector,//收集状态
cm,
in,
out,
//执行的状态回调函数
this, call, cm, out, token
{
call->ClearCancelCallback();
{
mutex_lock l(mu
);
cancellationmanager->DeregisterCallback(token);
}
delete cm;

  
  
  1. if (s.ok()) {

//执行完成之后的数据拷贝
for (const auto& p : out) {
const string& key = p.first;
const Tensor& val = p.second;
auto
recv = call->response.add_recv();
recv->set_key(key);
// TODO(zhifengc): Deal with gpu -> cpu copy.
//把数据从gpu 拷贝给cpus
TensorProto* proto = recv->mutable_val();
val.AsProtoField(proto);
}
}
delete out;
call->SendResponse(ToGrpcStatus(s));
}
);
}

……

}

———————————->

看以下异步执行代码
//graphmgr.cc
void GraphMgr::ExecuteAsync(…..)
{
//查找graph 的处理函数
auto iter = table
.find(handle);
//通过stepid 查找到执行的位置
Rendezvous* rendezvous = worker_env
->rendezvousmgr->Find(step_id);
//启动并行执行
{
//划分并行运行直接的界限
ExecutorBarrier barrier = new ExecutorBarrier(
num_units,
rendezvous,
std::bind(&ME::RunAllDone, this, item,
rendezvous,out, done, std::placeholders::_1));
//获取线程池的指针
thread::ThreadPool
pool = worker_env
->compute_pool;
//加入到待处理队列中
args.runner = pool> fn) { pool->Schedule(fn); };
//开始异步执行
for (const auto& unit : item->units) {
unit.root->RunAsync(args, barrier->Get());
}
}

//这下面的执行就跟上面本地执行的步骤一样了
//executor.cc
void ExecutorState::RunAsync(Executor::DoneCallback done) {
…….
//调度运行线程池中准备好的所有ops
ScheduleReady(ready, nullptr);
……
}

————————->
void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready,
std::deque<TaggedNode>* inlineready) {
…………..
for (auto& tagged_node : ready) {
//process 函数开始处理
runner
(std::bind(&ME::Process, this, tagged_node, scheduled_usec));
}
…………………
}
———————->
void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {

………….
// kernel_is_async = false 恒成立
if (item.kernel_is_async) {
//异步计算
device->ComputeAsync(async, ctx, done);
…………..
}
Else
{
//同执行方法
device->Compute(CHECK_NOTNULL(op_kernel), &ctx);
}
}

//看下同步是怎样执行的
//device.h
(这里不同设备对,compute实现不一样,不过都会调到opkernel->Compute(context))
virtual void Compute(OpKernel op_kernel, OpKernelContext context) {
op_kernel->Compute(context);
}
——>
//op_kernel.h
virtual void Compute(OpKernelContext context) = 0;
(到这里就直接运行对应op 的操作了,比如main.cc 中的softmax op)
———————->
//Softmax_op.h
void Compute(OpKernelContext
context) override {
const Tensor& logits_in = context->input(0);
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()),
errors::InvalidArgument(“logits must be 2-dimensional”));
Tensor* softmax_out = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(0, logits_in.shape(), &softmax_out));
if (logits_in.NumElements()) {
functor::SoftmaxFunctor<Device, T> functor;
functor(context->eigen_device<Device>(), logits_in.matrix<T>(),
softmax_out->matrix<T>(), log
);
}
}

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值