TVM Relay Pass探究

本文探讨TVM中的 Relay Pass,关键组件Relay用于深度学习模型转换和优化。Pass分为TIR和Relay层,涉及模块、函数和顺序级别优化。通过PassContext配置执行策略,如FuseOps、FoldConstant等。文章介绍了Pass的添加方法,包括Python装饰器和C++实现,并详细解析了Pass的底层调用逻辑,帮助读者理解Pass的工作原理。
摘要由CSDN通过智能技术生成

引言

 

Relay 是 TVM 中十分重要的基础组件之一,用于对接不同格式的深度学习模型以及进行模型的 transform。深度学习编译器的核心功能就是进行各种各样的 transform 变换,这个变换过程部分是由 Pass 来实现。当需要遍历计算图时,底层究竟是如何执行的?本文打算一探究竟。


1. 简介

Pass 两层设计:

  • TIR 层,基于 target 的优化,主要涉及 lower 到 target 时采用的优化策略,包括:VectorizeLoop、UnrollLoop、RemoveNoOp、SkipAssert、ThreadSync 等;此部分 Pass 有时可以直接复用底层编译器的 pass,如 LLVM/CUDA C 等编译器。TVM 主要关注和 ML 相关且底层编译器未考虑到的场景。

  • Relay 层:基于 计算图 的优化,主要通过对 AST 分析,进行 node 的修改来实现。

Pass 功能上可分为三类:

  • module level:tvm.transform.ModulePass,利用全局信息进行优化,可以增加或删除 module 内的 function;

    • 例如 FlattenNestedTuples, RemoveUnusedFunctions, PartitionGraph, InferType, dead code elimination, A-normal form conversion, lambda lifting;

  • function level:tvm.relay.transform.FunctionPasstvm.tir.transform.PrimFuncPass,对 IRModule 内的单个或多个 function 进行改写,TVM 中绝大部分 Pass 都是这类;

    • 例如 comm subexpression elimination, vectorizition;

  • sequential level:tvm.transform.Sequential,是一个 container,可以装载多个 Pass,顺序执行;可以认为是前两个的一个封装而已。

TVM 中有较多的 Pass,运行我们在调用时可以创建一个 PassContext 上下文环境,调用优先级:disabled_pass > required_pass > opt_level。

首先检查该 Pass 是否被用户 disable,然后检查该 pass 是 required,最后检查 Pass 的 opt_level 是否低于 pass context 中的 opt_level。如上均满足条件后,该 pass 即为 enabled。对应代码如下:

bool PassContext::PassEnabled(const PassInfo& info) const {
  if (PassArrayContains(operator->()->disabled_pass, info->name)) {
    return false;
  }

  if (PassArrayContains(operator->()->required_pass, info->name)) {
    return true;
  }

  return operator->()->opt_level >= info->opt_level;
}

常用 Pass 的 opt_level见下面列表,可以看到 FuseOps 作为推理性能强相关的 Pass,其优先级默认设置为了最高(0,数字越小,优先级越高),FoldConstant 常量折叠的优先级也被设置为了 2。更为激进的性能优化 Pass 如 CombineParallelConv2d、FastMath、DenseToSparse 和 Conv2dToSparse2 等都被设置为了 4 和 5。由于 TVM demo 中大部分都是设置 opt_level=3,上面提到的更为激进的性能优化 Pass 并没有 enable。因此,我们可以设置更高的 opt_level,同时在 required_pass 参数列表中加入所需 pass,则可以进一步提升模型的推理性能哦。10 行代码改动,性能提升 10% 不是梦。

  • FuseOps:0

  • DeadCodeElimination:1

  • FoldConstant:2

  • ConvertLayout:3

  • EliminateCommonSubexpr:3

  • CombineParallelConv2d:4<

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Linux基金会AI&Data基金会

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值