MLIR Tutorials

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档


前言

官网文档:https://mlir.llvm.org/docs/Tutorials/
Toy的见:https://blog.csdn.net/CatkinLX/article/details/125166135


一、创建Dialect(Foo)

文件目录:

  • mlir/include/mlir/Dialect/Foo (for public include files),ODS TableGen的文件
  • mlir/lib/Dialect/Foo (for sources)
  • mlir/lib/Dialect/Foo/IR (for operations)
  • mlir/lib/Dialect/Foo/Transforms (for transforms),rewrite rules, DDR
  • mlir/test/Dialect/Foo (for tests)

编译

在Dialect的operations都是用ODS格式声明的,在.td文件中,编译时用add_mlir_dialect声明。生成的MLIRFooOpsIncGen可以用于声明依赖。
transformation的编译:

set(LLVM_TARGET_DEFINITIONS FooTransforms.td) // LLVM风格
mlir_tablegen(FooTransforms.h.inc -gen-rewriters)
add_public_tablegen_target(MLIRFooTransformIncGen) // 生成另一个IncGen

如果Dialect有许多库的时候,使用add_mlir_dialect_library声明:

add_mlir_dialect_library(MLIRFoo
  DEPENDS // TableGen生成的头文件的依赖
  MLIRFooOpsIncGen
  MLIRFooTransformsIncGen

  LINK_COMPONENTS
  Core

  LINK_LIBS PUBLIC // 对其他的Dialect的库的依赖
  MLIRBar
  <some-other-library>
  )

这里的add_mlir_dialect_library是add_llvm_library的一个wrapper,主要用于link tools(mlir-opt这种),链接了的库由MLIR_DIALECT_LIBS全局属性得到:get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)

Dialect转换

从X转成Y,涉及的目录为:
mlir/include/mlir/Conversion/XToY
mlir/lib/Conversion/XToY
mlir/test/Conversion/XToY
编译方式:

add_mlir_conversion_library(MLIRBarToFoo
  BarToFoo.cpp

  ADDITIONAL_HEADER_DIRS
  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/BarToFoo

  LINK_LIBS PUBLIC
  MLIRBar
  MLIRFoo
  )

add_mlir_conversion_library是add_llvm_library的包装,包括了的库有MLIR_CONVERSION_LIBS

二、新增MLIR graph rewrite

MLIR的结构介绍:https://mlir.llvm.org/docs/LangRef/
Operation Definition Specification (ODS):https://mlir.llvm.org/docs/OpDefinitions/
Table-driven Declarative Rewrite Rule (DRR):https://mlir.llvm.org/docs/DeclarativeRewrites/

新增operation

在TableGen(https://llvm.org/docs/TableGen/index.html)中定义一个operation,需要:

  • name:unique,一般operation是在Dialect里的,所以命名为Dialect.operation,但是dialect的命名空间一般会都抽象出一个基类。
  • traits。特点,用于做校验
  • arguments。运行时由其他op得到的input operands(可以命名),编译阶段就知道的attribute(必须命名)。
  • results。
  • documentation
  • dialect。特殊驱动需要的额外信息,

新增Pattern

TableGen

将TF的LeakyReluOp,转成TFLite的。source Pattern匹配了,就直接可以用到result Pattern中。

def : Pat<(TF_LeakyReluOp $arg, F32Attr:$a), (TFL_LeakyReluOp $arg, $a)>

// 如果result Pattern需要额外的方法来解析参数
def createTFLLeakyRelu : NativeCodeCall< "createTFLLeakyRelu($_builder, $0.getDefiningOp(), $1, $2)">;
def : Pat<(TF_LeakyReluOp:$old_value, $arg, F32Attr:$a), (createTFLLeakyRelu $old_value, $arg, $a)>;

static Value createTFLLeakyRelu(PatternRewriter &rewriter, Operation *op, Value operand, Attribute attr) {
  return rewriter.create<mlir::TFL::LeakyReluOp>(op->getLoc(), operands[0].getType(), /*arg=*/operands[0],
      /*alpha=*/attrs[0].cast<FloatAttr>());
}

注册:

set(LLVM_TARGET_DEFINITIONS <name-of-the-td-file>)
mlir_tablegen(<name-of-the-generated-inc-file> -gen-rewriters)
add_public_tablegen_target(<name-of-the-cmake-target>)

生成的文件中有个opulateWithGenerated( RewritePatternSet &patterns)方法,可以获取所有Pattern

c++风格matchAndRewrite

简化的

static LogicalResult convertTFLeakyRelu(TFLeakyReluOp op, PatternRewriter &rewriter) {
  rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(op, op->getResult(0).getType(), op->getOperand(0),
      /*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
  return success();
}

void populateRewrites(RewritePatternSet &patternSet) {
  patternSet.add(convertTFLeakyRelu);
}

ODS 也提供了了函数式的转化

// TableGen中let hasCanonicalizeMethod = 1
LogicalResult circt::MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
  auto inputs = op.inputs();
  APInt value;

  // mul(x, c) -> shl(x, log2(c)), where c is a power of two.
  if (inputs.size() == 2 && matchPattern(inputs.back(), m_RConstant(value)) && value.isPowerOf2()) {
    auto shift = rewriter.create<rtl::ConstantOp>(op.getLoc(), op.getType(),value.exactLogBase2());
    auto shlOp = rewriter.create<comb::ShlOp>(op.getLoc(), inputs[0], shift);
    rewriter.replaceOpWithNewOp<MulOp>(op, op.getType(), ArrayRef<Value>(shlOp));
    return success();
  }
  return failure();
}

通用的

/// "match" and "rewrite" 分步
struct ConvertTFLeakyRelu : public RewritePattern {
 ConvertTFLeakyRelu(MLIRContext *context)
     : RewritePattern("tf.LeakyRelu", 1, context) {}

 LogicalResult match(Operation *op) const override {
   return success();
 }

 void rewrite(Operation *op, PatternRewriter &rewriter) const override {
   rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
       op, op->getResult(0).getType(), op->getOperand(0),
       /*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
 }
};

/// 一步 "matchAndRewrite"
struct ConvertTFLeakyRelu : public RewritePattern {
 ConvertTFLeakyRelu(MLIRContext *context)
     : RewritePattern("tf.LeakyRelu", 1, context) {}

 LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override {
   rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
       op, op->getResult(0).getType(), op->getOperand(0),
       /*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
   return success();
 }
};

三、理解IR结构

high level的MLIR:https://mlir.llvm.org/docs/LangRef/

Traversing the IR Nesting

一个operation可以属于一个或多个region,每个region是一堆blocks,每个blocks是一堆operations。
对应三个方法:printOperation, printRegion, printBlock

// 打印operation层面的属性
 void printOperation(Operation *op) {
   printIndent() << "visiting op: '" << op->getName() << "' with " << op->getNumOperands() << " operands and " << op->getNumResults() << " results\n";
   // Print the operation attributes
   if (!op->getAttrs().empty()) {
     printIndent() << op->getAttrs().size() << " attributes:\n";
     for (NamedAttribute attr : op->getAttrs())
       printIndent() << " - '" << attr.first << "' : '" << attr.second << "'\n";
   }

   // Recurse into each of the regions attached to the operation.
   printIndent() << " " << op->getNumRegions() << " nested regions:\n";
   auto indent = pushIndent();
   for (Region &region : op->getRegions())
     printRegion(region);
 }

void printRegion(Region &region) {
  // A region does not hold anything by itself other than a list of blocks.
  printIndent() << "Region with " << region.getBlocks().size() << " blocks:\n";
  auto indent = pushIndent();
  for (Block &block : region.getBlocks())
    printBlock(block);
}

 void printBlock(Block &block) {
   // Print the block intrinsics properties (basically: argument list)
   printIndent()
       << "Block with " << block.getNumArguments() << " arguments, "
       << block.getNumSuccessors()
       << " successors, and "
       // Note, this `.size()` is traversing a linked-list and is O(n).
       << block.getOperations().size() << " operations\n";

   // A block main role is to hold a list of Operations: let's recurse into
   // printing each operation.
   auto indent = pushIndent();
   for (Operation &op : block.getOperations())
     printOperation(&op);
 }

Filtered Iterator

getOps<OpTy>(),block返回的是operation的iterator,region返回的是block的iterator
walk,对block或者region里的operation的callback操作

 getFunction().walk([&](mlir::Operation *op) {
   // process Operation `op`.
 });

// 或者指定特殊类型的op
getFunction().walk([](LinalgOp linalgOp) {
  // process LinalgOp `linalgOp`.
  return WalkResult::interrupt(); // 退出walk
});

Traversing the def-use chains

IR中的代表Value,有两种类型:BlockArgument和Operation的结果。operation的输出,分别都是一个value。value的使用者是operation,每一个operation里面的参数都是一个value。

// 打印value对应的来源类型
for (Value operand : op->getOperands()) {
  if (Operation *producer = operand.getDefiningOp()) {
  	// 这种是operation的结果
    llvm::outs() << "  - Operand produced by operation '" << producer->getName() << "'\n";
  } else {
    // If there is no defining op, the Value is necessarily a Block argument.
    auto blockArg = operand.cast<BlockArgument>();
    llvm::outs() << "  - Operand produced by Block argument, number " << blockArg.getArgNumber() << "\n";
  }
}

// 这里的value是每个op的输出
llvm::outs() << "Has " << op->getNumResults() << " results:\n";
for (auto indexedResult : llvm::enumerate(op->getResults())) {
  Value result = indexedResult.value();
  llvm::outs() << "  - Result " << indexedResult.index();
  if (result.use_empty()) {
    llvm::outs() << " has no uses\n";
    continue;
  }
  if (result.hasOneUse()) {
    llvm::outs() << " has a single use: ";
  } else {
    llvm::outs() << " has "
                 << std::distance(result.getUses().begin(),
                                  result.getUses().end())
                 << " uses:\n";
  }
  for (Operation *userOp : result.getUsers()) {
    llvm::outs() << "    - " << userOp->getName() << "\n";
  }
}

下图是block和operation之间的关系,其中可以看到:

  • block中存着一堆operation
  • operation中存着结果和运算
  • value会标记谁用他
    block和operation之间的关系

四、DataFlow


总结

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值