torch autograd c++ source code 源码主干部分(大部分在csrc里面的Engine.cpp)

Edge{
/// The function this `Edge` points to.
  std::shared_ptr<Node> function;

  /// The identifier of a particular input to the function.
  uint32_t input_nr;
}
Node{
  // Sequence number used to correlate backward nodes with forward ops in the
  // profiler and provide determinisim in the engine.
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  const uint64_t sequence_nr_;
  edge_list next_edges_;
}
//The input and output is reverse compared with forward
auto Engine::execute(const edge_list& roots,//root, the output inthe forward: loss
                     const variable_list& inputs,//the tensor need grad
                     bool keep_graph,//...
                     bool create_graph,//if we need save the graph calculate grad
                     bool accumulate_grad,//is True when use backward()
                     const edge_list& outputs//the edge of input
                     )
                     init_local_ready_queue
                       auto graph_task = std::make_shared<GraphTask>(
      /* keep_graph */ keep_graph,
      /* create_graph */ create_graph,
      /* depth */ not_reentrant_backward_call ? 0 : total_depth + 1,
      /* cpu_ready_queue */ local_ready_queue);
                   compute_min_topological_nr//find the fast leaf, the node after it needn't calculate
                   compute_dependencies//find the input number of Node
                   init_to_execute//use exec_info to sign the one need calculate
                   execute_with_graph_task(graph_task, graph_root, std::move(input_buffer));
 }


c10::intrusive_ptr<at::ivalue::Future> Engine::execute_with_graph_task(
    const std::shared_ptr<GraphTask>& graph_task,
    std::shared_ptr<Node> graph_root,
    InputBuffer&& input_buffer) {
    auto queue = ready_queue(graph_task->cpu_ready_queue_, input_buffer.device());//graph_task->cpu_ready_queue_ is local_ready_queue
    queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer)));
	thread_main(graph_task)
}


auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {
while (graph_task == nullptr || !graph_task->future_result_->completed()) {
std::shared_ptr<GraphTask> local_graph_task;
{
NodeTask task = local_ready_queue->pop();//reentrant_thread_init(), init_local_ready_queue(in thread_init())
if (!(local_graph_task = task.base_.lock())) {//convert to shared_ptr
        // GraphTask for function is no longer valid, skipping further
        // execution.
        continue;
      }
evaluate_function(
                local_graph_task,
                task.fn_.get(),
                task.inputs_,
                local_graph_task->cpu_ready_queue_);
          }
}}}
void Engine::evaluate_function(//leave out the hook part
    std::shared_ptr<GraphTask>& graph_task,
    Node* func,
    InputBuffer& inputs,
    const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
    auto outputs = call_function(graph_task, func, inputs);
    //若该Node对应的dependencies降为0, 则从graph_task的dependecies中删除该Node, 并将该Node设为ready;按ready看看放哪里
    int num_outputs = outputs.size();
    for (const auto i : c10::irange(num_outputs)) {
    auto& output = outputs[i];
    const auto& next = fn.next_edge(i);
    auto &input_buffer = not_ready_it->second;//if exist, or not, it will be innitialize by InputBuffer input_buffer(next.function->num_inputs());
	input_buffer.add(next.input_nr,
                       std::move(output),
                       opt_parent_stream,
                       opt_next_stream);
    }
    }
static variable_list call_function(
    std::shared_ptr<GraphTask>& graph_task,
    Node* func,
    InputBuffer& inputBuffer) {
    outputs = fn(std::move(inputs))
    }
init_to_execute
sign
method1: recursion
  // is_needed = {fn: True for fn in outputs}             # (0)
  // seen = {}
  // def compute_is_needed(fn):
  //   for next_edge in fn.next_edges:
  //     child_fn = next_edge.fn
  //     if child_fn in seen and is_needed[child_fn]:     # (1)
  //       is_needed[fn] = true
  //     else:
  //       seen.add(child_fn)
  //       if compute_is_needed(child_fn):
  //         is_needed[fn] = true                         # (2)
  //                                                      # (3) exit for-loop
  //   return is_needed[fn]
  // compute_is_needed(graph_root)
method2: 
//use stack
  while (!stack.empty()) {
    auto &frame = stack.back();
    const auto fn = frame.fn_;

    Node *child_fn = nullptr;
    while((child_fn = frame.get_next_fn()) && !seen.emplace(child_fn).second) {
      // (1) next child exists AND has already been seen
      if (nodeShouldExecute(child_fn)) {
        exec_info_[fn].needed_ = true;
      }
    }

    if (child_fn) {
      // (2) next child exists but has not been seen
      if (child_fn->topological_nr() < min_topo_nr) {
        // child created before the first output means this child cannot have
        // an edge to output
        continue;
      }
      stack.emplace_back(child_fn);
    } else {
      // (3) no next child exists for `fn` means its `needed` has already been
      // finalized. pop stack and update parent
      stack.pop_back();
      if (nodeShouldExecute(fn) && !stack.empty()) {
        exec_info_[stack.back().fn_].needed_ = true;
      }
    }
  }
}
struct InputMetadata//the size device...of data
struct InputBuffer //accumulates a list of Variables for use by a function
std::vector<Variable> buffer
struct NodeTask{//
std::weak_ptr<GraphTask> base_;
  std::shared_ptr<Node> fn_;
  InputBuffer inputs_;
}
// GraphTask holds metadata needed for a single execution of backward()
struct GraphTask: std::enable_shared_from_this<GraphTask> {
std::unordered_map<Node*, ExecInfo> exec_info_;
std::vector<Variable> captured_vars_;
class C10_API Stream final {
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值