Partial Lowering to Lower-Level Dialects for Optimization


toy部分lowering到affine,从toy变成混合的mlir
使用DialectConversion框架需要提供两个东西(和一个可选的第三个)

  • 转换目标
    不合法的操作需要重写变成合法的
  • 一组重写模式
    按照这组规则将非法的操作变成合法的

Conversion Target

void ToyToAffineLoweringPass::runOnOperation() {
  // The first thing to define is the conversion target. This will define the
  // final target for this lowering.
  mlir::ConversionTarget target(getContext());

  // We define the specific operations, or dialects, that are legal targets for
  // this lowering. In our case, we are lowering to a combination of the
  // `Affine`, `Arith`, `Func`, and `MemRef` dialects.
  target.addLegalDialect<affine::AffineDialect, arith::ArithDialect,
                         func::FuncDialect, memref::MemRefDialect>();

  // We also define the Toy dialect as Illegal so that the conversion will fail
  // if any of these operations are *not* converted. Given that we actually want
  // a partial lowering, we explicitly mark the Toy operations that don't want
  // to lower, `toy.print`, as *legal*. `toy.print` will still need its operands
  // to be updated though (as we convert from TensorType to MemRefType), so we
  // only treat it as `legal` if its operands are legal.
  target.addIllegalDialect<ToyDialect>();
  target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
    return llvm::none_of(op->getOperandTypes(),
                         [](Type type) { return type.isa<TensorType>(); });
  });
  ...
}

在MLIR(Multi-Level Intermediate Representation)框架中,转换目标(Conversion Target)定义了在转换过程中哪些操作是合法的,哪些操作是非法的,以及在什么条件下某些操作可以被视为合法。具体来说,mlir::ConversionTarget target(getContext());这行代码的意思是在当前上下文(getContext())中创建一个转换目标对象target,用于指导后续的转换过程。

mlir::ConversionTarget target(getContext());
  • 这行代码实例化了一个ConversionTarget对象。这个对象用于定义转换过程中操作的合法性规则。这是转换过程的起点。
  • getContext()函数返回当前的MLIR上下文。上下文包含了所有与MLIR相关的全局信息,如已注册的方言(Dialects)、类型(Types)等。在转换过程中,上下文提供了所需的所有全局状态和元数据。
target.addLegalDialect<affine::AffineDialect, arith::ArithDialect,
                       func::FuncDialect, memref::MemRefDialect>();

将Affine、Arith、Func和MemRef方言中的操作标记为合法

target.addIllegalDialect<ToyDialect>();

将Toy方言中的操作标记为非法。这意味着在转换过程中,如果任何Toy方言中的操作没有被转换,这将导致转换失败。

target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
  return llvm::none_of(op->getOperandTypes(),
                       [](Type type) { return type.isa<TensorType>(); });
});

为toy::PrintOp操作设置了动态合法性条件。如果toy::PrintOp的所有操作数类型都不是TensorType,那么它就是合法的。这是为了支持部分转换,使得某些特定操作在特定条件下可以被保留.具体来说,如果toy::PrintOp操作的操作数(operands)类型中没有TensorType类型,那么它就是合法的。

Conversion Patterns

/// Lower the `toy.transpose` operation to an affine loop nest.
struct TransposeOpLowering : public mlir::ConversionPattern {
  TransposeOpLowering(mlir::MLIRContext *ctx)
      : mlir::ConversionPattern(TransposeOp::getOperationName(), 1, ctx) {}

  /// Match and rewrite the given `toy.transpose` operation, with the given
  /// operands that have been remapped from `tensor<...>` to `memref<...>`.
  llvm::LogicalResult
  matchAndRewrite(mlir::Operation *op, ArrayRef<mlir::Value> operands,
                  mlir::ConversionPatternRewriter &rewriter) const final {
    auto loc = op->getLoc();

    // Call to a helper function that will lower the current operation to a set
    // of affine loops. We provide a functor that operates on the remapped
    // operands, as well as the loop induction variables for the inner most
    // loop body.
    lowerOpToLoops(
        op, operands, rewriter,
        [loc](mlir::PatternRewriter &rewriter,
              ArrayRef<mlir::Value> memRefOperands,
              ArrayRef<mlir::Value> loopIvs) {
          // Generate an adaptor for the remapped operands of the TransposeOp.
          // This allows for using the nice named accessors that are generated
          // by the ODS. This adaptor is automatically provided by the ODS
          // framework.
          TransposeOpAdaptor transposeAdaptor(memRefOperands);
          mlir::Value input = transposeAdaptor.input();

          // Transpose the elements by generating a load from the reverse
          // indices.
          SmallVector<mlir::Value, 2> reverseIvs(llvm::reverse(loopIvs));
          return rewriter.create<mlir::AffineLoadOp>(loc, input, reverseIvs);
        });
    return success();
  }

如何匹配到操作?

    1. 构造函数中的操作名匹配
TransposeOpLowering(mlir::MLIRContext *ctx)
    : mlir::ConversionPattern(TransposeOp::getOperationName(), 1, ctx) {}

在构造函数中,TransposeOpLowering 调用基类 mlir::ConversionPattern 的构造函数,并传递了 TransposeOp::getOperationName() 作为操作名。这实际上是告诉 ConversionPattern,这个模式是专门用来匹配 TransposeOp 操作的。

    1. TransposeOp::getOperationName()
      TransposeOp::getOperationName() 返回的是 toy.transpose 操作的名称。这个名称是在定义 TransposeOp 操作时指定的,例如在 .td(TableGen)文件中。
    1. 模式匹配过程
      ConversionPattern 使用这个操作名来匹配 IR 中的操作。具体的匹配过程如下:
  • 注册模式:在某个地方(通常是在转换 pass 中),会将这个模式(TransposeOpLowering)注册到 MLIR 的模式列表中。

  • 模式应用:当转换 pass 运行时,它会遍历 IR 中的每个操作,并尝试应用已注册的模式。ConversionPattern 会检查每个操作的名称,如果操作名与 TransposeOp::getOperationName() 返回的名称匹配,就会调用 matchAndRewrite 方法。

matchAndRewrite

  • 矩阵转置
    比如2*3的矩阵转3*2
A = [ [a, b, c],
     [d, e, f] ]
B = [ [a, d],
     [b, e],
     [c, f] ]

A转B写成循环应该是

for (int i = 0; i < 2; ++i) {
 for (int j = 0; j < 3; ++j) {
   B[j][i] = A[i][j];
 }
}

在 MLIR 中,我们需要使用操作如 AffineLoadOp 和 AffineStoreOp 来实现数据的加载和存储。考虑以下简化的伪代码,展示如何通过 MLIR 实现这一点:

// 假设 A 是原始 2x3 矩阵,B 是目标 3x2 矩阵
memref<A> : memref<2x3xf32>
memref<B> : memref<3x2xf32>

for i = 0 to 2 {
 for j = 0 to 3 {
   // 从 A 中加载值
   value = affine.load A[i, j]
   // 将值存储到 B 中
   affine.store value, B[j, i]
 }
}

lowerOpToLoops

 lowerOpToLoops(
        op, operands, rewriter,
        [loc](mlir::PatternRewriter &rewriter,
              ArrayRef<mlir::Value> memRefOperands,
              ArrayRef<mlir::Value> loopIvs) {
          // Generate an adaptor for the remapped operands of the TransposeOp.
          // This allows for using the nice named accessors that are generated
          // by the ODS. This adaptor is automatically provided by the ODS
          // framework.
          TransposeOpAdaptor transposeAdaptor(memRefOperands);
          mlir::Value input = transposeAdaptor.input();

          // Transpose the elements by generating a load from the reverse
          // indices.
          SmallVector<mlir::Value, 2> reverseIvs(llvm::reverse(loopIvs));
          return rewriter.create<mlir::AffineLoadOp>(loc, input, reverseIvs);
        });

函数参数

  • op:这是需要被降低的操作符。

  • operands:这是操作符的操作数。

  • rewriter:这是用于模式匹配和替换的工具。

  • callback:这是一个lambda函数,用于定义如何将操作符的具体行为映射到基础的循环和加载/存储操作中。
    Lambda函数

  • loc:位置信息,用于在创建新的MLIR操作时保留源代码的位置信息。

  • transposeAdaptor:这是一个自动生成的适配器,用于方便地访问TransposeOp的操作数。

  • input:这是TransposeOp的输入数据。

  • reverseIvs:这是一个小向量,包含反转后的循环索引,用于生成反向加载操作。
    简单理解这里生成了循环,lambda里面包含需要循环的操作。

Partial Lowering

def PrintOp : Toy_Op<"print"> {
  ...

  // The print operation takes an input tensor to print.
  // We also allow a F64MemRef to enable interop during partial lowering.
  let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input);
}

这里意思说,让print能过够接受F64MemRef的参数,加上前面有表明什么情况下保留print,最后我们就得到混合的mlir

func.func @main() {
  %cst = arith.constant 1.000000e+00 : f64
  %cst_0 = arith.constant 2.000000e+00 : f64
  %cst_1 = arith.constant 3.000000e+00 : f64
  %cst_2 = arith.constant 4.000000e+00 : f64
  %cst_3 = arith.constant 5.000000e+00 : f64
  %cst_4 = arith.constant 6.000000e+00 : f64

  // Allocating buffers for the inputs and outputs.
  %0 = memref.alloc() : memref<3x2xf64>
  %1 = memref.alloc() : memref<3x2xf64>
  %2 = memref.alloc() : memref<2x3xf64>

  // Initialize the input buffer with the constant values.
  affine.store %cst, %2[0, 0] : memref<2x3xf64>
  affine.store %cst_0, %2[0, 1] : memref<2x3xf64>
  affine.store %cst_1, %2[0, 2] : memref<2x3xf64>
  affine.store %cst_2, %2[1, 0] : memref<2x3xf64>
  affine.store %cst_3, %2[1, 1] : memref<2x3xf64>
  affine.store %cst_4, %2[1, 2] : memref<2x3xf64>

  // Load the transpose value from the input buffer and store it into the
  // next input buffer.
  affine.for %arg0 = 0 to 3 {
    affine.for %arg1 = 0 to 2 {
      %3 = affine.load %2[%arg1, %arg0] : memref<2x3xf64>
      affine.store %3, %1[%arg0, %arg1] : memref<3x2xf64>
    }
  }

  // Multiply and store into the output buffer.
  affine.for %arg0 = 0 to 3 {
    affine.for %arg1 = 0 to 2 {
      %3 = affine.load %1[%arg0, %arg1] : memref<3x2xf64>
      %4 = affine.load %1[%arg0, %arg1] : memref<3x2xf64>
      %5 = arith.mulf %3, %4 : f64
      affine.store %5, %0[%arg0, %arg1] : memref<3x2xf64>
    }
  }

  // Print the value held by the buffer.
  toy.print %0 : memref<3x2xf64>
  memref.dealloc %2 : memref<2x3xf64>
  memref.dealloc %1 : memref<3x2xf64>
  memref.dealloc %0 : memref<3x2xf64>
  return
}
  • 21
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值