Edge{
/// The function this `Edge` points to.
std::shared_ptr<Node> function;
/// The identifier of a particular input to the function.
uint32_t input_nr;
}
Node{
// Sequence number used to correlate backward nodes with forward ops in the
// profiler and provide determinisim in the engine.
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
const uint64_t sequence_nr_;
edge_list next_edges_;
}
//The input and output is reverse compared with forward
auto Engine::execute(const edge_list& roots,//root, the output inthe forward: loss
const variable_list& inputs,//the tensor need grad
bool keep_graph,//...
bool create_graph,//if we need save the graph calculate grad
bool accumulate_grad,//is True when use backward()
const edge_list& outputs//the edge of input
)
init_local_ready_queue
auto graph_task = std::make_shared<GraphTask>(
/* keep_graph */ keep_graph,
/* create_graph */ create_graph,
/* depth */ not_reentrant_backward_call ? 0 : total_depth + 1,
/* cpu_ready_queue */ local_ready_queue);
compute_min_topological_nr//find the fast leaf, the node after it needn't calculate
compute_dependencies//find the input number of Node
init_to_execute//use exec_info to sign the one need calculate
execute_with_graph_task(graph_task, graph_root, std::move(input_buffer));
}
c10::intrusive_ptr<at::ivalue::Future> Engine::execute_with_graph_task(
const std::shared_ptr<GraphTask>& graph_task,
std::shared_ptr<Node> graph_root,
InputBuffer&& input_buffer) {
auto queue = ready_queue(graph_task->cpu_ready_queue_, input_buffer.device());//graph_task->cpu_ready_queue_ is local_ready_queue
queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer)));
thread_main(graph_task)
}
auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {
while (graph_task == nullptr || !graph_task->future_result_->completed()) {
std::shared_ptr<GraphTask> local_graph_task;
{
NodeTask task = local_ready_queue->pop();//reentrant_thread_init(), init_local_ready_queue(in thread_init())
if (!(local_graph_task = task.base_.lock())) {//convert to shared_ptr
// GraphTask for function is no longer valid, skipping further
// execution.
continue;
}
evaluate_function(
local_graph_task,
task.fn_.get(),
task.inputs_,
local_graph_task->cpu_ready_queue_);
}
}}}
void Engine::evaluate_function(//leave out the hook part
std::shared_ptr<GraphTask>& graph_task,
Node* func,
InputBuffer& inputs,
const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
auto outputs = call_function(graph_task, func, inputs);
//若该Node对应的dependencies降为0, 则从graph_task的dependecies中删除该Node, 并将该Node设为ready;按ready看看放哪里
int num_outputs = outputs.size();
for (const auto i : c10::irange(num_outputs)) {
auto& output = outputs[i];
const auto& next = fn.next_edge(i);
auto &input_buffer = not_ready_it->second;//if exist, or not, it will be innitialize by InputBuffer input_buffer(next.function->num_inputs());
input_buffer.add(next.input_nr,
std::move(output),
opt_parent_stream,
opt_next_stream);
}
}
static variable_list call_function(
std::shared_ptr<GraphTask>& graph_task,
Node* func,
InputBuffer& inputBuffer) {
outputs = fn(std::move(inputs))
}
init_to_execute
sign
method1: recursion
// is_needed = {fn: True for fn in outputs} # (0)
// seen = {}
// def compute_is_needed(fn):
// for next_edge in fn.next_edges:
// child_fn = next_edge.fn
// if child_fn in seen and is_needed[child_fn]: # (1)
// is_needed[fn] = true
// else:
// seen.add(child_fn)
// if compute_is_needed(child_fn):
// is_needed[fn] = true # (2)
// # (3) exit for-loop
// return is_needed[fn]
// compute_is_needed(graph_root)
method2:
//use stack
while (!stack.empty()) {
auto &frame = stack.back();
const auto fn = frame.fn_;
Node *child_fn = nullptr;
while((child_fn = frame.get_next_fn()) && !seen.emplace(child_fn).second) {
// (1) next child exists AND has already been seen
if (nodeShouldExecute(child_fn)) {
exec_info_[fn].needed_ = true;
}
}
if (child_fn) {
// (2) next child exists but has not been seen
if (child_fn->topological_nr() < min_topo_nr) {
// child created before the first output means this child cannot have
// an edge to output
continue;
}
stack.emplace_back(child_fn);
} else {
// (3) no next child exists for `fn` means its `needed` has already been
// finalized. pop stack and update parent
stack.pop_back();
if (nodeShouldExecute(fn) && !stack.empty()) {
exec_info_[stack.back().fn_].needed_ = true;
}
}
}
}
struct InputMetadata//the size device...of data
struct InputBuffer //accumulates a list of Variables for use by a function
std::vector<Variable> buffer
struct NodeTask{//
std::weak_ptr<GraphTask> base_;
std::shared_ptr<Node> fn_;
InputBuffer inputs_;
}
// GraphTask holds metadata needed for a single execution of backward()
struct GraphTask: std::enable_shared_from_this<GraphTask> {
std::unordered_map<Node*, ExecInfo> exec_info_;
std::vector<Variable> captured_vars_;
class C10_API Stream final {