Paddle build_cinn_pass_test源码阅读(fluid目录下)

代码位置在 paddle\fluid\framework\paddle2cinn\build_cinn_pass_test.cc ,因为paddle CINN和PIR部分依旧在高频更新,所以各位看到的可能和我的不一样

inline bool CheckNodeExisted(const std::unordered_set<Node*>& nodes,
                             const std::string& op_name) {
  return std::find_if(nodes.begin(), nodes.end(), [&op_name](const Node* node) {
           return node->Name() == op_name;
         }) != nodes.end();
}

用一个内联函数, 去看一个 unordered_set (一系列节点) 中是否有某个 node 的名字是 op_name,用 std::find_if 去实现, 第三个参数传入的是匿名函数。[&op_name] 闭包被定义在Lambda表达式声明中的方括号[]内. 这个机制允许这些变量被按值或按引用捕获.

函数匿名函数的闭包可以参考这篇文章: https://www.cnblogs.com/pzhfei/archive/2013/01/14/lambda_expression.html

接下来就是返回名字为 op_namenode 数量

inline int CountNode(const std::unordered_set<Node*>& nodes,
                     const std::string& op_name) {
  return std::count_if(
      nodes.begin(), nodes.end(), [&op_name](const Node* node) {
        return node->Name() == op_name;
      });
}

接下来是返回节点名字是 op_name 的 节点,注意 std::find_if 前面为啥有 * 呢,因为 find_if 返回一个迭代器, *迭代器 可以返回一个 Node*

inline Node* GetNode(const std::unordered_set<Node*>& nodes,
                     const std::string& op_name) {
  return *std::find_if(
      nodes.begin(), nodes.end(), [&op_name](const Node* node) {
        return node->Name().find(op_name) != std::string::npos;
      });
}

CheckGraphIndependence 内部定义了一个 check_node_ok 匿名函数,匿名函数中 n1n2 都是节点 Node 的指针,
( 说明一下,Paddle PIR之前的节点,节点既有 Op, 也有 Var )
只有 n1n2 一个为 OP, 一个为 Var 才有可能返回 true;

inline bool CheckGraphIndependence(const std::unordered_set<Node*>& nodes) {
  auto check_node_ok = [&nodes](Node* n1, Node* n2) -> bool {
    if (n1->IsOp() && !n2->IsVar()) {
      return false;
    }
    if (n1->IsVar() && !n2->IsOp()) {
      return false;
    }
    if (nodes.count(n2) == 0) {
      return false;
    }
    return true;
  };

  for (auto node : nodes) {
    for (auto in : node->inputs) {
      if (!check_node_ok(node, in)) {
        return false;
      }
    }
    for (auto out : node->outputs) {
      if (!check_node_ok(node, out)) {
        return false;
      }
    }
  }
  return true;
}

这里需要说明一下,由于 Paddle pir之前 Op 和 Var 都是node, 所以这样定义

var1 -> op1 -> var2
op3-> var3 -> op4

op1的输入是 var1,输出是 var2,而下边那一行是
va3 的输入是 op3,var3 的输出是 op4 , 这样写有点儿诡异,不过确实是这样定义的

所以 CheckGraphIndependence 的用法就是,首先检查是不是 op->varvar->op 的关系,其次就是看当前 op/var 在不在当前 Graph 的 unordered_set<Node*>

可以看到之后的调用就是将计算图的节点 g->Nodes() 传入 CheckGraphIndependence,如果返回值不为 True 则报错

  ASSERT_TRUE(CheckGraphIndependence(g->Nodes()));

这个函数主要是将 kCinnLaunchOpoperators::kCompilationKey 属性取出来扔到 compilation_keys这个 vector 中, 目前暂时未知有什么用

// Get compilation_key values
std::vector<int64_t> GetCompilationKeys(const Graph& graph) {
  std::vector<int64_t> compilation_keys;
  for (auto& node : graph.Nodes()) {
    if (node->IsOp() && node->Name() == kCinnLaunchOp) {
      compilation_keys.emplace_back(PADDLE_GET_CONST(
          int64_t, node->Op()->GetAttr(operators::kCompilationKey)));
    }
  }
  return compilation_keys;
}

接下来创建一个CINN子图,创建一个空图 Graph, 之后依次添加 op 和 var

std::unique_ptr<Graph> BuildNoCinnSubgraph() {
  ProgramDesc prog;
  auto g = std::make_unique<Graph>(prog);
  // var1 --
  //        | --> fake1 --> var3 --> fake2 --> var4
  // var2 --

  // *Desc 是之后用来创建 OpNode 和 VarNode 的类
  OpDesc fake1_op;
  fake1_op.SetType("fake1");
  OpDesc fake2_op;
  fake2_op.SetType("fake2");

  VarDesc var1("var1");
  VarDesc var2("var2");
  var2.SetPersistable(true);
  var2.SetIsParameter(true);
  VarDesc var3("var3");
  VarDesc var4("var4");
  
  // 之后用 graph 的 Create*Node 来创建对应的 ir::Node
  ir::Node* fake1 = g->CreateOpNode(&fake1_op);
  ir::Node* fake2 = g->CreateOpNode(&fake2_op);

  ir::Node* v1 = g->CreateVarNode(&var1);
  ir::Node* v2 = g->CreateVarNode(&var2);
  ir::Node* v3 = g->CreateVarNode(&var3);
  ir::Node* v4 = g->CreateVarNode(&var4);
  
  // ----------- 创建完 node 之后, 把 op/var 串起来
  // fill op node
  fake1->inputs = {v1, v2};
  fake1->outputs = {v3};
  fake2->inputs = {v3};
  fake2->outputs = {v4};

  // fill variable node
  v1->outputs = {fake1};
  v2->outputs = {fake1};

  v3->inputs = {fake1};
  v3->outputs = {fake2};

  v4->inputs = {fake2};

  return g;
}

接下来出现第一个单测

TEST(BuildCinnPassTest, NoCinnSubgraph) {
  auto g = BuildNoCinnSubgraph();    // 调用上边的函数建计算图
  auto previous_nodes = g->Nodes();  // 取出计算图的节点
  
  // 创建 pass 这个应该是旧IR的pass
  auto pass =
      paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass");
  // g.get() 返回的是图的指针, g是个 unique_ptr 的智能指针
  pass->Apply(g.get());

  // After search, origin graph should no change
  // 注释的意思是, pass search 之后, 原来的计算图不应当修改
  ASSERT_EQ(previous_nodes, g->Nodes());
  ASSERT_TRUE(CheckGraphIndependence(g->Nodes())); // 接下来看计算图是否合法且不依赖其他计算图

  // After search, there should be no cinn subgraph
  ASSERT_TRUE(GetCompilationKeys(*g).empty());  // pass search之后没有 cinn subgraph 子图怎么理解
}

接下来依旧是 BuildAllOpSupportCinnGraph 与上一个建图的函数没啥太大区别

  • 图更加复杂
  • op 的 type 从 fake2 变成了 elementwise_add | mul | relu
std::unique_ptr<Graph> BuildAllOpSupportCinnGraph() {
  ProgramDesc prog;
  auto g = std::make_unique<Graph>(prog);

  // v1 --
  //      | --> mul --> v3 --
  // v2 --                   | --> add --> v5 --> relu --> v6
  //                    v4 --

  OpDesc add_op;
  add_op.SetType("elementwise_add");
  OpDesc mul_op;
  mul_op.SetType("mul");
  OpDesc relu_op;
  relu_op.SetType("relu");

  VarDesc var1("var1");
  VarDesc var2("var2");
  var2.SetPersistable(true);
  var2.SetIsParameter(true);
  VarDesc var3("var3");
  VarDesc var4("var4");
  VarDesc var5("var5");
  VarDesc var6("var6");

  ir::Node* add = g->CreateOpNode(&add_op);
  ir::Node* mul = g->CreateOpNode(&mul_op);
  ir::Node* relu = g->CreateOpNode(&relu_op);

  ir::Node* v0 = g->CreateEmptyNode("var0", Node::Type::kVariable);     // 创建空节点用意是?
  ir::Node* v1 = g->CreateVarNode(&var1);
  ir::Node* v2 = g->CreateVarNode(&var2);
  ir::Node* v3 = g->CreateVarNode(&var3);
  ir::Node* v4 = g->CreateVarNode(&var4);
  ir::Node* v5 = g->CreateVarNode(&var5);
  ir::Node* v6 = g->CreateVarNode(&var6);
  ir::Node* v7 = g->CreateControlDepVar();

  // fill op node
  mul->inputs = {v0, v1, v2};
  mul->outputs = {v3};
  add->inputs = {v3, v4};
  add->outputs = {v5};
  relu->inputs = {v5};
  relu->outputs = {v6, v7};

  // fill variable node
  v0->outputs = {mul};
  v1->outputs = {mul};
  v2->outputs = {mul};

  v3->inputs = {mul};
  v3->outputs = {add};

  v4->outputs = {add};

  v5->inputs = {add};
  v5->outputs = {relu};

  v6->inputs = {relu};
  v7->inputs = {relu};

  return g;
}

上边这个注释有点儿问题:

  // v1 --
  //      | --> mul --> v3 --
  // v2 --                   | --> add --> v5 --> relu --> v6
  //                    v4 --

应该改成:

  // v0 --|
  // v1 --|                  
  // v2 --| --> mul  --> v3 --|
  //                 --> v4 --| --> add  --> v5 --> relu  --> v6
  //                                                      --> v7

接下来的 TEST 和之前的一样,只不过由于图结构变化,pass 之后图结构都变化为 kCinnLaunchOp

TEST(BuildCinnPassTest, AllOpSupportCinn) {
  auto g = BuildAllOpSupportCinnGraph();

  auto pass =
      paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass");
  pass->Apply(g.get());

  // After search, the graph should as following
  // v0 --|
  // v1 --|                   |--> v6
  // v2 --| --> kCinnLaunchOp |--> v7
  // v4 --|
  const auto& nodes = g->Nodes();
  ASSERT_EQ(nodes.size(), static_cast<size_t>(7));      // 节点数为 7, 4个输入, 2个输出 和 1 个 Op 节点
  ASSERT_TRUE(CheckGraphIndependence(nodes));           // 检测该图是否独立,是否会依赖其他图

  // A new op named kCinnLaunchOp should be added
  ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp));  // kCinnLaunchOp 是个常量字符串, 检测节点 vector 中有无 kCinnLaunchOp 
  auto* cinn_op = GetNode(nodes, kCinnLaunchOp);
  auto* v0 = GetNode(nodes, "var0");
  auto* v1 = GetNode(nodes, "var1");                    // 依次获取对应的 var Node 指针
  auto* v2 = GetNode(nodes, "var2");
  auto* v4 = GetNode(nodes, "var4");
  auto* v6 = GetNode(nodes, "var6");
  auto* v7 = GetNode(nodes, Node::kControlDepVarName);
  
  // 查看 cinn_op 的输入输出是否与 `v0, v1, v2, v4` 和 `v6, v7` 对应
  ASSERT_EQ(
      std::unordered_set<Node*>(cinn_op->inputs.begin(), cinn_op->inputs.end()),
      std::unordered_set<Node*>({v0, v1, v2, v4}));
  ASSERT_EQ(std::unordered_set<Node*>(cinn_op->outputs.begin(),
                                      cinn_op->outputs.end()),
            std::unordered_set<Node*>({v6, v7}));
  
  // 查看 var 节点的输入输出是否是 cinn_op 
  ASSERT_EQ(v1->outputs, std::vector<Node*>({cinn_op}));
  ASSERT_EQ(v6->inputs, std::vector<Node*>({cinn_op}));

  // previous op (mul, add, relu) should all removed
  // 由于 mul/elementwise_add/relu 被整体合并为 cinn_op 所以图中不应该被搜索到
  ASSERT_FALSE(CheckNodeExisted(nodes, "mul"));
  ASSERT_FALSE(CheckNodeExisted(nodes, "elementwise_add"));
  ASSERT_FALSE(CheckNodeExisted(nodes, "relu"));

  // After search, there should has just one cinn subgraph
  // feed --> v1 --
  //               | --> mul --> v3 --
  // feed --> v2 --                   | --> add --> v5 --> relu --> v6 --> fetch
  //                    feed --> v4 --
  
  // 获取编译完毕之后的 key, 之后会根据 key 去取对应的 subgraph 
  auto compilation_keys = GetCompilationKeys(*g);
  ASSERT_EQ(compilation_keys.size(), static_cast<size_t>(1));  // 因为只有一个 kCinnLaunchOp 所以 key 的数量也为 1 
  auto* cinn_compiler = CinnCompiler::GetInstance();
  const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]);  // 根据 key 拿对应的子图

  const auto& subnodes = subgraph.Nodes();             // 拿子图的节点set
  ASSERT_EQ(subnodes.size(), static_cast<size_t>(13));
  ASSERT_TRUE(CheckGraphIndependence(subnodes));

  // 该 cinn op 就是这三 mul | elementwise_add | relu 的合体
  ASSERT_TRUE(CheckNodeExisted(subnodes, "mul"));
  ASSERT_TRUE(CheckNodeExisted(subnodes, "elementwise_add"));
  ASSERT_TRUE(CheckNodeExisted(subnodes, "relu"));
  ASSERT_EQ(CountNode(subnodes, "feed"), 3);   // 上边注释有 3个feed Op
  ASSERT_EQ(CountNode(subnodes, "fetch"), 1);  // 1 个 fetch Op
  
  // 在 kCinnLaunchOp 中有参和无参的 node 都应当有 feed Op 
  // No-parameter input should has feed op
  auto new_v1 = GetNode(subnodes, "var1");
  ASSERT_EQ(new_v1->inputs.size(), static_cast<size_t>(1));
  ASSERT_EQ(new_v1->outputs.size(), static_cast<size_t>(1));
  ASSERT_EQ(new_v1->inputs[0]->Name(), "feed");
  ASSERT_EQ(new_v1->outputs[0]->Name(), "mul");

  // Parameter input should also have the feed op
  auto new_v2 = GetNode(subnodes, "var2");
  ASSERT_EQ(new_v2->inputs.size(), static_cast<size_t>(1));
  ASSERT_EQ(new_v2->inputs[0]->Name(), "feed");
  ASSERT_EQ(new_v2->outputs.size(), static_cast<size_t>(1));
  ASSERT_EQ(new_v2->outputs[0]->Name(), "mul");

  // kCinnLaunchOp 输出中应当有 fetch Op
  // output should has fetch op
  auto new_v6 = GetNode(subnodes, "var6");
  ASSERT_EQ(new_v6->inputs.size(), static_cast<size_t>(1));
  ASSERT_EQ(new_v6->outputs.size(), static_cast<size_t>(1));
  ASSERT_EQ(new_v6->inputs[0]->Name(), "relu");
  ASSERT_EQ(new_v6->outputs[0]->Name(), "fetch");
}

第一个单测是只有 fake Op 没办法 pass 优化,第二个单测是所有Op 都支持 CINN Pass, 那下一个就是一半是 fake Op,另一半是 只是 CINN Pass 的 OP

std::unique_ptr<Graph> BuildGraphWithOneCinnSubgraph() {
  ProgramDesc prog;
  auto g = std::make_unique<Graph>(prog);

  // fake1 --> v1 --
  //                | --> mul --> v3 --> relu --> v4 --> fake2
  //           v2 --

  OpDesc fake1_op;
  fake1_op.SetType("fake1");
  OpDesc mul_op;
  mul_op.SetType("mul");
  OpDesc relu_op;
  relu_op.SetType("relu");
  OpDesc fake2_op;
  fake2_op.SetType("fake2");

  VarDesc var1("var1");
  VarDesc var2("var2");
  var2.SetPersistable(true);
  var2.SetIsParameter(true);
  VarDesc var3("var3");
  VarDesc var4("var4");

  ir::Node* fake1 = g->CreateOpNode(&fake1_op);
  ir::Node* mul = g->CreateOpNode(&mul_op);
  ir::Node* relu = g->CreateOpNode(&relu_op);
  ir::Node* fake2 = g->CreateOpNode(&fake2_op);

  ir::Node* v1 = g->CreateVarNode(&var1);
  ir::Node* v2 = g->CreateVarNode(&var2);
  ir::Node* v3 = g->CreateVarNode(&var3);
  ir::Node* v4 = g->CreateVarNode(&var4);

  // fill op node
  fake1->outputs = {v1};
  mul->inputs = {v2, v1};
  mul->outputs = {v3};
  relu->inputs = {v3};
  relu->outputs = {v4};
  fake2->inputs = {v4};

  // fill variable node
  v2->outputs = {mul};

  v1->inputs = {fake1};
  v1->outputs = {mul};

  v3->inputs = {mul};
  v3->outputs = {relu};

  v4->inputs = {relu};
  v4->outputs = {fake2};

  return g;
}

上边的函数就是建立了一个这样的一个图

  // fake1 --> v1 --
  //                | --> mul --> v3 --> relu --> v4 --> fake2
  //           v2 --

通过 cinn pass 之后这个图的节点变成下边儿这样:

  // fake1 --> v1 --
  //                | --> kCinnLaunchOp --> v4 --> fake2
  //           v2 --

只有一个 kCinnLaunchOp 其子图为,有9个节点

  // feed --> v1 --
  //               | --> mul --> v3 --> relu --> v4 --> fetch
  // feed --> v2 --

之前的图是单个 cinn op,下一个单测是多个 cinn op 的情况:

std::unique_ptr<Graph> BuildGraphWithMultiCinnSubgraph() {
  ProgramDesc prog;
  auto g = std::make_unique<Graph>(prog);

  // fake1 --> v1 --
  //                | --> mul --> v3 --> fake2 --> v4 --> relu --> v5 --> fake3
  //           v2 --

  OpDesc fake1_op;
  fake1_op.SetType("fake1");
  OpDesc mul_op;
  mul_op.SetType("mul");
  OpDesc relu_op;
  relu_op.SetType("relu");
  OpDesc fake2_op;
  fake2_op.SetType("fake2");
  OpDesc fake3_op;
  fake3_op.SetType("fake3");

  VarDesc var1("var1");
  VarDesc var2("var2");
  var2.SetPersistable(true);
  var2.SetIsParameter(true);
  VarDesc var3("var3");
  VarDesc var4("var4");
  VarDesc var5("var5");

  ir::Node* fake1 = g->CreateOpNode(&fake1_op);
  ir::Node* mul = g->CreateOpNode(&mul_op);
  ir::Node* relu = g->CreateOpNode(&relu_op);
  ir::Node* fake2 = g->CreateOpNode(&fake2_op);
  ir::Node* fake3 = g->CreateOpNode(&fake3_op);

  ir::Node* v1 = g->CreateVarNode(&var1);
  ir::Node* v2 = g->CreateVarNode(&var2);
  ir::Node* v3 = g->CreateVarNode(&var3);
  ir::Node* v4 = g->CreateVarNode(&var4);
  ir::Node* v5 = g->CreateVarNode(&var5);

  // fill op node
  fake1->outputs = {v1};
  mul->inputs = {v2, v1};
  mul->outputs = {v3};
  fake2->inputs = {v3};
  fake2->outputs = {v4};
  relu->inputs = {v4};
  relu->outputs = {v5};
  fake3->inputs = {v5};

  // fill variable node
  v2->outputs = {mul};

  v1->inputs = {fake1};
  v1->outputs = {mul};

  v3->inputs = {mul};
  v3->outputs = {fake2};

  v4->inputs = {fake2};
  v4->outputs = {relu};

  v5->inputs = {relu};
  v5->outputs = {fake3};

  return g;
}

以上代码建立一个这样的图:

  // fake1 --> v1 --
  //                | --> mul --> v3 --> fake2 --> v4 --> relu --> v5 --> fake3
  //           v2 --

fake2 op 为界,可以建立两个 cinn op pass

  // fake1 -> v1 -
  //              | -> CinnOp -> v3 -> fake2 -> v4 -> CinnOp ->v5 -> fake3
  //          v2 -

cinn pass 就两句代码:

  auto pass =
      paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass");
  pass->Apply(g.get());

此处是检验有两个 cinn pass Op 的代码:

  // A new op named kCinnLaunchOp should be added
  ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp));
  ASSERT_EQ(CountNode(nodes, kCinnLaunchOp), 2);

最后的编译结果是 cinn pass 之后有两个 子图:

  // subgraph1:
  // feed --> v4 --> relu --> v5 --> fetch
  // subgraph2:
  // feed --> v1 --
  //               | --> mul --> v3 --> fetch
  //          v2 --

BuildGraphWithNoNeedBufferInput 就是建立一个这样的子图:

  // fake1 --> v1 --                 --> v4 --> relu_grad --> v6
  //           v2 -- | --> add_grad |
  //           v3 --                 --> v5 --> fake2

BuildGraphWithNoNeedBufferInput 与之前不同的是,add_grad_op 使用了设置输入的 API SetInput

  OpDesc add_grad_op;
  add_grad_op.SetType("elementwise_add_grad");
  add_grad_op.SetInput(::paddle::framework::GradVarName("Out"), {"var1"});
  add_grad_op.SetInput("X", {"var2"});
  add_grad_op.SetInput("Y", {"var3"});

之后的单测写了,no_need_buffer_x 不知道什么意思.

  // A new op named kCinnLaunchOp should be added and
  // its input arguments are set correctly
  ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp));
  ASSERT_EQ(CountNode(nodes, kCinnLaunchOp), 1);
  auto* cinn_op_node = GetNode(nodes, kCinnLaunchOp);
  ASSERT_EQ(cinn_op_node->Op()->Input(operators::kX),
            std::vector<std::string>({"var1"}));
  auto& no_need_buffer_x = cinn_op_node->Op()->Input(operators::kNoNeedBufferX);
  ASSERT_EQ(std::unordered_set<std::string>(no_need_buffer_x.begin(),
                                            no_need_buffer_x.end()),
            std::unordered_set<std::string>({"var2", "var3"}));

这里的 no_need_buffer_feeds 什么意思??

  ASSERT_TRUE(CheckNodeExisted(subnodes, "elementwise_add_grad"));
  ASSERT_TRUE(CheckNodeExisted(subnodes, "relu_grad"));
  ASSERT_EQ(CountNode(subnodes, "feed"), 3);
  ASSERT_EQ(CountNode(subnodes, "fetch"), 2);
  const auto& no_need_buffer_feeds =
      subgraph.Get<std::unordered_set<std::string>>(kNoNeedBufferFeeds);
  ASSERT_EQ(no_need_buffer_feeds.size(), 2);
  ASSERT_EQ(no_need_buffer_feeds,
            std::unordered_set<std::string>({"var2", "var3"}));

  // check the attributes of variable lists are saved correctly
  ASSERT_TRUE(subgraph.Has(kInputVars));
  EXPECT_EQ(subgraph.Get<std::vector<std::string>>(kInputVars),
            std::vector<std::string>({"var1"}));
  ASSERT_TRUE(subgraph.Has(kInternalVars));
  EXPECT_EQ(subgraph.Get<std::vector<std::string>>(kInternalVars),
            std::vector<std::string>({"var4"}));
  ASSERT_TRUE(subgraph.Has(kOutputVars));
  const auto& output_vars = subgraph.Get<std::vector<std::string>>(kOutputVars);
  EXPECT_EQ(
      std::unordered_set<std::string>(output_vars.begin(), output_vars.end()),
      std::unordered_set<std::string>({"var5", "var6"}));
TEST(BuildCinnPassTest, TestSkipGcVars){
  auto g = BuildGraphWithOneCinnSubgraph();
  
  // 这里什么意思????
  std::unordered_set<std::string> all_skip_gc_vars = {"var1", "var3"};
  g->SetNotOwned(kSkipGcVarNames, &all_skip_gc_vars);

  auto pass =
      paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass");
  pass->Apply(g.get());

  // After search, the graph should as following
  // fake1 --> v1 --
  //                | --> kCinnLaunchOp --> v4 --> fake2
  //           v2 --
  const auto& nodes = g->Nodes();
  ASSERT_EQ(nodes.size(), static_cast<size_t>(7));  // 这里为啥变成了 7
  ASSERT_TRUE(CheckGraphIndependence(nodes));

  // A new op named kCinnLaunchOp should be added
  ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp));

  // After search, there should has just one cinn subgraph
  // Note v3 has fetched because of v3 in kSkipGcVarNames
  // And v1 is a feed var so v1 no need fetched though it in kSkipGcVarNames
  // feed --> v1 --
  //               | --> mul --> v3 --> relu --> v4 --> fetch
  // feed --> v2 --                 --> fetch
  auto compilation_keys = GetCompilationKeys(*g);
  ASSERT_EQ(compilation_keys.size(), static_cast<size_t>(1));
  auto* cinn_compiler = CinnCompiler::GetInstance();
  const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]);

  const auto& subnodes = subgraph.Nodes();
  ASSERT_EQ(subnodes.size(), static_cast<size_t>(10));
  ASSERT_TRUE(CheckGraphIndependence(subnodes));

  ASSERT_EQ(CountNode(subnodes, "feed"), 2);
  // var3 and var4 should has fetch op
  ASSERT_EQ(CountNode(subnodes, "fetch"), 2);
}

最后两个 TEST 没看懂,留下问题

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值