MLIR官方Tutorials学习笔记(四)

        本章主要学习interface和pass的作用和使用。

        interface作用是定义一个operation属性(其中包括方法),其他operations如果在自己属性中引入了该interface,就要在operations的具体定义中实现该interface的方法。个人感觉很像C++里的多态性,不同子类通过重写父类函数来实现调用时的不同行为,而这个是不同operations通过在自己的定义里重写interface的方法,来实现使用operations时的不同行为。

        pass作用个人理解是一个代码检查员,他有自己的代码优化逻辑,在生成MLIR IR之后,他会按照自己的逻辑来检查该IR有没有需要优化的地方。比如内联pass检查到了一个函数需要内联,那就执行内联,从而实现优化。

        这章首先介绍了如何添加内联pass,我们先来看没有添加内联pass时的MLIR IR:

        再来看看添加内联pass之后的IR:
        可以看到,multiply_transpose()函数被inline到了main函数里,这样就达到了一种优化。

        那具体如何实现呢?首先因为mlir中自带DialectInlinerInterface,这个接口是应用在Dialect粒度上的内联接口,也就是说如果我们继承这个接口重写的内联类也是应用在整个Dialect上的。我们定义ToyInlinerInterface来实现我们的内联接口:

struct ToyInlinerInterface : public DialectInlinerInterface {
  using DialectInlinerInterface::DialectInlinerInterface;

  /// 这里我们让toy中所有的call operations都可以内联
  bool isLegalToInline(Operation *call, Operation *callable,
                       bool wouldBeCloned) const final {
    return true;
  }

  /// toy中所有operations都可以内联
  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
    return true;
  }

  //toy中所有func都可以内联
  bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
    return true;
  }


  /// Handle the given inlined terminator(toy.return) by replacing it with a new
  /// operation as necessary.
  void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
    // Only "toy.return" needs to be handled here.
    auto returnOp = cast<ReturnOp>(op);

    // Replace the values directly with the return operands.
    assert(returnOp.getNumOperands() == valuesToRepl.size());
    for (const auto &it : llvm::enumerate(returnOp.getOperands()))
      valuesToRepl[it.index()].replaceAllUsesWith(it.value());
  }

  Operation *materializeCallConversion(OpBuilder &builder, Value input,
                                       Type resultType,
                                       Location conversionLoc) const final {
    return builder.create<CastOp>(conversionLoc, resultType, input);
  }
};

        之后我们在Dialect的初始化中添加这个接口:

void ToyDialect::initialize() {
  addOperations<
#define GET_OP_LIST
#include "toy/Ops.cpp.inc"
      >();
  addInterfaces<ToyInlinerInterface>();
}

        此时我们的Dialect中有了这个接口,但是还没有应用pass来调用这个接口,我们在toyc.cpp中的passmanager添加MLIR自带的createInlinerPass(),这样pass就可以检查整个dialect中有没有可以应用我们已定义的inline接口的地方并应用。

pm.addPass(mlir::createInlinerPass());

          在ToyInlineInterface类中有一个细节,就是有这么一个函数:

Operation *materializeCallConversion(OpBuilder &builder, Value input,
                                       Type resultType,
                                       Location conversionLoc) const final {
    return builder.create<CastOp>(conversionLoc, resultType, input);
  }

        它的作用是将input的类型转换为resultType的类型,为什么要这么做呢,因为如果要内联一个函数,那么该函数定义时所需参数类型和调用该函数时传入的参数类型可能不一致。如我们定义multiply_transpose函数时声明的参数类型是tensor<*xf64>,而我们内联时要传入该函数的类型是tensor<2×3×f64>,因此在调用之前我们要进行一个类型转换(cast)。

        我们在Ops.td和Dialect.cpp中定义CastOp,用作类型转换,暂且不提。

        这章第二部分是讲解怎么添加shape推断pass。首先我们定义shape推断interface,可以使用ODS框架定义:

def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
  let description = [{
    Interface to access a registered method to infer the return types for an
    operation that can be used during type inference.
  }];

  let methods = [
    InterfaceMethod<"Infer and set the output shape for the current operation.",
                    "void", "inferShapes">
  ];
}

        有了ShapeInference定义之后,就可以在Ops.td中的需要类型推断的operations属性里添加这个接口了。比如我们在castOp中要用到类型推断,则在castOp定义里的属性里添加DeclareOpInterfaceMethods<ShapeInferenceOpinterface>:

def CastOp : Toy_Op<"cast", [
     DeclareOpInterfaceMethods<CastOpInterface>,
     DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
     Pure,
     SameOperandsAndResultShape
  ]>

        接着我们在castOp的具体定义里实现inferShapes的方法:

void CastOp::inferShapes() { getResult().setType(getInput().getType()); }

        意义:将结果的类型设置为输入的类型。

        最后我们定义一个类型推断的pass:

struct ShapeInferencePass
    : public mlir::PassWrapper<ShapeInferencePass, OperationPass<toy::FuncOp>>

        并在该类中的runOnOperation()里实现检查一个Func中的Op有没有类型推断的属性,如果有,则执行该Op的inferShapes方法:

if (auto shapeOp = dyn_cast<ShapeInference>(op)) {
        shapeOp.inferShapes();

        最最后,我们把定义好的pass指针add到passmanager:

optPM.addPass(mlir::toy::createShapeInferencePass());

        我们来看添加了shapeInferencePass后的IR:

        相信你已经看出了和之前的区别。

  • 7
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值