Tensorflow源码分析-内存优化:memory_optimizer.cc

FindCandidateRecomputeNodes函数:

目的:找到可能需要重计算的候选节点

就是找到那些计算成本低,输出会投入到目标节点,又不依赖目标节点输入的节点,这些节点是可能需要重计算的候选子图。

作为候选重计算节点的条件:

  1. 满足重计算条件(计算简单等,传入函数is_candidate实现)
  2. 该节点的输出节点中包含目标节点(传入函数is_target实现)
  3. 该节点的输入节点不依赖于目标节点

后续的GetOpGroupsToRecompute函数会基于这些候选节点,寻找完整的需要重计算的子图。

// Find recomputable ops which feed into target nodes.
std::unordered_set<const NodeDef*> FindCandidateRecomputeNodes(
    const NodeMap& node_map, const GraphDef* graph,
    const std::function<bool(const NodeDef&)>& is_candidate,
    const std::function<bool(const NodeDef&)>& is_target) {
  std::unordered_set<const NodeDef*> candidate_recompute_nodes;
  //遍历图形中的每个节点
  for (const auto& node : graph->node()) {
    //1判断是否符合重计算条件(通过传入的is_candidate函数判断)
    if (!is_candidate(node)) {
      continue;
    }
    bool has_target_output = false;
    //2查看该节点是否有输出会投入到目标节点(通过is_target函数判断),如果没有就忽略
    //获得该节点的所有输出节点->判断输出节点是否为目标节点
    for (const NodeDef* output : node_map.GetOutputs(node.name())) {
      // It only makes sense to recompute this if it feeds into a target
      // node. We expand this to dependencies in GetOpGroupsToRecompute.
      if (is_target(*output)) {
        has_target_output = true;
        break;
      }
    }
    if (!has_target_output) {
      continue;
    }
    bool has_target_input = false;
    //3判断该节点是否依赖于目标节点的输入,如果是就忽略
    for (const string& input_name : node.input()) {
      // Don't recompute nodes which depend on target nodes.
      const NodeDef* input_node = node_map.GetNode(input_name);
      if (is_target(*input_node)) {
        has_target_input = true;
        break;
      }
    }
    if (has_target_input) {
      continue;
    }
    //如果满足作为候选重计算节点条件就添加到candidate_recompute_nodes集合中
    candidate_recompute_nodes.insert(&node);
  }
  return candidate_recompute_nodes;
}

connected_subgraph函数:

根据种子节点扩展子图(扩展子图的作用是什么?)

将expanded_nodes放入to_visit队列中,遍历to_visit队列,如果已访问过则跳过,根据传入的参数决定是否收集输入、输出节点。

// 以广度优先搜索的方式扩展子图,得到所有直接和间接依赖的节点。
// 你可以设置collect_inputs和collect_outputs来决定是否收集输入、输出节点
void connected_subgraph(const NodeMap& node_map, bool collect_inputs,
                        bool collect_outputs,
                        const std::function<bool(const NodeDef&)>& is_candidate,
                        std::unordered_set<const NodeDef*>* expanded_nodes) {
  std::queue<const NodeDef*> to_visit;
  for (const NodeDef* starting_node : *expanded_nodes) {
    to_visit.push(starting_node);
  }
  expanded_nodes->clear();
  while (!to_visit.empty()) {
    const NodeDef* current_node = to_visit.front();
    to_visit.pop();
    if (!expanded_nodes->insert(current_node).second) {
      // We already visited this node
      continue;
    }
    if (collect_inputs) {
      // Add inputs and outputs to this subgraph if they are candidates
      for (const string& input_name_raw : current_node->input()) {
        const NodeDef* input_node = node_map.GetNode(input_name_raw);
        // 如果当前节点的输入节点未被访问并且满足候选条件加入队列等待访问
        if (expanded_nodes->count(input_node) == 0 &&
            is_candidate(*input_node)) {
          to_visit.push(input_node);
        }
      }
    }
    if (collect_outputs) {
      for (const NodeDef* output : node_map.GetOutputs(current_node->name())) {
        // 若当前节点的输出节点未被访问且满足候选条件,加入队列等待访问
        if (expanded_nodes->count(output) == 0 && is_candidate(*output)) {
          to_visit.push(output);
        }
      }
    }
  }
}

GetOpGroupsToRecompute函数

函数作用:根据候选重计算节点,找到需要真正重计算的算子子图组。

原理:只有子图的输出连接到目标节点(如梯度节点),才有效进行重计算。

所以这个函数在扩展每个候选节点的子图后,检查这个子图是否包含输出连接到目标节点:

  • 如果包含,即子图影响到目标节点,那么这个子图就有可能通过重计算真正节省内存
  • 将这个子图的信息(RecomputedSubGraph)加入结果中,作为后续优化参考

函数过程:

调用FindCandidateRecomputeNodes函数获取重计算候选节点

遍历候选节点作为种子节点,为每个节点建立一个RecomputedSubGraph(RSG)对象(包含重计算节点集合和目标节点集合)

调用connected_subgraph函数,基于种子节点递归扩展子图

对扩展后的节点集合进行遍历,对于当前种子节点的RSG对象,如果扩展后的节点的输出是目标节点,则将输出加入RSG的目标节点集合(获取直接输出至目标节点的节点),并且将记录这些节点作为需要重计算的起始节点(放在RSG的recomputed_source_nodes中)

再次调用connected_subgraph,只收集起始节点前驱组成的子图

如果子图包含目标节点输出,则将此RSG加入结果中

struct RecomputedSubGraph {
  std::unordered_set<const NodeDef*> recomputed_source_nodes;
  std::unordered_set<NodeDef*> target_nodes;
};

// Find groups of ops to recompute together based on `should_recompute`.
// 根据候选重计算节点,找到需要真正重计算的算子子图组。
std::vector<RecomputedSubGraph> GetOpGroupsToRecompute(
    const GraphDef* graph, const NodeMap& node_map,
    const std::function<bool(const NodeDef&)>& should_recompute,
    const std::function<bool(const NodeDef&)>& is_target) {
  std::unordered_set<const NodeDef*> visited_nodes;
  std::vector<RecomputedSubGraph> subgraphs_to_recompute;
  std::unordered_set<const NodeDef*> candidate_recompute_nodes =
      FindCandidateRecomputeNodes(node_map, graph, should_recompute, is_target);
  // 获取并遍历候选节点
  for (const NodeDef* recompute_node : candidate_recompute_nodes) {
    // count = 1访问过则忽略  0未访问过
    if (visited_nodes.count(recompute_node) > 0) {
      continue;
    }
    // RecomputedSubGraph子图包含重计算节点和目标节点
    RecomputedSubGraph current_recomputation;
    // Build out recomputation groups by expanding to inexpensive-to-recompute
    // nodes which do not feed target nodes. The goal is to capture some
    // intermediate activations within this graph.
    std::unordered_set<const NodeDef*> unpruned_recompute_nodes;
    unpruned_recompute_nodes.insert(recompute_node);
    connected_subgraph(node_map,
                       true,  // Collect inputs
                       true,  // Collect outputs
                       should_recompute, &unpruned_recompute_nodes);
    visited_nodes.insert(unpruned_recompute_nodes.begin(),
                         unpruned_recompute_nodes.end());
    for (const NodeDef* unpruned_recompute_node : unpruned_recompute_nodes) {
      bool inserted_feed = false;
      for (NodeDef* output :
           node_map.GetOutputs(unpruned_recompute_node->name())) {
        if (is_target(*output)) {
          current_recomputation.target_nodes.insert(output);
          if (!inserted_feed) {
            // Keep track of nodes which feed directly into a target node. These
            // and nodes which feed into them will define the recomputed
            // subgraph.
            current_recomputation.recomputed_source_nodes.insert(
                unpruned_recompute_node);
            inserted_feed = true;
          }
        }
      }
    }
    // Recompute only nodes which eventually feed into a target node.
    // 仅重新计算最终馈送到目标节点的节点。
    connected_subgraph(
        node_map,
        true,   // Collect inputs
        false,  // Collect outputs
        [&unpruned_recompute_nodes](const NodeDef& node) {
          return unpruned_recompute_nodes.count(&node) != 0;
        },
        &current_recomputation.recomputed_source_nodes);
    if (current_recomputation.target_nodes.empty()) {
      continue;
    }
    subgraphs_to_recompute.push_back(current_recomputation);
  }
  return subgraphs_to_recompute;
}

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值