RepVGG论文阅读及pytorch实现解析
论文地址:https://paperswithcode.com/method/repvgg
1. 提出背景
- 以ResNet为代表的具有分支的网络虽然有利于网络训练但会降低内存使用效率和减慢模型推理速度。
如上图(A),以残差连接为例,其模块输入得保留直至cat操作完成,之间占用了两份内存空间,如果换成直通的架构,只需保留一份数据,节省了部分内存。
2. 核心创新点
灵感来源:如何保留多分支结构训练时的优点,同时改进其推理时的缺点?能不能在训练时使用多分枝结构,推理时采用等价的直通结构,实现训练与推理的解耦。关键在于等价转换。
2.1 整体结构
如上图A所示,借鉴ResNet的思路,分支结构如shortcut数学上可以看作 y = x + f(x),如果f(x)什么也没有学到,那么 y = x表明输入等于输出,这就保证了ResNet深模型的特征空间一定包括浅模型,一般来说能得到更好的结果。RepVGG训练时如上图B使用 3 x 3 , 1 x 1 , indentity三个分支来组成一个block, y = x + f(x) + g(x)。RepVGG推理时如上图C使用只含有3 x 3 (硬件优化好)+ relu的模块。关键在于等价转换。
2.2 等价转换 重参数化原理
https://blog.csdn.net/qq_33532713/article/details/108602463
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-DORaHh6W-1667960933640)(C:\Users\Admin\AppData\Roaming\Typora\typora-user-images\image-20221028182855318.png)]
-
第一步:conv + bn —> fusedConv
C o n v ( X ) = W × X Conv(X) = W \times X Conv(X)=W×XB N ( X ) = γ × X − μ σ + β BN(X) = \gamma \times \frac{X-\mu}{\sigma} +\beta BN(X)=γ×σX−μ+β
f u s e d C o n v ( X ) = B N ( C o n v ( X ) ) = γ × W × X − μ σ + β = γ σ × W × X + ( − μ γ σ + β ) = W