项目里有这样的一个需求,后端要求将特定模式的的算子融合为CustomCall传下去。
例如下面的HLO计算图,需要将省略部分之间的算子融合为CustomCall。
分析
图合并的背景
前端有个X86的后端模拟器(后面简称模拟器),需要将前端的图跑在模拟器上,来验证真正芯片后端的准确性。
在上面的融合需求提出之后,带有CustomCall的图显然不能够直接运行在模拟器上,除非针对此CustomCall增加特定的展开pass,这样就做了很多重复性的工作,而且如果再有类似的CustomCall需求,同样会面临这个问题。
融合之前
// 主图
HloModule MainGraph, entry_computation_layout=...
%AddComputation.14 (x.15: f32[], y.16: f32[]) -> f32[] {
%x.15 = f32[] parameter(0)
%y.16 = f32[] parameter(1)
ROOT %add.17 = f32[] add(f32[] %x.15, f32[] %y.16)
}
// 主图的entry_computation
ENTRY %MainGraph ... {
// 省略...
%p0.2 = f32[10]{0} parameter(4)
%p1.10 = f32[10]{0} parameter(7)
%add = f32[10]{0} add(f32[10]{0} %p0.2, f32[10]{0} %p1.10)
%constant.12 = f32[] constant(0)
%constant.13 = f32[10,5]{1,0} constant(1)
%reduce.18 = f32[10]{0} reduce(f32[10,5]{1,0} %constant.13, f32[] %constant.12), dimensions={1}, to_apply=%AddComputation.14
%multiply.40 = f32[10]{0} multiply(f32[10]{0} %add, f32[10]{0} %reduce.18)
// 省略...
}
融合之后
融合为CustomCall之后,会得到如下的形式:
// 主图
HloModule SyncTensorsGraph.10, entry_computation_layout=...
// 主图的entry_computation
ENTRY %SyncTensorsGraph.10 ... {
// 省略...
%custom=f32[10]{0} custom-call(%constant.89, %constant.92), custom_call_target="MyCustomCall", backend_config=""
// 省略...
}
思路
在前端融合CustomCall时,将融合的多个算子抽象为一个HloModule,并作为字符串从CustomCall的属性传出去,这里利用backend_config属性。例如,将上述算子抽象成如下HloModule,其中,HloModule的输入参数为CustomCall的operands,HloModule的返回值为CustomCall的返回值:
HloModule SyncTensorsGraph.42, entry_computation_layout={(f32[10]{0},f32[10]{0})->f32[10]{0}}
%AddComputation.14 (x.15: f32[], y.16: f32[]) -> f32[] {
%x.15 = f32[] parameter(0)
%y.16 = f32[] parameter(1)
ROOT %add.17 = f32[] add(f32[] %x.15, f32[] %y.16)
}
ENTRY %SyncTensorsGraph.42 (p0.2: f32[10], p1.10: f32[10]) -> f32[10] {
%p0.2 = f32[10]{0} parameter(0)
%p1.10 = f32[10]{0} parameter(1)
%add = f32[10]{0} add(f32[10]{0} %p0.2, f32[10]{0} %p1.10)
%constant.12 = f32[] constant(0)
%constant.13 = f32[10,5]{1,0} constant(1)
%reduce.18 = f32[10]{0} reduce(f32[10,5]{1,0} %constant.13, f32[] %constant.12), dimensions={1}, to_apply=%AddComputation.14
ROOT %multiply.40 = f32[10]{0} multiply(f32[10]{0} %add, f32[10]{0} %reduce.18)
}
合并前的图
体现在主图中,如下:
HloModule Test
ENTRY Test {
%p0.1 = f32[10]{0} parameter(0)
%constant.89 = f32[10]{0} constant(1)
%constant.92 = f32[10]{0} constant(2)
%custom = f32[10]{0} custom-call(%constant.89, %constant.92), custom_call_target="MyCustomCall", backend_config="
HloModule SyncTensorsGraph.42, entry_computation_layout={(f32[10]{0},f32[10]{0})->f32[10]{0}}
%AddComputation.14 (x.15: f32[], y.16: f32[]) -> f32[] {
%x.15 = f32[] parameter(0)
%y.16 = f32[] parameter(1)
ROOT %add.17 = f32[] add(f32[] %x.15, f32[] %y.16)
}
ENTRY %SyncTensorsGraph.42 (p0.2: f32[10], p1.10: f32[10]) -> f32[10] {
%p0.2 = f32[10]{0} parameter(0)
%p1.10 = f32[10]{0} parameter(1)
%add = f32[10]{0} add(f32[10]{0} %p0.2, f32[10]{0} %p1.10)
%constant.12 = f32[] constant(0)
%constant.13 = f32[10,5]{1,0} constant(1)
%reduce.18 = f32[10]{0} reduce(f32[10,5]{1,0} %constant.13, f32[] %constant.12), dimensions={1}, to_apply=%AddComputation.14
ROOT %multiply.40 = f32[10]{0} multiply(f32[10]{0} %add, f32[10]{0} %reduce.18)
}
"
ROOT add = f32[10]{0} add(%custom, %constant.92)
}
合并后的图
注意观察:
- 原图中的CustomCall被重新展开为多个HloInstruction的组合
- 原图的“子图”中的非entry_computation已经被clone到主图中,并正常调用
HloModule Test, entry_computation_layout={(f32[10]{0})->f32[10]{0}}
%AddComputation.14.clone (x.15: f32[], y.16: f32[]) -> f32[] {
%x.15 = f32[] parameter(0)
%y.16 = f32[] parameter(1)
ROOT %add.17 = f32[] add(f32[] %x.15, f32[] %y.16)
}
ENTRY %Test (p0.1: f32[10]) -> f32[10] {
%p0.1 = f32[10]{0} parameter(0)
%constant.89 = f32[10]{0} constant({1, 0, 0, 0, 0, 0, 0, 0, 0, 0})
%constant.92 = f32[10]{0} constant({2, 0, 0, 0, 0, 0, 0, 0, 0, 0})
%add.1 = f32[10]{0} add(f32[10]{0} %constant.89, f32[10]{0} %constant.92)
%constant.1 = f32[10,5]{1,0} constant({...})
%constant = f32[] constant(0)
%reduce = f32[10]{0} reduce(f32[10,5]{1,0} %constant.1, f32[] %constant), dimensions={1}, to_apply=%AddComputation.14.clone
%multiply = f32[10]{0} multiply(f32[10]{0} %add.1, f32[10]{0} %reduce)
ROOT %add = f32[10]{0} add(f32[10]{0} %multiply, f32[10]{0} %constant.92)
}
编码实现
- 整体算法,利用递归实现,启发于复制二叉树算法,这里实际上是复制图
- 由于子图的root instruction和非root instruction需要分别处理,root instruction可以直接作为unique_ptr被替换到主图,这里使用std::variant进行处理,root返回
std::unique_ptr<HloInstruction>
,非root返回HloInstruction*
- 在利用
HloInstructionPtr
对HloInstruction*
进行全局替换时,出现问题,CopyGraph的第二个参数,体会到《C++ Templates - The Complete Guide, 2nd Edition》中 Some Remarks About Programming Style 一节的描述:使用const时,和类型的书写顺序,为何要这样写 - 不同类型的HloInstruction基本覆盖了:普通Binary算子,Constant,parameters,reduce(含有子图调用)
#include "tensorflow/compiler/plugin/ipu/driver/passes/ipu_custom_call_expander_pass.h"
#include <algorithm>
#include <functional>
#include <iterator>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/overflow_util.h"
#include "tensorflow/compiler/xla/permutation_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/statusor.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/lib/statusor.h"
// 利用HloParser将字符串解析为HloModule
// bazel BUILD文件中deps,需要增加对应依赖:"//tensorflow/compiler/xla/service:hlo_parser"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
namespace xla {
namespace {
using HloInstructionUniquePtr = std::unique_ptr<HloInstruction>;
using HloInstructionPtr = HloInstruction*;
std::variant<HloInstructionPtr, HloInstructionUniquePtr>
CopyGraph(HloComputation* parent,
HloInstruction const* root,
const HloInstruction::InstructionVector& cc_operands,
std::vector<HloComputation*> *const non_entry_comps,
bool is_root_instr=false){
auto opcode = root->opcode();
VLOG(2) << "opcode " << opcode;
HloInstructionUniquePtr cloned_instr;
switch (opcode)
{
case HloOpcode::kMultiply:
case HloOpcode::kAdd:
{
cloned_instr = std::move(HloInstruction::CreateBinary(
root->shape(),
opcode,
std::get<HloInstructionPtr>(CopyGraph(parent, root->operand(0), cc_operands, non_entry_comps)),
std::get<HloInstructionPtr>(CopyGraph(parent, root->operand(1), cc_operands, non_entry_comps))
));
// HloInstruction::HloInstruction has protected access, can not be constructed here
// new_instr = new HloInstruction(opcode, root->shape());
// new_instr->AppendOperand(const_cast<HloInstructionPtr>(root->operand(0)));
// new_instr->AppendOperand(const_cast<HloInstructionPtr>(root->operand(1)));
}
break;
// 图的输入参数,其实就是CustomCall的operands,且按顺序一一对应
case HloOpcode::kParameter:
return cc_operands[root->parameter_number()];
case HloOpcode::kConstant:
{
auto& literal = const_cast<Literal&>(root->literal());
cloned_instr = std::move(HloInstruction::CreateConstant(std::move(literal)));
break;
}
case HloOpcode::kReduce:
{
// tensorflow/compiler/xla/service/hlo_instruction.cc:1569#HloInstruction::CreateReduce
// todo: 多个operand/init_value情况
auto* root_nc = const_cast<HloInstructionPtr>(root);
auto* reduce = dynamic_cast<HloReduceInstruction*>(root_nc);
HloComputation* comp = nullptr;
VLOG(1) << non_entry_comps->size();
for(auto *cp: *non_entry_comps){
// 取出clone的同名HloComputation
if(cp->name() == reduce->to_apply()->name() + ".clone"){
VLOG(3) << cp->name() << " vs " << reduce->to_apply()->name();
comp = cp;
}
}
VLOG(1) << comp->ToString();
cloned_instr = std::move(HloInstruction::CreateReduce(
root->shape(),
std::get<HloInstructionPtr>(CopyGraph(parent, root->operand(0), cc_operands, non_entry_comps)),
std::get<HloInstructionPtr>(CopyGraph(parent, root->operand(1), cc_operands, non_entry_comps)),
//operand(1) and init_values()[0] both are OK
// std::get<HloInstructionPtr>(CopyGraph(parent, reduce->init_values()[0], cc_operands, non_entry_comps)),
reduce->dimensions(),
comp
));
break;
}
default:
VLOG(1) << "unsupported opcode: " << opcode;
break;
}
VLOG(1) << cloned_instr->parent() << ", " << cloned_instr->ToString();
CHECK(cloned_instr && "cloned_instr is not assigned");
if(!is_root_instr){
// fetch first, empty when moved
auto raw = cloned_instr.get();
// 需要添加到HloComputation
parent->AddInstruction(std::move(cloned_instr));
return raw;
}
return std::move(cloned_instr);
}
} // namespace
namespace ipu {
StatusOr<bool> IpuCustomCallExpanderPass::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
XLA_VLOG_LINES(2, "IpuCustomCallExpanderPass::Run(), before:\n" + module->ToString());
bool changed = false;
for (auto* computation : module->MakeNonfusionComputations()) {
for (auto instruction : computation->MakeInstructionPostOrder()) {
if (instruction->IsDead()) {
continue;
}
if(HloOpcode::kParameter == instruction->opcode()){
VLOG(1) << "instruction->parameter_number() " << instruction->parameter_number();
}
if(HloOpcode::kCustomCall == instruction->opcode()){
HloCustomCallInstruction* cc = xla::Cast<xla::HloCustomCallInstruction>(instruction);
// custom-call backend_config to HloModule
auto sm = xla::ParseAndReturnUnverifiedModule(cc->opaque());
auto cm = std::move(sm.value());
HloComputation* target_comp = nullptr;
// per CustomCall
std::vector<HloComputation*> non_entry_comps;
// collect all non entry computations and clone all non entry computations to outer module
for (auto* comp : cm->MakeNonfusionComputations()) {
if(!comp->IsEntryComputation()){
auto cp = comp->Clone();
non_entry_comps.push_back(cp.get());
module->AddEmbeddedComputation(std::move(cp));
}else{
target_comp = comp;
}
}
// HloComputation has the same params as custom-call, use them to replace its parameters
auto new_root = std::get<HloInstructionUniquePtr>(CopyGraph(computation, target_comp->root_instruction(), cc->operands(), &non_entry_comps, true));
TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(instruction, std::move(new_root)));
changed = true;
}
}
}
XLA_VLOG_LINES(2, "IpuCustomCallExpanderPass::Run(), after:\n" + module->ToString());
return changed;
}
} // namespace ipu
} // namespace xla
todo
逐算子支持,当前仅支持了几个典型算子;