提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
文章目录
前言
官网文档:https://mlir.llvm.org/docs/Tutorials/
Toy的见:https://blog.csdn.net/CatkinLX/article/details/125166135
一、创建Dialect(Foo)
文件目录:
- mlir/include/mlir/Dialect/Foo (for public include files),ODS TableGen的文件
- mlir/lib/Dialect/Foo (for sources)
- mlir/lib/Dialect/Foo/IR (for operations)
- mlir/lib/Dialect/Foo/Transforms (for transforms),rewrite rules, DDR
- mlir/test/Dialect/Foo (for tests)
编译
在Dialect的operations都是用ODS格式声明的,在.td
文件中,编译时用add_mlir_dialect
声明。生成的MLIRFooOpsIncGen
可以用于声明依赖。
transformation的编译:
set(LLVM_TARGET_DEFINITIONS FooTransforms.td) // LLVM风格
mlir_tablegen(FooTransforms.h.inc -gen-rewriters)
add_public_tablegen_target(MLIRFooTransformIncGen) // 生成另一个IncGen
如果Dialect有许多库的时候,使用add_mlir_dialect_library
声明:
add_mlir_dialect_library(MLIRFoo
DEPENDS // TableGen生成的头文件的依赖
MLIRFooOpsIncGen
MLIRFooTransformsIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC // 对其他的Dialect的库的依赖
MLIRBar
<some-other-library>
)
这里的add_mlir_dialect_library是add_llvm_library的一个wrapper,主要用于link tools(mlir-opt这种),链接了的库由MLIR_DIALECT_LIBS全局属性得到:get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
。
Dialect转换
从X转成Y,涉及的目录为:
mlir/include/mlir/Conversion/XToY
mlir/lib/Conversion/XToY
mlir/test/Conversion/XToY
编译方式:
add_mlir_conversion_library(MLIRBarToFoo
BarToFoo.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/BarToFoo
LINK_LIBS PUBLIC
MLIRBar
MLIRFoo
)
add_mlir_conversion_library是add_llvm_library的包装,包括了的库有MLIR_CONVERSION_LIBS
二、新增MLIR graph rewrite
MLIR的结构介绍:https://mlir.llvm.org/docs/LangRef/
Operation Definition Specification (ODS):https://mlir.llvm.org/docs/OpDefinitions/
Table-driven Declarative Rewrite Rule (DRR):https://mlir.llvm.org/docs/DeclarativeRewrites/
新增operation
在TableGen(https://llvm.org/docs/TableGen/index.html)中定义一个operation,需要:
- name:unique,一般operation是在Dialect里的,所以命名为Dialect.operation,但是dialect的命名空间一般会都抽象出一个基类。
- traits。特点,用于做校验
- arguments。运行时由其他op得到的input operands(可以命名),编译阶段就知道的attribute(必须命名)。
- results。
- documentation
- dialect。特殊驱动需要的额外信息,
新增Pattern
TableGen
将TF的LeakyReluOp,转成TFLite的。source Pattern匹配了,就直接可以用到result Pattern中。
def : Pat<(TF_LeakyReluOp $arg, F32Attr:$a), (TFL_LeakyReluOp $arg, $a)>
// 如果result Pattern需要额外的方法来解析参数
def createTFLLeakyRelu : NativeCodeCall< "createTFLLeakyRelu($_builder, $0.getDefiningOp(), $1, $2)">;
def : Pat<(TF_LeakyReluOp:$old_value, $arg, F32Attr:$a), (createTFLLeakyRelu $old_value, $arg, $a)>;
static Value createTFLLeakyRelu(PatternRewriter &rewriter, Operation *op, Value operand, Attribute attr) {
return rewriter.create<mlir::TFL::LeakyReluOp>(op->getLoc(), operands[0].getType(), /*arg=*/operands[0],
/*alpha=*/attrs[0].cast<FloatAttr>());
}
注册:
set(LLVM_TARGET_DEFINITIONS <name-of-the-td-file>)
mlir_tablegen(<name-of-the-generated-inc-file> -gen-rewriters)
add_public_tablegen_target(<name-of-the-cmake-target>)
生成的文件中有个opulateWithGenerated( RewritePatternSet &patterns)
方法,可以获取所有Pattern
c++风格matchAndRewrite
简化的
static LogicalResult convertTFLeakyRelu(TFLeakyReluOp op, PatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(op, op->getResult(0).getType(), op->getOperand(0),
/*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
return success();
}
void populateRewrites(RewritePatternSet &patternSet) {
patternSet.add(convertTFLeakyRelu);
}
ODS 也提供了了函数式的转化
// TableGen中let hasCanonicalizeMethod = 1
LogicalResult circt::MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
auto inputs = op.inputs();
APInt value;
// mul(x, c) -> shl(x, log2(c)), where c is a power of two.
if (inputs.size() == 2 && matchPattern(inputs.back(), m_RConstant(value)) && value.isPowerOf2()) {
auto shift = rewriter.create<rtl::ConstantOp>(op.getLoc(), op.getType(),value.exactLogBase2());
auto shlOp = rewriter.create<comb::ShlOp>(op.getLoc(), inputs[0], shift);
rewriter.replaceOpWithNewOp<MulOp>(op, op.getType(), ArrayRef<Value>(shlOp));
return success();
}
return failure();
}
通用的
/// "match" and "rewrite" 分步
struct ConvertTFLeakyRelu : public RewritePattern {
ConvertTFLeakyRelu(MLIRContext *context)
: RewritePattern("tf.LeakyRelu", 1, context) {}
LogicalResult match(Operation *op) const override {
return success();
}
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
op, op->getResult(0).getType(), op->getOperand(0),
/*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
}
};
/// 一步 "matchAndRewrite"
struct ConvertTFLeakyRelu : public RewritePattern {
ConvertTFLeakyRelu(MLIRContext *context)
: RewritePattern("tf.LeakyRelu", 1, context) {}
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
op, op->getResult(0).getType(), op->getOperand(0),
/*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
return success();
}
};
三、理解IR结构
high level的MLIR:https://mlir.llvm.org/docs/LangRef/
Traversing the IR Nesting
一个operation可以属于一个或多个region,每个region是一堆blocks,每个blocks是一堆operations。
对应三个方法:printOperation
, printRegion
, printBlock
。
// 打印operation层面的属性
void printOperation(Operation *op) {
printIndent() << "visiting op: '" << op->getName() << "' with " << op->getNumOperands() << " operands and " << op->getNumResults() << " results\n";
// Print the operation attributes
if (!op->getAttrs().empty()) {
printIndent() << op->getAttrs().size() << " attributes:\n";
for (NamedAttribute attr : op->getAttrs())
printIndent() << " - '" << attr.first << "' : '" << attr.second << "'\n";
}
// Recurse into each of the regions attached to the operation.
printIndent() << " " << op->getNumRegions() << " nested regions:\n";
auto indent = pushIndent();
for (Region ®ion : op->getRegions())
printRegion(region);
}
void printRegion(Region ®ion) {
// A region does not hold anything by itself other than a list of blocks.
printIndent() << "Region with " << region.getBlocks().size() << " blocks:\n";
auto indent = pushIndent();
for (Block &block : region.getBlocks())
printBlock(block);
}
void printBlock(Block &block) {
// Print the block intrinsics properties (basically: argument list)
printIndent()
<< "Block with " << block.getNumArguments() << " arguments, "
<< block.getNumSuccessors()
<< " successors, and "
// Note, this `.size()` is traversing a linked-list and is O(n).
<< block.getOperations().size() << " operations\n";
// A block main role is to hold a list of Operations: let's recurse into
// printing each operation.
auto indent = pushIndent();
for (Operation &op : block.getOperations())
printOperation(&op);
}
Filtered Iterator
getOps<OpTy>()
,block返回的是operation的iterator,region返回的是block的iterator
walk
,对block或者region里的operation的callback操作
getFunction().walk([&](mlir::Operation *op) {
// process Operation `op`.
});
// 或者指定特殊类型的op
getFunction().walk([](LinalgOp linalgOp) {
// process LinalgOp `linalgOp`.
return WalkResult::interrupt(); // 退出walk
});
Traversing the def-use chains
IR中的代表Value,有两种类型:BlockArgument和Operation的结果。operation的输出,分别都是一个value。value的使用者是operation,每一个operation里面的参数都是一个value。
// 打印value对应的来源类型
for (Value operand : op->getOperands()) {
if (Operation *producer = operand.getDefiningOp()) {
// 这种是operation的结果
llvm::outs() << " - Operand produced by operation '" << producer->getName() << "'\n";
} else {
// If there is no defining op, the Value is necessarily a Block argument.
auto blockArg = operand.cast<BlockArgument>();
llvm::outs() << " - Operand produced by Block argument, number " << blockArg.getArgNumber() << "\n";
}
}
// 这里的value是每个op的输出
llvm::outs() << "Has " << op->getNumResults() << " results:\n";
for (auto indexedResult : llvm::enumerate(op->getResults())) {
Value result = indexedResult.value();
llvm::outs() << " - Result " << indexedResult.index();
if (result.use_empty()) {
llvm::outs() << " has no uses\n";
continue;
}
if (result.hasOneUse()) {
llvm::outs() << " has a single use: ";
} else {
llvm::outs() << " has "
<< std::distance(result.getUses().begin(),
result.getUses().end())
<< " uses:\n";
}
for (Operation *userOp : result.getUsers()) {
llvm::outs() << " - " << userOp->getName() << "\n";
}
}
下图是block和operation之间的关系,其中可以看到:
- block中存着一堆operation
- operation中存着结果和运算
- value会标记谁用他