本篇文章讲讲resnet50静态图优化加速的手段之一–算子融合,因为算子融合知识非常庞杂,故本文主要从resnet50中Pad和conv2D的算子融合切入,管中窥豹,理解算子融合在静态图层面的实现,读者事先需要了解resnet50的结构。本文首发于我的公众号“AI不止算法”,文章链接在此
背景
之前文章讲了一些算子优化,算子优化只是AI推理优化的手段之一,图优化也是非常重要的手段。
众所周知,TF1.x是一个性能强劲,易于部署的AI框架,这得益于静态图机制,在2020年前是最流行的AI框架,而PyTorch 1.x是一个性能偏弱,部署困难,但是对炼丹人们非常友好易于debug的AI框架,这得益于动态图机制。本文将基于TF python API讲解静态图优化之Pad+Conv2d的算子融合,以此折射出静态图优化的思想,所讲内容或许在PyTorch 2.x出世之后有点过时,但是依然具有学习价值,以前很多时候都会手写这种图优化策略,代表了静态图的常用图优化机制和思想,对于所有静态图优化都是受用的。
ResNet50静态图简览
ResNet50 Pad+Conv2d部分如下两张图所示:清晰描述了Pad节点和conv2d节点的在ResNet50图中的输入输出、名字和各属性(特别注意图1 pad节点中paddings这个输入节点,这是后续图优化处理的重点输入节点)。
图1
图2
注意:为了降低理解门槛,本文使用TensorFlow Python API展示resnet50静态图优化手段之Pad+Conv2d的融合,融合结果为Pad + conv2d => conv2d,即把Pad从图中去掉,前提是推理结果依然正确,由于写一个完成的融合策略,代码非常长,本文截取的代码仅仅是上层代码,下层具体的实现没有暴露出来,比如DFS遍历静态图解析图上所有节点信息,这个时候leetcode算法派上了用场,如果你刷过DFS以及学校里学过“图论”这门课程,我想理解起来很容易。
Pad+Conv2d的图优化融合代码
具体见注释即可
from tensorflow.python.framework import tensor_util