executor和 direct session源码解读
目录
这篇博客主要从 C++ 源码角度,讲tensorlfow 运行时 对control flow 原语的特殊处理。要完全理解这部分源码,需要对tensorflow control flow 的原理有些了解,建议先看我上一篇博客。
excutor.cc
先来看逻辑清晰一点的excutor.cc。直观地理解,运行时根据可用的资源,把计算图分割成若干个子图,每个子图绑定到一个device(一个cpu/gpu/tpu..etc),对应一个excutor 。excutor.cc这部分代码被单机模式(direct session)和分布式模式(distributed session)共用。
以下的描述可以参考这张图。
在读代码之前,我先列出类excutor.cc内的几个重要的类。
- struct EdgeInfo 用来表示NodeItem里的输出边
- struct NodeItem 表示一个op 节点
- class GraphView // Immutable view of a Graph organized for efficient execution.
- public Executor ExecutorImpl的基类,一个计算子图绑定一个device,对应一个Executor。
- class ExecutorImpl : public Executor excutor的具体实现
- ExecutorImpl::struct ControlFlowInfo 这个类的对象维护nodeid 对应的Frame name ,一个ControlFlowInfo对象会在ExecutorImpl初始化的时候创建
- ExecutorImpl::struct FrameInfo 这个类的对象用来表示一个Frame(while loop)的静态信息。其中最重要的是 一个 pendingcounts数据结构的属性,该数据结构以内存紧凑的方式维护frame中的每一个node的初始pending count(该节点的输入个数及依赖边的个数),这个pendingcount 被用来初始化 FrameInfo 对应的 FrameState里的FrameState Iteration 的pendingcounts属性值。在ExecutorImpl初始化的时候,一系列的FrameInfo会被创建(准确地说,是一个Frame 到FrameInfo的map,map的每一项对应计算图中一个whileloop)
- class ExecutorState 用源码里的注释描述,// The state associated with one invocation of ExecutorImpl::Run.
- ExecutorState::struct IterationState IterationState 维护FrameState(也就是一个实例化的whileloop) 里的一次迭代相关的状态变量。其中最重要的有两个。一个是pendingcounts,这个pendingcounts 的初始值是这个IterationState 所在FrameState对应的FrameInfo里的pendingcounts值,随着计算的过程,这个pendingcounts 会被修改;另一个是inputs数组,这里会存储这个iteration 里所有计算节点的输入。计算节点的输出会复制到所有消费这个输出的节点在inputs 对应的槽位中(可能复制多份)。存在inputs数组里面的tensorvalue 在被消费完后会被销毁。
- ExecutorState::struct FrameState FrameState表示一个Frame的实例化状态。因为Frame 可以嵌套,嵌套在一个whileloop 里的whileloop可能会被多次实例化,一次实例化对应一个FrameState,也就是说一个whileloop,只对应一个FrameInfo,但是可能会实例化多个FrameState。FrameState在图计算的过程中动态地创建,修改和销毁。FrameState维护的重要状态是 IterationState数组。
- ExecutorState::struct TaggedNode //A tagged node: <frame*, iter, node*,isdead>。表示一个op的实例。Frame里的一个op在Frame的每一次iteration 都会实例化一次。
接下来,讲述一下excutor的运行逻辑。
direct session 调用excutor 的虚函数RunAsync。由RunAsync开始:
- ExecutorImpl::RunAsync
- ExecutorState::RunAsync // (new ExecutorState(args, this))->RunAsync(std::move(done))
- ExecutorState::ScheduleReady // ScheduleReady(ready, nullptr)
- runner_ // runner_([=]() { Process(tagged_node, scheduled_nsec); })
- process // runner_([=]() { Process(tagged_node, scheduled_nsec); })
- PrepareInputs // PrepareInputs(item, first_input, &inputs, &input_device_contexts,
&input_alloc_attrs, &is_input_dead) - Compute // device->Compute(CHECK_NOTNULL(op_kernel), &ctx);
- ProcessOutputs // ProcessOutputs(item, &ctx, &outputs, stats);
- PropagateOutputs // PropagateOutputs(tagged_node, &item, &outputs, &ready);
- ExecutorState::FrameState::ActivateNodes
- NodeDone // NodeDone(s, item.node, ready, stats, &inline_ready)
- PrepareInputs // PrepareInputs(item, first_input, &inputs, &input_device_contexts,
- process // runner_([=]() { Process(tagged_node, scheduled_nsec); })
- runner_ // runner_([=]() { Process(tagged_node, scheduled_nsec); })
- ExecutorState::ScheduleReady // ScheduleReady(ready, nullptr)
- ExecutorState::RunAsync // (new ExecutorState(args, this))->RunAsync(std::move(done))
ExecutorState::RunAsync初始化运行,把计算图入度为0的opnode放入一个tagednodeReadyQueue;调用ExecutorState::ScheduleReady开始调度。ExecutorState::ScheduleReady里面把process方法包装成一个lambda传入runner_,把任务提交给线程池,runner_是ExecutorState的一个属性,由ExecutorState构造函数传入,是一个std::function<void(std::function<void()>)> 对象,负责调度任务。进入process,process 把taggednode 放入一个队列,readyqueue。进入while循环,在循环中依次调用PrepareInputs、Compute、ProcessOutputs、PropagateOutputs、NodeDone,PropagateOutputs 可能在队列里加入新的元素。循环直到readyqueue为空。
以上的调用链源码分析已经有人写过博客了,详情请看这位高人的博客,我这里着重解析和控制流相关的部分,主要集中在process里。以下我主要分析PrepareInput、ProcessOutputs、PropagateOutputs、NodeDone这5个方法。
PrepareInputs
源代码及我加的注释
prepareinputs 的作用,如名字所表达地那样,从IterationState里把输入地tenser及其一些附带信息取出。
//主要逻辑是把输入从first_input(数组中的一个)开始的片段中拷贝到inputs,并收集一些其他的信息。
//在调用的时候,first_input指向IterationState中inputs数组中的一个元素
//在process里调用结束后,inputs、input_device_contexts、input_alloc_attrs、is_input_dead
//这些数据会传递给opkernelcontext,在Kernel::compute里被使用。
Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input,
TensorValueVec* inputs,
DeviceContextVec* input_device_contexts,
AllocatorAttributeVec* input_alloc_attrs,
bool* is_input_dead) {
const Node* node = item.node;
inputs->clear();
inputs->resize(item.num_inputs);
input_device_contexts->clear();
input_device_contexts->resize(item.num_inputs);
input_alloc_attrs->clear();
input_alloc_attrs->resize(item.num_inputs);
*is_input_dead = false;
bool is_merge = item.is_merge;
for (int i = 0; i < item.num_inputs; ++i) {
//从该taggednode对应的NodeItem获取一些输入值的属性
const bool expect_ref = IsRefType(item.input_type(i));
Entry* entry = first_input + i;
(*input_device_contexts)[i] = entry->device_context;
(*input_alloc_attrs)[i] = entry->alloc_attr;
// i-th input.
TensorValue* inp = &(*inputs)[i];
//校验,PrepareInputs在process被调用,只有merge 和 transfer 节点才可能执行到这,因为merge
//本身容许 无值的输入,而运输节点 需要在子图之间传递 isdead 标识(详情请看我的上一篇博客),
// Only merge and transfer nodes can have no-value inputs.
if (!entry->has_value) {
if (!is_merge) {
DCHECK(IsTransferNode(node)) << node->name() << " - input " << i;
DCHECK(!entry->val_field_is_set) << node->name() << " - input " << i;
entry->has_value = true;
entry->val_field_is_set = true;
entry->val.Init(*kEmptyTensor);
inp->tensor = entry->val.get();
*is_input_dead = true;
}
continue;
}
//拷贝数据,考虑引用和非引用之间的拷贝。
if (entry->ref == nullptr) {
if (expect_ref) {
return AttachDef(
errors::InvalidArgument(i, "-th input expects a ref type"),
item.kernel->def());
}
inp->tensor = entry->val.get();
} else {
{
mutex_lock ml(*entry->ref_mu);
if (!entry->ref->IsInitialized() && !IsInitializationOp(item.node)) {
return AttachDef(errors::FailedPrecondition(
"Attempting to use uninitialized value ",
item.kernel->requested_input(i)),
item.kernel->def());
}
}
if (expect_ref) {
inp->mutex_if_ref = entry->ref_mu;
inp->tensor = entry->ref;
} else {
// Automatically deref the tensor ref when the op expects a
// tensor but is given a ref to a tensor. Need to deref it
// under the mutex.
{
mutex_lock l(*(entry->ref_mu));
DCHECK(!entry->val_field_is_set);
entry->val.Init(*entry->ref);
entry->val_field_is_set = true;
}
entry->ref = nullptr;
entry->ref_mu = nullptr;
inp->tensor = entry->val.get();
// The dtype of entry->ref could have been changed by another operation
// that ran after the operation that "produced" it executed, so
// re-validate that the type of the dereferenced tensor matches the
// expected input type.
if (item.input_type(i) != inp->tensor->dtype()) {
return AttachDef(
errors::InvalidArgument(
i, "-th input expects type ",
DataTypeString(item.input_type(i)),
" but automatically dereferenced input tensor has type ",
DataTypeString(inp->tensor->dtype())),
item.kernel->def());
}
}
}
}
return Status::OK();
}
Compute
compute 调用kernel 执行计算
ProcessOutputs
processoutput把计算地opkernelcontext里地输出拷贝出来
//Kernel 运行结束,从opKernelContext获得计算的输出,暂时存储在outputs指向的数组。
Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
EntryVector* outputs,
NodeExecStatsWrapper* stats) {
const Node* node = item.node;
DCHECK_EQ(0, outputs->size());
outputs->resize(item.num_outputs);
Status s = ctx->status();
if (!s.ok()) {
s = AttachDef(s, item.kernel->def());
// TODO(misard) Replace with a finer-grain enabling flag once we
// add better optional debugging support.
if (vlog_ && VLOG_IS_ON(1)) {
LOG(WARNING) << this << " Compute status: " << s;
DumpState();
}
if (s.code() == error::RESOURCE_EXHAUSTED) {
if (stats_collector_) {
string err = stats_collector_->ReportAllocsOnResourceExhausted(
s.error_message());
s = Status(s.code(), strings::StrCat(s.error_message(), err));
} else {
s = Status(
s.code(),
strings::StrCat(
s.error_message(),
"\nHint: If you want to see a list of allocated tensors when "
"OOM happens, add report_tensor_allocations_upon_oom "
"to RunOptions for current allocation info.\n"));
}
}
return s;
}
// Get the device_context for this node id, if it exists.
DeviceContext* device_context = nullptr;
if (node->id() < device_context_map_.size()) {
device_context = device_context_map_[node->id()];
}
for (int i = 0; i < item.num_outputs; ++i) {
const TensorValue val = ctx->release_output(i);
if (val.tensor == nullptr) {
// Unless it's a Switch or a Recv, the node must produce a
// tensor value at i-th output.
if (!IsSwitch(node) && !IsRecv(node)) {
s.Update(errors::Internal("Missing ", i, "-th output from ",
SummarizeNode(*node)));
}
} else {
Entry* out = &((*outputs)[i]);
// Set the device context of the output entry.
out->device_context = device_context;
// Set the allocator attributes of the output entry.
out->alloc_attr = ctx->output_alloc_attr(i);
// Sanity check of output tensor types.
DataType dtype;
if (val.is_ref()) {
mutex_lock ml(*val.mutex_if_ref);
dtype = MakeRefType(val->dtype());
} else {
dtype = val->dtype();
}
if (dtype == item.output_type(i)) {
if (stats && val.tensor->IsInitialized()) {
nodestats::SetOutput(stats, i, val.tensor);
}
if (val.is_ref()) {
out->has_value = true;
out->ref = val.tensor;
out->ref_mu = val.mutex_if_ref;
if (log_memory_) {
Tensor to_log;
{
// Dereference the tensor under the lock.
mutex_lock l(*out->ref_mu);
to_log = *out->ref;
}
LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
ctx->step_id(), i, to_log);
}
} else {
// NOTE that std::move is used here, so val.tensor goes to
// uninitialized state (val.tensor->IsInitialized return false).
DCHECK(!out->val_field_is_set);
out->has_value = true;
out->val_field_is_set = true;
out->val.Init(std::move(*val.tensor));
if (log_memory_) {
LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
ctx->step_id(), i, *out->val);
}
}
} else {
s.Update(errors::Internal("Output ", i, " of type ",
DataTypeString(dtype),
" does not match declared output type ",
DataTypeString(item.output_type(i)),
" for node ", SummarizeNode(*node)));
}
}
if (!val.is_ref()) {
// If OpKernelContext returns outputs via pass-by-value, we
// don't need this trouble.
delete val.tensor;
}
}
return s;
}
PropagateOutputs,这里是对控制流特殊处理的的核心部分。netx/enter/exit 会在这特殊处理
//把上一步的到的输出,拷贝给需要改输出作为输入的TaggedNode,存放在taggedNode对应的IterationState 里的inputs数组。
//这里会特殊地对待控制流算子
void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
const NodeItem* item, EntryVector* outputs,
TaggedNodeSeq* ready) {
const Node* node = tagged_node.node;
FrameState* input_frame = tagged_node.input_frame;
const int64 input_iter = tagged_node.input_iter;
const bool is_dead = tagged_node.is_dead;
//如源码注释所言,沿着输出边,减小对应节点的入度,如果一个tagednode满足激活条件(对于一般节
//点,入度为0),就把它放到ready 队列
// Propagates outputs along out edges, and puts newly ready nodes
// into the ready queue.
ready->clear();
bool is_frame_done = false;
FrameState* output_frame = input_frame;
int64 output_iter = input_iter;
//控制流相关节点特殊处理
if (!item->is_enter_exit_or_next_iter) {
//其他节点的处理,merge节点也走这条线,merge会在ActivateNodes方法里特殊处理。
// Fast path for nodes types that don't need special handling
DCHECK_EQ(input_frame, output_frame);
// Normal path for most nodes
mutex_lock l(input_frame->mu);
output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
is_frame_done = input_frame->DecrementOutstandingOpsLocked(
&impl_->gview_, input_iter, ready);
} else if (item->is_enter) {
//对enter 特殊处理
bool is_constant;
const Status s = GetNodeAttr(node->attrs(), "is_constant", &is_constant);
DCHECK(s.ok()) << s;
//enter 可能需要创建一个FrameState,准确地说,如果是第一次enter一个frame,会创建一个
//framestate。
FindOrCreateChildFrame(input_frame, input_iter, node, &output_frame);
output_iter = 0;
{
const NodeItem* item = impl_->gview_.node(node->id());
mutex_lock l(output_frame->mu);
if (is_constant) {
// Propagate to all active iterations if this is a loop invariant.
output_frame->AddLoopInv(item, (*outputs)[0], ready);
} else {
output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
}
output_frame->num_pending_inputs--;
}
is_frame_done =
input_frame->DecrementOutstandingOps(&impl_->gview_, input_iter, ready);
} else if (item->is_exit) {
//对exit 节点特殊处理
if (is_dead) {
mutex_lock l(input_frame->mu);
// Stop and remember this node if it is a dead exit.
//如果是个dead exit,而且当前所在地iteration是当前的frame里走在最前面的iteration,暂时存
//储下来,之后在销这
//首次进入下一轮时销毁,或一直传到外层frame(如果当前的whileloop 在一个条件分支里。在本
//在当前的FrameState被销毁时才处理,
//因为只有到最后才知道,是否时真的dead,比如说,一个whileloop循环了两次,那么这个循环里
//的 exits 节点第一个 iteration里的实例化出来的taggednode 是dead,但是第二个iteration的
//实例化出的taggednode就不是dead。
if (input_iter == input_frame->iteration_count) {
input_frame->dead_exits.push_back(node);
}
//上锁同步,减小未决的op个数,并判断这个frame是否计算完成,如果是,在后面会销毁这个
//framestate
is_frame_done = input_frame->DecrementOutstandingOpsLocked(
&impl_->gview_, input_iter, ready);
} else {
//如果是个live exist 节点,向上层frame 传递 exist 对应的tensor
output_frame = input_frame->parent_frame;
output_iter = input_frame->parent_iter;
{
mutex_lock l(output_frame->mu);
output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
}
is_frame_done = input_frame->DecrementOutstandingOps(&impl_->gview_,
input_iter, ready);
}
} else {
//nextiteration 节点特殊处理
DCHECK(IsNextIteration(node));
//注意这里上锁了,以下需要修改一些线程共享变量
mutex_lock l(input_frame->mu);
if (is_dead) {
//如果传入的tensor是dead,不创建下一个iteration,循环到此为止,由此可知,nextiteration的
//一个作用是截断循环。
// Stop the deadness propagation.
output_frame = nullptr;
} else {
if (input_iter == input_frame->iteration_count &&
input_frame->num_outstanding_iterations ==
input_frame->max_parallel_iterations) {
//如果当前是走在最前面的iteration,而且未决的iteration达到了最大并行的iteration个数,
//那么,把这个tensor暂时保存下来,等iteration 并行度降低了再启动下一个iteration。
// Reached the maximum for parallel iterations.
input_frame->next_iter_roots.push_back({node, (*outputs)[0]});
output_frame = nullptr;
} else {
//如果当前是走在最前面的iteration,
// If this is a new iteration, start it.
if (input_iter == input_frame->iteration_count) {
//这个方法会创建一个IterationState,把input_frame->iteration_count加1。
//另外会清空dead_exits,因为已经确定有下一轮了,没必要保存当前轮的dead exit;
//还会激活滞留在next_iter_roots里的nextiteration 节点。
input_frame->IncrementIteration(&impl_->gview_, ready);
}
//iter加1。
output_iter = input_iter + 1;
}
}
if (output_frame != nullptr) {
// This is the case when node is not Enter, Exit, or NextIteration.
DCHECK(input_frame == output_frame);
output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
}
is_frame_done = input_frame->DecrementOutstandingOpsLocked(
&impl_->gview_, input_iter, ready);
}
//如果当前的frame完成了(while loop 循环结束),可能使得外层的frame也计算结束,
//递归的删除frameState、iterationstate.详情请看CleanupFramesIterations。
// At this point, this node is completely done. We also know if the
// completion of this node makes its frame completed.
if (is_frame_done) {
FrameState* parent_frame = input_frame->parent_frame;
const int64 parent_iter = input_frame->parent_iter;
DeleteFrame(input_frame, ready);
if (parent_frame != nullptr) {
// The completion of frame may cause completions in its parent frame.
// So clean things up recursively.
CleanupFramesIterations(parent_frame, parent_iter, ready);
}
}
}
ActivateNodes
这里对计算完的这个taggednode,循环每一条输出边,把修改这条边的终结点的pending和dead 计数,拷贝输出作为这个节点的输入,如果这个dst 节点is ready 则放如ready 队列。
void ExecutorState::FrameState::ActivateNodes(const NodeItem* item,
const bool is_dead, int64 iter,
EntryVector* outputs,
TaggedNodeSeq* ready) {
//GraphView 和 NodeItem 分别表示计算图和图中的一个节点,
//在计算图执行的过程中不会更改状态
const GraphView& gview = executor->gview_;
//拿到iter 对应的IterationState
IterationState* iter_state = GetIteration(iter);
// 从NodeItem 拿到输出边的信息
const size_t num_output_edges = item->num_output_edges;
const EdgeInfo* edges = item->output_edge_list();
//从iter_state 拿到Entry数组 的起始指针
Entry* input_tensors = iter_state->input_tensors;
//对ItemNode的每一个输出边循环
for (size_t out_index = 0; out_index < num_output_edges; out_index++) {
const EdgeInfo& e = edges[out_index];
//输出边指向的节点
const int dst_id = e.dst_id;
const NodeItem* dst_item = gview.node(dst_id);
//一个handle,用来从PendingCounts 结构里存取counts
//一个Iteration 维持一个PeningCounts 对象,用来保存NodeItem 在本次iteration 的
//pendingcount数据
const PendingCounts::Handle dst_pending_id = dst_item->pending_id;
const int src_slot = e.output_slot;
// TODO(yuanbyu): We don't need this if we require the subgraph
// given to an executor not to contain a sink node.
if (dst_item->is_sink) continue;
bool dst_dead = false;
bool dst_ready = false;
// True iff this input for dst is needed. We only set this input for
// dst if this flag is true. This is needed to make the thread safety
// analysis happy.
const bool is_control_edge = (src_slot == Graph::kControlSlot);
//是否需要这个输入数据
bool dst_need_input = !is_control_edge;
//如果dst是个merge 节点,这里特殊处理
if (dst_item->is_merge) {
//一个merge节点是ready,如果所有的控制边都到齐,两个输入到了一个alive或两个都是dead
// A merge node is ready if all control inputs have arrived and either
// a) a live data input becomes available or b) all data inputs are
// dead. For Merge, pending's LSB is set iff a live data input has
// arrived.
if (is_control_edge) {
//如果是控制边,pending 减去2,pending第一位有其他的用处
iter_state->decrement_pending(dst_pending_id, 2);
int count = iter_state->pending(dst_pending_id);
int dead_cnt = iter_state->dead_count(dst_pending_id);
//两个输入都到齐了且都是dead
dst_dead = (dead_cnt == dst_item->num_inputs);
//pending第一位用来标识是否有一个活的数据到了,其他位用来对控制边计数
//一个merge节点是ready,如果所有的控制边都到齐,
//两个输入到了一个alive(count == 0)或两个都是dead ((count == 1) && dst_dead)
dst_ready = (count == 0) || ((count == 1) && dst_dead);
} else {
//如果是数据边,而且有数据输入
if ((*outputs)[src_slot].has_value) {
// This is a live data input.
//先取出pending
int count = iter_state->pending(dst_pending_id);
//把pending 的最低位清0
iter_state->mark_live(dst_pending_id);
// Only the first live edge sets the input and (potentially)
// triggers execution. The low bit of count is set if and
// only if no live input has been used yet (mark_live clears
// it). The node should be started if and only if this is
// the first live input and there are no pending control
// edges, i.e. count == 1.
dst_ready = (count == 1);
//第一个数据输入设置数据,如果之前已经有一个数据输入了,就不需要拷贝数据
dst_need_input = ((count & 0x1) == 1);
} else {
// This is a dead data input. Note that dst_node is dead if node is
// a dead enter. We need this to handle properly a while loop on
// the untaken branch of a conditional.
// TODO(yuanbyu): This is a bit hacky, but a good solution for
// now.
//如果是一个dead 输入,deadcount 加1
iter_state->increment_dead_count(dst_pending_id);
//取出deadcount
const int dead_cnt = iter_state->dead_count(dst_pending_id);
//如果deadcount 和数据输入的个数相等或输入边的起始节点item 是一个enter
//这里对enter特殊考虑,是因为,如果一个whileloop 嵌套在一个cond里面,如果这个
//whileloop 在一个untaken 分支,那么这个whileloop 里的merge 连 enter 输入都是dead
//且没有其他的输入值到达。
dst_dead = (dead_cnt == dst_item->num_inputs) || item->is_enter;
dst_ready = (iter_state->pending(dst_pending_id) == 1) && dst_dead;
dst_need_input = false;
}
}
} else {
//如果是非merge 节点,判断是否是个dead 输入,对pendingcounts的dead 计数
const bool increment_dead =
(is_dead || (!is_control_edge && !(*outputs)[src_slot].has_value));
int pending, dead;
//调整pendincounts,pending 减1,dead加1或不加1
iter_state->adjust_for_activation(dst_pending_id, increment_dead,
&pending, &dead);
dst_dead = (dead > 0);
dst_ready = (pending == 0);
}
if (dst_need_input) {
//如果需要数据,则把数据拷贝到iterationstate里相应的槽位
const int dst_slot = e.input_slot;
const int dst_loc = dst_item->input_start + dst_slot;
if (e.is_last) {
//如果这是这个输出值的最后一个拷贝,直接移动拷贝
input_tensors[dst_loc] = std::move((*outputs)[src_slot]);
} else {
input_tensors[dst_loc] = (*outputs)[src_slot];
}
}
// 如果dest 节点ready,构造一个tagednode,放入ready 队列。
//tagnode 是<ItemNode,iter,isdead>的三元组,标识一个节点的一次计算
// Add dst to the ready queue if it's ready
if (dst_ready) {
if (dst_item->is_control_trigger) dst_dead = false;
ready->push_back(TaggedNode(dst_item->node, this, iter, dst_dead));
iter_state->outstanding_ops++;
}
}
}
NodeDone
计算的收尾工作,如果ready 不为空,继续循环,否则结束这个线程。
bool ExecutorState::NodeDone(const Status& s, const Node* node,
const TaggedNodeSeq& ready,
NodeExecStatsWrapper* stats,
TaggedNodeReadyQueue* inline_ready) {
nodestats::SetAllEnd(stats);
if (stats_collector_ != nullptr && !SetTimelineLabel(node, stats)) {
// Only record non-transfer nodes.
// Transfers 'stats' ownership to 'stats_collector_'.
stats_collector_->Save(impl_->params_.device->name(), stats);
} else if (stats) {
delete stats;
}
bool abort_run = false;
if (!s.ok()) {
// Some error happened. This thread of computation is done.
mutex_lock l(mu_);
//如果是这个excutor第一次执行出错,标识这个状态,因为一个excutor
//对应一个线程池,在当前线程去检查和改变全局变量status_ 需要上锁
if (status_.ok()) {
abort_run = true;
status_ = s;
}
}
//abort,清理一些通信资源
if (abort_run) {
TRACEPRINTF("StartAbort: %s", s.ToString().c_str());
if (rendezvous_) {
rendezvous_->StartAbort(s);
}
if (collective_executor_) {
collective_executor_->StartAbort(s);
}
if (cancellation_manager_) {
cancellation_manager_->StartCancel();
}
}
bool completed = false;
const size_t ready_size = ready.size();
//如果这个节点计算成功且没有触发其他节点,或这个计算节点出错
//判断当前节点是否是当前唯一的未决节点,如果是,把completed设置成true
//这里的num_outstanding_ops_.fetch_sub(1)语义应该等同于i--,但是是原子的。
if (ready_size == 0 || !s.ok()) {
completed = (num_outstanding_ops_.fetch_sub(1) == 1);
} else if (ready_size > 1) {
//设置全局变量,未决的taggednode增加ready_size - 1,这里减一是因为自身已经计算完成
num_outstanding_ops_.fetch_add(ready_size - 1, std::memory_order_relaxed);
}
// Schedule the ready nodes in 'ready'.
if (s.ok()) {
//如果触发了其他的taggednode(这个节点的输出是这些节点等待的最后一个输入),调度这些节点
//回到ScheduleReady
ScheduleReady(ready, inline_ready);
}
//返回这个excutor的全部计算是否完成。
return completed;
}