tensorflow control flow 3 --- excutor.cc 源码解读

executor和 direct session源码解读

目录

executor和 direct session源码解读

excutor.cc


这篇博客主要从 C++ 源码角度,讲tensorlfow 运行时 对control flow  原语的特殊处理。要完全理解这部分源码,需要对tensorflow control flow 的原理有些了解,建议先看我上一篇博客。

excutor.cc

先来看逻辑清晰一点的excutor.cc。直观地理解,运行时根据可用的资源,把计算图分割成若干个子图,每个子图绑定到一个device(一个cpu/gpu/tpu..etc),对应一个excutor 。excutor.cc这部分代码被单机模式(direct session)和分布式模式(distributed session)共用。

以下的描述可以参考这张图。

在读代码之前,我先列出类excutor.cc内的几个重要的类。

  • struct EdgeInfo 用来表示NodeItem里的输出边
  • struct NodeItem 表示一个op 节点
  • class GraphView // Immutable view of a Graph organized for efficient execution.
  • public Executor ExecutorImpl的基类,一个计算子图绑定一个device,对应一个Executor。
  • class ExecutorImpl : public  Executor excutor的具体实现
  • ExecutorImpl::struct ControlFlowInfo 这个类的对象维护nodeid 对应的Frame name ,一个ControlFlowInfo对象会在ExecutorImpl初始化的时候创建
  • ExecutorImpl::struct  FrameInfo   这个类的对象用来表示一个Frame(while loop)的静态信息。其中最重要的是 一个 pendingcounts数据结构的属性,该数据结构以内存紧凑的方式维护frame中的每一个node的初始pending count(该节点的输入个数及依赖边的个数),这个pendingcount 被用来初始化 FrameInfo 对应的 FrameState里的FrameState Iteration 的pendingcounts属性值。在ExecutorImpl初始化的时候,一系列的FrameInfo会被创建(准确地说,是一个Frame 到FrameInfo的map,map的每一项对应计算图中一个whileloop)
  • class ExecutorState  用源码里的注释描述,// The state associated with one invocation of ExecutorImpl::Run.
  • ExecutorState::struct IterationState  IterationState 维护FrameState(也就是一个实例化的whileloop) 里的一次迭代相关的状态变量。其中最重要的有两个。一个是pendingcounts,这个pendingcounts 的初始值是这个IterationState 所在FrameState对应的FrameInfo里的pendingcounts值,随着计算的过程,这个pendingcounts 会被修改;另一个是inputs数组,这里会存储这个iteration 里所有计算节点的输入。计算节点的输出会复制到所有消费这个输出的节点在inputs 对应的槽位中(可能复制多份)。存在inputs数组里面的tensorvalue 在被消费完后会被销毁。
  •  ExecutorState::struct FrameState  FrameState表示一个Frame的实例化状态。因为Frame 可以嵌套,嵌套在一个whileloop 里的whileloop可能会被多次实例化,一次实例化对应一个FrameState,也就是说一个whileloop,只对应一个FrameInfo,但是可能会实例化多个FrameState。FrameState在图计算的过程中动态地创建,修改和销毁。FrameState维护的重要状态是 IterationState数组。
  • ExecutorState::struct TaggedNode //A tagged node: <frame*, iter, node*,isdead>。表示一个op的实例。Frame里的一个op在Frame的每一次iteration 都会实例化一次。

接下来,讲述一下excutor的运行逻辑。

direct session 调用excutor 的虚函数RunAsync。由RunAsync开始:

  • ExecutorImpl::RunAsync                                                                                
    • ExecutorState::RunAsync          //         (new ExecutorState(args, this))->RunAsync(std::move(done))   
      • ExecutorState::ScheduleReady        //             ScheduleReady(ready, nullptr)
        • runner_            //           runner_([=]() { Process(tagged_node, scheduled_nsec); })
          • process       //          runner_([=]() { Process(tagged_node, scheduled_nsec); })
            • PrepareInputs    // PrepareInputs(item, first_input, &inputs, &input_device_contexts,
                                      &input_alloc_attrs, &is_input_dead)
            • Compute // device->Compute(CHECK_NOTNULL(op_kernel), &ctx);
            • ProcessOutputs // ProcessOutputs(item, &ctx, &outputs, stats);
            • PropagateOutputs //  PropagateOutputs(tagged_node, &item, &outputs, &ready);
              • ExecutorState::FrameState::ActivateNodes
            • NodeDone // NodeDone(s, item.node, ready, stats, &inline_ready)

ExecutorState::RunAsync初始化运行,把计算图入度为0的opnode放入一个tagednodeReadyQueue;调用ExecutorState::ScheduleReady开始调度。ExecutorState::ScheduleReady里面把process方法包装成一个lambda传入runner_,把任务提交给线程池,runner_是ExecutorState的一个属性,由ExecutorState构造函数传入,是一个std::function<void(std::function<void()>)> 对象,负责调度任务。进入process,process 把taggednode 放入一个队列,readyqueue。进入while循环,在循环中依次调用PrepareInputsComputeProcessOutputsPropagateOutputsNodeDone,PropagateOutputs 可能在队列里加入新的元素。循环直到readyqueue为空

以上的调用链源码分析已经有人写过博客了,详情请看这位高人的博客,我这里着重解析和控制流相关的部分,主要集中在process里。以下我主要分析PrepareInput、ProcessOutputs、PropagateOutputs、NodeDone这5个方法。

PrepareInputs

源代码及我加的注释

prepareinputs 的作用,如名字所表达地那样,从IterationState里把输入地tenser及其一些附带信息取出。

//主要逻辑是把输入从first_input(数组中的一个)开始的片段中拷贝到inputs,并收集一些其他的信息。
//在调用的时候,first_input指向IterationState中inputs数组中的一个元素
//在process里调用结束后,inputs、input_device_contexts、input_alloc_attrs、is_input_dead
//这些数据会传递给opkernelcontext,在Kernel::compute里被使用。
Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input,
                                    TensorValueVec* inputs,
                                    DeviceContextVec* input_device_contexts,
                                    AllocatorAttributeVec* input_alloc_attrs,
                                    bool* is_input_dead) {
  const Node* node = item.node;

  inputs->clear();
  inputs->resize(item.num_inputs);
  input_device_contexts->clear();
  input_device_contexts->resize(item.num_inputs);
  input_alloc_attrs->clear();
  input_alloc_attrs->resize(item.num_inputs);

  *is_input_dead = false;

  bool is_merge = item.is_merge;
  for (int i = 0; i < item.num_inputs; ++i) {
    //从该taggednode对应的NodeItem获取一些输入值的属性
    const bool expect_ref = IsRefType(item.input_type(i));
    Entry* entry = first_input + i;
    (*input_device_contexts)[i] = entry->device_context;
    (*input_alloc_attrs)[i] = entry->alloc_attr;
    // i-th input.
    TensorValue* inp = &(*inputs)[i];
    //校验,PrepareInputs在process被调用,只有merge 和 transfer 节点才可能执行到这,因为merge
    //本身容许 无值的输入,而运输节点 需要在子图之间传递 isdead 标识(详情请看我的上一篇博客),
    // Only merge and transfer nodes can have no-value inputs.
    if (!entry->has_value) {
      if (!is_merge) {
        DCHECK(IsTransferNode(node)) << node->name() << " - input " << i;
        DCHECK(!entry->val_field_is_set) << node->name() << " - input " << i;
        entry->has_value = true;
        entry->val_field_is_set = true;
        entry->val.Init(*kEmptyTensor);
        inp->tensor = entry->val.get();
        *is_input_dead = true;
      }
      continue;
    }
    //拷贝数据,考虑引用和非引用之间的拷贝。
    if (entry->ref == nullptr) {
      if (expect_ref) {
        return AttachDef(
            errors::InvalidArgument(i, "-th input expects a ref type"),
            item.kernel->def());
      }
      inp->tensor = entry->val.get();
    } else {
      {
        mutex_lock ml(*entry->ref_mu);
        if (!entry->ref->IsInitialized() && !IsInitializationOp(item.node)) {
          return AttachDef(errors::FailedPrecondition(
                               "Attempting to use uninitialized value ",
                               item.kernel->requested_input(i)),
                           item.kernel->def());
        }
      }
      if (expect_ref) {
        inp->mutex_if_ref = entry->ref_mu;
        inp->tensor = entry->ref;
      } else {
        // Automatically deref the tensor ref when the op expects a
        // tensor but is given a ref to a tensor.  Need to deref it
        // under the mutex.
        {
          mutex_lock l(*(entry->ref_mu));
          DCHECK(!entry->val_field_is_set);
          entry->val.Init(*entry->ref);
          entry->val_field_is_set = true;
        }
        entry->ref = nullptr;
        entry->ref_mu = nullptr;

        inp->tensor = entry->val.get();
        // The dtype of entry->ref could have been changed by another operation
        // that ran after the operation that "produced" it executed, so
        // re-validate that the type of the dereferenced tensor matches the
        // expected input type.
        if (item.input_type(i) != inp->tensor->dtype()) {
          return AttachDef(
              errors::InvalidArgument(
                  i, "-th input expects type ",
                  DataTypeString(item.input_type(i)),
                  " but automatically dereferenced input tensor has type ",
                  DataTypeString(inp->tensor->dtype())),
              item.kernel->def());
        }
      }
    }
  }
  return Status::OK();
}

Compute

compute 调用kernel 执行计算

ProcessOutputs     

processoutput把计算地opkernelcontext里地输出拷贝出来   

//Kernel 运行结束,从opKernelContext获得计算的输出,暂时存储在outputs指向的数组。
Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
                                     EntryVector* outputs,
                                     NodeExecStatsWrapper* stats) {
  const Node* node = item.node;
  DCHECK_EQ(0, outputs->size());
  outputs->resize(item.num_outputs);

  Status s = ctx->status();
  if (!s.ok()) {
    s = AttachDef(s, item.kernel->def());
    // TODO(misard) Replace with a finer-grain enabling flag once we
    // add better optional debugging support.
    if (vlog_ && VLOG_IS_ON(1)) {
      LOG(WARNING) << this << " Compute status: " << s;
      DumpState();
    }
    if (s.code() == error::RESOURCE_EXHAUSTED) {
      if (stats_collector_) {
        string err = stats_collector_->ReportAllocsOnResourceExhausted(
            s.error_message());
        s = Status(s.code(), strings::StrCat(s.error_message(), err));
      } else {
        s = Status(
            s.code(),
            strings::StrCat(
                s.error_message(),
                "\nHint: If you want to see a list of allocated tensors when "
                "OOM happens, add report_tensor_allocations_upon_oom "
                "to RunOptions for current allocation info.\n"));
      }
    }
    return s;
  }

  // Get the device_context for this node id, if it exists.
  DeviceContext* device_context = nullptr;
  if (node->id() < device_context_map_.size()) {
    device_context = device_context_map_[node->id()];
  }

  for (int i = 0; i < item.num_outputs; ++i) {
    const TensorValue val = ctx->release_output(i);
    if (val.tensor == nullptr) {
      // Unless it's a Switch or a Recv, the node must produce a
      // tensor value at i-th output.
      if (!IsSwitch(node) && !IsRecv(node)) {
        s.Update(errors::Internal("Missing ", i, "-th output from ",
                                  SummarizeNode(*node)));
      }
    } else {
      Entry* out = &((*outputs)[i]);

      // Set the device context of the output entry.
      out->device_context = device_context;

      // Set the allocator attributes of the output entry.
      out->alloc_attr = ctx->output_alloc_attr(i);

      // Sanity check of output tensor types.
      DataType dtype;
      if (val.is_ref()) {
        mutex_lock ml(*val.mutex_if_ref);
        dtype = MakeRefType(val->dtype());
      } else {
        dtype = val->dtype();
      }
      if (dtype == item.output_type(i)) {
        if (stats && val.tensor->IsInitialized()) {
          nodestats::SetOutput(stats, i, val.tensor);
        }
        if (val.is_ref()) {
          out->has_value = true;
          out->ref = val.tensor;
          out->ref_mu = val.mutex_if_ref;
          if (log_memory_) {
            Tensor to_log;
            {
              // Dereference the tensor under the lock.
              mutex_lock l(*out->ref_mu);
              to_log = *out->ref;
            }
            LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
                                          ctx->step_id(), i, to_log);
          }
        } else {
          // NOTE that std::move is used here, so val.tensor goes to
          // uninitialized state (val.tensor->IsInitialized return false).
          DCHECK(!out->val_field_is_set);
          out->has_value = true;
          out->val_field_is_set = true;
          out->val.Init(std::move(*val.tensor));
          if (log_memory_) {
            LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
                                          ctx->step_id(), i, *out->val);
          }
        }
      } else {
        s.Update(errors::Internal("Output ", i, " of type ",
                                  DataTypeString(dtype),
                                  " does not match declared output type ",
                                  DataTypeString(item.output_type(i)),
                                  " for node ", SummarizeNode(*node)));
      }
    }
    if (!val.is_ref()) {
      // If OpKernelContext returns outputs via pass-by-value, we
      // don't need this trouble.
      delete val.tensor;
    }
  }
  return s;
}

PropagateOutputs,这里是对控制流特殊处理的的核心部分。netx/enter/exit 会在这特殊处理

 

//把上一步的到的输出,拷贝给需要改输出作为输入的TaggedNode,存放在taggedNode对应的IterationState 里的inputs数组。
//这里会特殊地对待控制流算子
void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
                                     const NodeItem* item, EntryVector* outputs,
                                     TaggedNodeSeq* ready) {
  const Node* node = tagged_node.node;
  FrameState* input_frame = tagged_node.input_frame;
  const int64 input_iter = tagged_node.input_iter;
  const bool is_dead = tagged_node.is_dead;
  //如源码注释所言,沿着输出边,减小对应节点的入度,如果一个tagednode满足激活条件(对于一般节 
  //点,入度为0),就把它放到ready 队列
  // Propagates outputs along out edges, and puts newly ready nodes
  // into the ready queue.
  ready->clear();
  bool is_frame_done = false;
  FrameState* output_frame = input_frame;
  int64 output_iter = input_iter;
  //控制流相关节点特殊处理
  if (!item->is_enter_exit_or_next_iter) {
    //其他节点的处理,merge节点也走这条线,merge会在ActivateNodes方法里特殊处理。
    // Fast path for nodes types that don't need special handling
    DCHECK_EQ(input_frame, output_frame);
    // Normal path for most nodes
    mutex_lock l(input_frame->mu);
    output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
    is_frame_done = input_frame->DecrementOutstandingOpsLocked(
        &impl_->gview_, input_iter, ready);
  } else if (item->is_enter) {
    //对enter 特殊处理
    bool is_constant;
    const Status s = GetNodeAttr(node->attrs(), "is_constant", &is_constant);
    DCHECK(s.ok()) << s;
    //enter 可能需要创建一个FrameState,准确地说,如果是第一次enter一个frame,会创建一个 
    //framestate。
    FindOrCreateChildFrame(input_frame, input_iter, node, &output_frame);
    output_iter = 0;
    {
      const NodeItem* item = impl_->gview_.node(node->id());
      mutex_lock l(output_frame->mu);
      if (is_constant) {
        // Propagate to all active iterations if this is a loop invariant.
        output_frame->AddLoopInv(item, (*outputs)[0], ready);
      } else {
        output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
      }
      output_frame->num_pending_inputs--;
    }
    is_frame_done =
        input_frame->DecrementOutstandingOps(&impl_->gview_, input_iter, ready);
  } else if (item->is_exit) {
    //对exit 节点特殊处理
    if (is_dead) {
      mutex_lock l(input_frame->mu);
      // Stop and remember this node if it is a dead exit.
      //如果是个dead exit,而且当前所在地iteration是当前的frame里走在最前面的iteration,暂时存 
      //储下来,之后在销这 
      //首次进入下一轮时销毁,或一直传到外层frame(如果当前的whileloop 在一个条件分支里。在本 
      //在当前的FrameState被销毁时才处理,
      //因为只有到最后才知道,是否时真的dead,比如说,一个whileloop循环了两次,那么这个循环里
      //的 exits 节点第一个 iteration里的实例化出来的taggednode 是dead,但是第二个iteration的 
      //实例化出的taggednode就不是dead。
      if (input_iter == input_frame->iteration_count) {
        input_frame->dead_exits.push_back(node);
      }
      //上锁同步,减小未决的op个数,并判断这个frame是否计算完成,如果是,在后面会销毁这个
      //framestate
      is_frame_done = input_frame->DecrementOutstandingOpsLocked(
          &impl_->gview_, input_iter, ready);
    } else {
      //如果是个live exist 节点,向上层frame 传递 exist 对应的tensor
      output_frame = input_frame->parent_frame;
      output_iter = input_frame->parent_iter;
      {
        mutex_lock l(output_frame->mu);
        output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
      }
      is_frame_done = input_frame->DecrementOutstandingOps(&impl_->gview_,
                                                           input_iter, ready);
    }
  } else {
    //nextiteration 节点特殊处理
    DCHECK(IsNextIteration(node));
    //注意这里上锁了,以下需要修改一些线程共享变量
    mutex_lock l(input_frame->mu);
    if (is_dead) {
      //如果传入的tensor是dead,不创建下一个iteration,循环到此为止,由此可知,nextiteration的 
      //一个作用是截断循环。
      // Stop the deadness propagation.
      output_frame = nullptr;
    } else {
      if (input_iter == input_frame->iteration_count &&
          input_frame->num_outstanding_iterations ==
              input_frame->max_parallel_iterations) {
        //如果当前是走在最前面的iteration,而且未决的iteration达到了最大并行的iteration个数, 
        //那么,把这个tensor暂时保存下来,等iteration 并行度降低了再启动下一个iteration。
        // Reached the maximum for parallel iterations.
        input_frame->next_iter_roots.push_back({node, (*outputs)[0]});
        output_frame = nullptr;
      } else {
        //如果当前是走在最前面的iteration,
        // If this is a new iteration, start it.
        if (input_iter == input_frame->iteration_count) {
          //这个方法会创建一个IterationState,把input_frame->iteration_count加1。
          //另外会清空dead_exits,因为已经确定有下一轮了,没必要保存当前轮的dead  exit;
          //还会激活滞留在next_iter_roots里的nextiteration 节点。
          input_frame->IncrementIteration(&impl_->gview_, ready);
        }
        //iter加1。
        output_iter = input_iter + 1;
      }
    }
    if (output_frame != nullptr) {
      // This is the case when node is not Enter, Exit, or NextIteration.
      DCHECK(input_frame == output_frame);
      output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
    }
    is_frame_done = input_frame->DecrementOutstandingOpsLocked(
        &impl_->gview_, input_iter, ready);
  }
  //如果当前的frame完成了(while loop 循环结束),可能使得外层的frame也计算结束,
  //递归的删除frameState、iterationstate.详情请看CleanupFramesIterations。
  // At this point, this node is completely done. We also know if the
  // completion of this node makes its frame completed.
  if (is_frame_done) {
    FrameState* parent_frame = input_frame->parent_frame;
    const int64 parent_iter = input_frame->parent_iter;
    DeleteFrame(input_frame, ready);
    if (parent_frame != nullptr) {
      // The completion of frame may cause completions in its parent frame.
      // So clean things up recursively.
      CleanupFramesIterations(parent_frame, parent_iter, ready);
    }
  }
}

 

   ActivateNodes           

  这里对计算完的这个taggednode,循环每一条输出边,把修改这条边的终结点的pending和dead 计数,拷贝输出作为这个节点的输入,如果这个dst 节点is ready 则放如ready 队列。                                                                                           

void ExecutorState::FrameState::ActivateNodes(const NodeItem* item,
                                              const bool is_dead, int64 iter,
                                              EntryVector* outputs,
                                              TaggedNodeSeq* ready) {
  //GraphView 和 NodeItem 分别表示计算图和图中的一个节点,
  //在计算图执行的过程中不会更改状态
  const GraphView& gview = executor->gview_;
  //拿到iter 对应的IterationState
  IterationState* iter_state = GetIteration(iter);
  // 从NodeItem 拿到输出边的信息
  const size_t num_output_edges = item->num_output_edges;
  const EdgeInfo* edges = item->output_edge_list();
  //从iter_state 拿到Entry数组 的起始指针
  Entry* input_tensors = iter_state->input_tensors;
  //对ItemNode的每一个输出边循环
  for (size_t out_index = 0; out_index < num_output_edges; out_index++) {
    const EdgeInfo& e = edges[out_index];
     //输出边指向的节点
    const int dst_id = e.dst_id;
    const NodeItem* dst_item = gview.node(dst_id);
    //一个handle,用来从PendingCounts 结构里存取counts
    //一个Iteration 维持一个PeningCounts 对象,用来保存NodeItem 在本次iteration 的 
    //pendingcount数据
    const PendingCounts::Handle dst_pending_id = dst_item->pending_id;
    const int src_slot = e.output_slot;

    // TODO(yuanbyu): We don't need this if we require the subgraph
    // given to an executor not to contain a sink node.
    if (dst_item->is_sink) continue;

    bool dst_dead = false;
    bool dst_ready = false;
    // True iff this input for dst is needed. We only set this input for
    // dst if this flag is true. This is needed to make the thread safety
    // analysis happy.
    const bool is_control_edge = (src_slot == Graph::kControlSlot);
    //是否需要这个输入数据
    bool dst_need_input = !is_control_edge;
    //如果dst是个merge 节点,这里特殊处理
    if (dst_item->is_merge) {
      //一个merge节点是ready,如果所有的控制边都到齐,两个输入到了一个alive或两个都是dead
      // A merge node is ready if all control inputs have arrived and either
      // a) a live data input becomes available or b) all data inputs are
      // dead. For Merge, pending's LSB is set iff a live data input has
      // arrived.
      if (is_control_edge) {
        //如果是控制边,pending 减去2,pending第一位有其他的用处
        iter_state->decrement_pending(dst_pending_id, 2);
        int count = iter_state->pending(dst_pending_id);
        int dead_cnt = iter_state->dead_count(dst_pending_id);
        //两个输入都到齐了且都是dead
        dst_dead = (dead_cnt == dst_item->num_inputs);
        //pending第一位用来标识是否有一个活的数据到了,其他位用来对控制边计数
        //一个merge节点是ready,如果所有的控制边都到齐,
        //两个输入到了一个alive(count == 0)或两个都是dead ((count == 1) && dst_dead)
        dst_ready = (count == 0) || ((count == 1) && dst_dead);
      } else {
        //如果是数据边,而且有数据输入
        if ((*outputs)[src_slot].has_value) {
          // This is a live data input.
          //先取出pending
          int count = iter_state->pending(dst_pending_id);
          //把pending 的最低位清0
          iter_state->mark_live(dst_pending_id);
          // Only the first live edge sets the input and (potentially)
          // triggers execution. The low bit of count is set if and
          // only if no live input has been used yet (mark_live clears
          // it). The node should be started if and only if this is
          // the first live input and there are no pending control
          // edges, i.e. count == 1.
          dst_ready = (count == 1);
          //第一个数据输入设置数据,如果之前已经有一个数据输入了,就不需要拷贝数据
          dst_need_input = ((count & 0x1) == 1);
        } else {
          // This is a dead data input. Note that dst_node is dead if node is
          // a dead enter. We need this to handle properly a while loop on
          // the untaken branch of a conditional.
          // TODO(yuanbyu): This is a bit hacky, but a good solution for
          // now.
          //如果是一个dead 输入,deadcount 加1
          iter_state->increment_dead_count(dst_pending_id);
          //取出deadcount
          const int dead_cnt = iter_state->dead_count(dst_pending_id);
          //如果deadcount 和数据输入的个数相等或输入边的起始节点item 是一个enter
          //这里对enter特殊考虑,是因为,如果一个whileloop 嵌套在一个cond里面,如果这个
          //whileloop 在一个untaken 分支,那么这个whileloop 里的merge 连 enter 输入都是dead
          //且没有其他的输入值到达。
          dst_dead = (dead_cnt == dst_item->num_inputs) || item->is_enter;
          dst_ready = (iter_state->pending(dst_pending_id) == 1) && dst_dead;
          dst_need_input = false;
        }
      }
    } else {
      //如果是非merge 节点,判断是否是个dead 输入,对pendingcounts的dead 计数
      const bool increment_dead =
          (is_dead || (!is_control_edge && !(*outputs)[src_slot].has_value));
      int pending, dead;
      //调整pendincounts,pending 减1,dead加1或不加1
      iter_state->adjust_for_activation(dst_pending_id, increment_dead,
                                        &pending, &dead);
      dst_dead = (dead > 0);
      dst_ready = (pending == 0);
    }

    if (dst_need_input) {
     //如果需要数据,则把数据拷贝到iterationstate里相应的槽位
      const int dst_slot = e.input_slot;
      const int dst_loc = dst_item->input_start + dst_slot;
      if (e.is_last) {
        //如果这是这个输出值的最后一个拷贝,直接移动拷贝
        input_tensors[dst_loc] = std::move((*outputs)[src_slot]);
      } else {
        input_tensors[dst_loc] = (*outputs)[src_slot];
      }
    }
    // 如果dest 节点ready,构造一个tagednode,放入ready 队列。
    //tagnode 是<ItemNode,iter,isdead>的三元组,标识一个节点的一次计算
    // Add dst to the ready queue if it's ready
    if (dst_ready) {
      if (dst_item->is_control_trigger) dst_dead = false;
      ready->push_back(TaggedNode(dst_item->node, this, iter, dst_dead));
      iter_state->outstanding_ops++;
    }
  }
}

NodeDone

计算的收尾工作,如果ready 不为空,继续循环,否则结束这个线程。

bool ExecutorState::NodeDone(const Status& s, const Node* node,
                             const TaggedNodeSeq& ready,
                             NodeExecStatsWrapper* stats,
                             TaggedNodeReadyQueue* inline_ready) {
  nodestats::SetAllEnd(stats);
  if (stats_collector_ != nullptr && !SetTimelineLabel(node, stats)) {
    // Only record non-transfer nodes.
    // Transfers 'stats' ownership to 'stats_collector_'.
    stats_collector_->Save(impl_->params_.device->name(), stats);
  } else if (stats) {
    delete stats;
  }

  bool abort_run = false;
  if (!s.ok()) {
    // Some error happened. This thread of computation is done.
    mutex_lock l(mu_);
    //如果是这个excutor第一次执行出错,标识这个状态,因为一个excutor
    //对应一个线程池,在当前线程去检查和改变全局变量status_ 需要上锁
    if (status_.ok()) {
      abort_run = true;
      status_ = s;
    }
  }
  //abort,清理一些通信资源
  if (abort_run) {
    TRACEPRINTF("StartAbort: %s", s.ToString().c_str());
    if (rendezvous_) {
      rendezvous_->StartAbort(s);
    }
    if (collective_executor_) {
      collective_executor_->StartAbort(s);
    }
    if (cancellation_manager_) {
      cancellation_manager_->StartCancel();
    }
  }

  bool completed = false;
  const size_t ready_size = ready.size();
  //如果这个节点计算成功且没有触发其他节点,或这个计算节点出错
  //判断当前节点是否是当前唯一的未决节点,如果是,把completed设置成true
  //这里的num_outstanding_ops_.fetch_sub(1)语义应该等同于i--,但是是原子的。
  if (ready_size == 0 || !s.ok()) {
    completed = (num_outstanding_ops_.fetch_sub(1) == 1);
  } else if (ready_size > 1) {
    //设置全局变量,未决的taggednode增加ready_size - 1,这里减一是因为自身已经计算完成
    num_outstanding_ops_.fetch_add(ready_size - 1, std::memory_order_relaxed);
  }

  // Schedule the ready nodes in 'ready'.
  if (s.ok()) {
    //如果触发了其他的taggednode(这个节点的输出是这些节点等待的最后一个输入),调度这些节点
    //回到ScheduleReady
    ScheduleReady(ready, inline_ready);
  }
  //返回这个excutor的全部计算是否完成。
  return completed;
}

 

 

 

 

 

 

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值