Pattern Rewriting : Generic DAG-to-DAG Rewriting

Introduction

模式重写框架主要可以分解为两个部分:模式定义和模式应用。

Defining Patterns

模式是通过继承 RewritePattern 类来定义的。该类代表了 MLIR 中所有重写模式的基类,包括以下组成部分:

Benefit

  • 预期好处:应用一个模式(pattern)可以带来一定的优化效果,这个优化效果在模式创建时是固定的,但也可以在模式初始化时根据具体情况(例如目标架构)动态计算。

  • 静态 vs 动态:静态的好处是预先确定的,而动态的好处则可以在运行时根据具体情况计算。

  • 优化模式匹配:通过限制动态计算,可以让模式匹配更高效。研究表明,使用“匹配谓词”(简单条件判断)可以避免大部分情况下的动态计算。也就是说,我们可以为每种可能的情况预先创建一个模式,然后用简单的条件判断来选择合适的模式。

例子:将加法和乘法优化成乘法

假设我们有一个简单的表达式:a + a,我们希望将其优化成2 * a,因为乘法运算通常比加法运算更高效。

  • 创建模式:我们创建一个模式,识别出a + a的形式,并将其转换为2 * a。这个优化的好处是显而易见的,因为乘法比加法更高效。

  • 静态好处:在创建模式时,我们预先知道这个转换会带来优化,所以这是一个静态的好处。

  • 动态计算:如果我们针对不同的硬件架构进行优化,比如某些架构上的加法比乘法更快,我们可以在模式初始化时根据架构信息动态决定是否应用这个优化。

  • 匹配谓词:为了避免复杂的动态计算,我们可以创建多个版本的模式。例如,一个版本适用于某些架构,另一个版本适用于其他架构。使用简单的条件判断(匹配谓词)来选择哪个版本的模式。

// 初始代码
%result = add %a, %a

// 应用模式后的优化代码
%result = mul constant(2), %a

匹配谓词是什么?

“匹配谓词”(match predicate)是一种条件判断,用来决定某个模式是否应该被应用。在MLIR中,模式匹配是将一个特定的代码模式转换为更优化的形式,而匹配谓词就是用来判断这个代码模式是否符合某些条件,从而决定是否进行转换。

举例解释

假设我们有一个模式,用来优化某种数学表达式,比如将x * 1优化为x,因为乘以1不会改变值。

没有匹配谓词的情况

在最简单的情况下,我们可以直接定义一个模式:

pattern {
  match: "mul %x, 1"
  rewrite: "%x"
}

使用匹配谓词的情况

但是,有时候我们需要一些额外的条件来决定是否应用这个模式。比如,我们只有在某些特定情况下(比如x是一个特定类型的变量)才希望进行这个优化。这个时候就需要用到匹配谓词。

pattern {
  match: "mul %x, 1"
  predicate: "isSpecialType(%x)"
  rewrite: "%x"
}

Root Operation Name(根操作名称)

这是一个可选的参数,用于指明这个模式(pattern)要匹配的根操作的名称。如果指定了根操作名称,那么只有具有该名称的操作才会被提供给匹配和重写的实现代码。如果没有指定,那么任何类型的操作都可能被提供。提供根操作名称有助于在应用成本模型时简化模式分析。如果要匹配任何类型的操作,需要提供一个特殊的标签(MatchAnyOpTypeTag)来明确意图。

match and rewrite implementation(匹配和重写实现)

这是指匹配给定的根操作并重写IR(中间表示)的代码块。一个RewritePattern可以通过独立的match和rewrite方法,或通过一个结合的matchAndRewrite方法来指定其实现。当使用结合的matchAndRewrite方法时,在匹配成功之前不应进行任何IR的变动。结合的matchAndRewrite方法在匹配和重写阶段需要非平凡的可重新计算信息时特别有用。

class MyPattern : public RewritePattern {
public:
  // 构造一个只匹配名称为`MyOp`的操作的模式
  MyPattern(PatternBenefit benefit, MLIRContext *context)
      : RewritePattern(MyOp::getOperationName(), benefit, context) {}
  
  // 构造一个匹配任何类型操作的模式
  MyPattern(PatternBenefit benefit)
      : RewritePattern(benefit, MatchAnyOpTypeTag()) {}

  // 使用独立的match和rewrite方法来实现匹配和重写
  LogicalResult match(Operation *op) const override {
    // 如果模式匹配,返回`success()`,否则返回`failure`
    // ... (具体匹配逻辑)
  }

  void rewrite(Operation *op, PatternRewriter &rewriter) {
    // 使用提供的rewriter对IR进行变动
    // ... (具体重写逻辑)
  }

  // 使用结合的matchAndRewrite方法来实现匹配和重写
  LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) {
    // 这个方法同时进行匹配和变动
    // 注意在匹配成功之前不应进行IR变动
    // ... (具体逻辑)
  }
};

Restrictions

  • 匹配阶段:在这个阶段,不能对IR进行任何修改。也就是说,只能读数据,不能改数据。
  • 重写阶段:在这个阶段,可以对IR进行修改,但必须通过指定的PatternRewriter来操作。PatternRewriter类提供了执行各种可能的修改操作的接口。例如,如果要删除一个操作(operation),不能直接调用这个操作的删除方法,而是应该使用PatternRewriter提供的删除方法eraseOp。此外,根操作必须被就地更新、替换或删除。
struct MyPattern : public mlir::RewritePattern {
  MyPattern(mlir::MLIRContext *context)
      : mlir::RewritePattern("my_op", 1, context) {}

  // 匹配阶段
  mlir::LogicalResult match(mlir::Operation *op) const override {
    // 只能读取数据,不能修改op
    if (auto myOp = llvm::dyn_cast<MyOp>(op)) {
      return mlir::success();
    }
    return mlir::failure();
  }

  // 重写阶段
  void rewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override {
    // 使用PatternRewriter进行IR的修改
    rewriter.setInsertionPoint(op);
    auto newOp = rewriter.create<NewOp>(op->getLoc(), op->getOperands());

    // 使用PatternRewriter来删除操作
    rewriter.eraseOp(op);
  }
};

递归应用(Recursion):

  • 递归在编程中是指一个函数调用自身。在模式重写中,一个模式可以应用在它自己产生的结果上。
  • 想象一下,你有一个模式,它每次运行都会从一个循环中去掉一层迭代。如果这个循环可以剥掉多层迭代,那么这个模式可能会被反复应用多次。
  • 问题是,这种反复应用可能会引起无限循环,导致程序无法停止运行。因此,系统默认假设所有模式都不能安全地递归,如果检测到递归就会停止。
  • 如果你确定某个模式可以安全地递归,你需要显式告诉系统,这样系统就不会阻止它。这可以通过调用 setHasBoundedRewriteRecursion 来完成。

剥离迭代

一种优化技术,主要用于从循环中提取出一个或几个单次迭代,使其单独处理。这样做可以帮助更好地进行代码优化,例如更好地并行化或者处理特殊情况。我们通过一个具体的例子来说明这个过程。
假设我们有一个简单的MLIR循环,如下所示:

func @example(%N: index) {
  %c0 = constant 0 : index
  %c1 = constant 1 : index
  scf.for %i = %c0 to %N step %c1 {
    // 循环体
  }
  return
}

剥离单次迭代的具体步骤

  • 确定循环的范围和步长:首先,我们需要知道循环的下界、上界和步长。
  • 生成剥离的单次迭代:在原始循环之前创建一个新的循环,范围是从下界到下界加上步长。
  • 更新原始循环的范围:将原始循环的下界更新为新的下界,即下界加上步长。
    下面是用C++和MLIR的具体实现,展示如何进行这一步骤。

C++/MLIR实现
首先,我们定义一个模式来进行剥离操作:

#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Builders.h"
#include "mlir/Dialect/SCF/SCFOps.h"

using namespace mlir;

struct PeelLoopPattern : public RewritePattern {
  explicit PeelLoopPattern(MLIRContext *context)
      : RewritePattern(scf::ForOp::getOperationName(), 1, context) {}

  LogicalResult matchAndRewrite(Operation *op,
                                PatternRewriter &rewriter) const override {
    auto forOp = cast<scf::ForOp>(op);

    // 假设我们只处理步长为1的情况
    if (!matchPattern(forOp.getStep(), m_One())) {
      return failure();
    }

    // 提取循环的范围
    Value lowerBound = forOp.getLowerBound();
    Value upperBound = forOp.getUpperBound();
    Value step = forOp.getStep();

    // 生成剥离的单次迭代
    rewriter.setInsertionPoint(forOp);
    Value peeledIter = rewriter.create<scf::ForOp>(
        forOp.getLoc(), lowerBound, rewriter.create<AddIOp>(forOp.getLoc(), lowerBound, step), step,
        forOp.getIterOperands());
    
    // 将循环体移动到新的单次迭代中
    rewriter.inlineRegionBefore(forOp.getRegion(), peeledIter.getRegion(),
                                peeledIter.getRegion().begin());

    // 更新原始循环的范围
    rewriter.setInsertionPointAfter(peeledIter);
    Value newLowerBound = rewriter.create<AddIOp>(forOp.getLoc(), lowerBound, step);
    rewriter.updateRootInPlace(forOp, [&]() {
      forOp.setLowerBound(newLowerBound);
    });

    return success();
  }
};

void registerPeelLoopPattern(RewritePatternSet &patterns) {
  patterns.add<PeelLoopPattern>(patterns.getContext());
}

然后,我们需要将这个模式注册到MLIR Pass中,并在Pass中应用它:

struct PeelLoopPass : public PassWrapper<PeelLoopPass, OperationPass<FuncOp>> {
  void runOnOperation() override {
    FuncOp func = getOperation();

    RewritePatternSet patterns(&getContext());
    registerPeelLoopPattern(patterns);

    if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
      signalPassFailure();
    }
  }
};

std::unique_ptr<Pass> createPeelLoopPass() {
  return std::make_unique<PeelLoopPass>();
}

经过上述过程,原始的MLIR循环:

scf.for %i = %c0 to %N step %c1 {
  // 循环体
}

将被转换为:

scf.for %i = %c0 to %c1 step %c1 {
  // 剥离的单次迭代的循环体
}
scf.for %i = %c1 to %N step %c1 {
  // 剩余迭代的循环体
}

这样,我们就完成了剥离迭代的操作。剥离后的单次迭代可以单独优化或并行化处理。

调试名称和标签

在调试代码时,我们有时需要追踪特定的模式(相当于代码中的模板或者规则)。为了方便,我们可以给这些模式起一个调试名称(类似于给每个模式贴一个标签),这样在查看调试信息时,就能很容易地知道是哪个模式在起作用。此外,我们还可以给一组模式起一个共同的标签,这样可以方便地对这组模式进行过滤和分类。
假设我们有一个模式叫做 MyPattern,它是我们定义的一种重写规则。我们可以给它设置一个调试名称和标签:

class MyPattern : public RewritePattern {
public:
  using RewritePattern::RewritePattern;

  void initialize() {
    setDebugName("MyPattern");
    addDebugLabels("MyRewritePass");
  }
};

// 在某个地方,我们要把这些模式添加到一个集合中,并给它们设置一个公共标签:
void populateMyPatterns(RewritePatternSet &patterns, MLIRContext *ctx) {
  patterns.addWithLabel<MyPattern>("MyRewritePatterns", ctx);
}

初始化

有些模式在使用前需要进行特殊的初始化,比如如果一个模式会递归调用自身,那么我们需要明确地标记它可以处理这种递归。这种初始化可以在模式的构造函数中完成,也可以通过一个专门的初始化方法来完成。
仍然以 MyPattern 为例,如果这个模式需要处理递归调用,我们可以这样做:

class MyPattern : public RewritePattern {
public:
  using RewritePattern::RewritePattern;

  void initialize() {
    setHasBoundedRewriteRecursion();
  }
};

构造

为了确保模式在创建后被正确初始化并且可以正常使用,我们建议使用一种标准的创建方法。这种方法确保所有需要的初始化都已经完成。
假设我们需要创建一个 MyPattern 的实例并添加到模式集合中,我们可以这样做:

void populateMyPatterns(RewritePatternSet &patterns, MLIRContext *ctx) {
  // 使用 create<T> 方法来创建并初始化模式
  auto myPattern = RewritePattern::create<MyPattern>(ctx);
  patterns.add(std::move(myPattern));
}

Pattern Rewriter (模式重写器)

PatternRewriter 是一个特殊的类,允许模式(pattern)与模式应用的驱动程序进行通信。所有对中间表示(IR)的更改,包括创建,必须通过PatternRewriter类进行。这是因为底层的模式驱动程序可能有状态,如果直接进行更改会使这些状态无效。

下面是一些常见的PatternRewriter API示例,请参考类文档以获取最新的API列表:

  • 擦除操作:eraseOp

这个方法用来删除没有结果或者其结果没有被使用的操作。

  • 通知匹配失败的原因:notifyMatchFailure

这个方法允许在matchAndRewrite方法中提供一个诊断消息,说明为什么一个模式匹配失败。如何显示这个消息取决于具体的模式驱动程序。

  • 替换操作:replaceOp/replaceOpWithNewOp

这个方法用提供的一组值替换一个操作的结果,并擦除该操作。

  • 原地更新操作:(start|cancel|finalize)OpModification

这是一组方法,提供类似事务的API,用于在模式中原地更新操作的属性、位置、操作数或后继者。更新事务通过startOpModification开始,可以用cancelOpModification取消或用finalizeOpModification完成。一个方便的封装modifyOpInPlace可以在回调周围自动包裹开始和完成

模式应用

我们定义了一些优化或转换模式,然后将这些模式应用到某个程序或数据结构上,以优化其性能或改变其结构。

  • RewritePatternSet:这是一个用来存储所有模式的集合,就像一个模式的清单。

  • PatternRewriter:这是一个工具,用来实际执行模式中的变更。为了确保在执行变更时不会破坏系统的状态,我们需要定制这个工具。

  • PatternApplicator:这是一个负责实际应用模式的类。它使用一个成本模型来决定哪些模式最值得应用,并按照这个模型来应用模式。

  • 成本模型:这是一个用来评估每个模式收益的算法,帮助我们决定应该优先应用哪个模式。

以下是一个简单的MLIR示例,展示如何定义和应用一个模式来优化一个操作:

Step 1: 定义一个简单的MLIR操作

module {
  func @simple_op(%arg0: i32) -> i32 {
    %0 = "my_dialect.my_op"(%arg0) : (i32) -> i32
    return %0 : i32
  }
}

Step 2: 定义一个优化模式

我们将定义一个模式来将这个操作优化为另一个操作,例如将 my_dialect.my_op 转换为 my_dialect.optimized_op。

class MyPattern : public mlir::RewritePattern {
public:
  MyPattern(mlir::MLIRContext *context)
      : RewritePattern("my_dialect.my_op", 1, context) {}

  mlir::LogicalResult matchAndRewrite(
      mlir::Operation *op, mlir::PatternRewriter &rewriter) const override {
    // 检查操作是否为目标操作
    if (op->getName().getStringRef() != "my_dialect.my_op")
      return mlir::failure();

    // 创建一个新的操作来替换旧的操作
    rewriter.replaceOpWithNewOp<mlir::Operation>(op, "my_dialect.optimized_op",
                                                 op->getResultTypes(), op->getOperands());
    return mlir::success();
  }
};

Step 3: 收集模式并应用

void applyMyPatternDriver(mlir::Operation *op, mlir::MLIRContext *context) {
  mlir::RewritePatternSet patterns(context);
  patterns.add<MyPattern>(context);

  mlir::FrozenRewritePatternSet frozenPatterns(std::move(patterns));

  mlir::PatternApplicator applicator(frozenPatterns);

  // 应用默认的成本模型
  applicator.applyDefaultCostModel();

  mlir::PatternRewriter rewriter(context);

  // 匹配并应用模式
  mlir::LogicalResult result = applicator.matchAndRewrite(op, rewriter);
  if (failed(result)) {
    // 没有应用任何模式
    llvm::errs() << "No patterns were applied.\\n";
  } else {
    // 成功应用了一个模式
    llvm::errs() << "A pattern was successfully applied.\\n";
  }
}

Dialect Conversion Driver(方言转换驱动器):

  • 该驱动器提供了一个框架,用于在方言之间及方言内部进行操作转换。使用“合法性”的概念,将不合法的操作转换为目标方言支持的操作。
  • 还支持类型转换。

Greedy Pattern Rewrite Driver(贪婪模式重写驱动器):

  • 该驱动器以工作列表的方式处理操作,并贪婪地应用在本地最有益的模式。
    模式的益处由模式自身的益处和模式列表中的相对顺序决定。
    该驱动器有两种形式:
  • Region-based driver(基于区域的驱动器):应用模式到指定区域内的所有操作。
  • Op-based driver(基于操作的驱动器):应用模式到指定的操作列表。
    驱动器通过GreedyRewriteConfig进行配置,可以选择自顶向下或自底向上的遍历方式。
  • 7
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值