视频链接: 16、PyTorch中进行卷积残差模块算子融合_哔哩哔哩_bilibili
Conv2D官方API:Conv2d — PyTorch 2.0 documentation
原始论文:RepVGG: Making VGG-style ConvNets Great Again (arxiv.org)
理论
论文中提及如何将一个训练时的多分支模块转换为单一的 3 × 3 3\times 3 3×3卷积,从而达到加速的目的。如下图所示:
代码验证
视频介绍的代码中并没有考虑BN层。
原生写法
对应上图中(A)的第1幅小图:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
in_c = 2
out_c = 2
k = 3 # kernel_size
w = h = 9
x = torch.ones(1, in_c, h, w) # 输入图片 [batch_size,channels,h,w]
conv_2d = nn.Conv2d(in_c, out_c, k, padding="same") # in_channels=out_channels,不然输入输出不一致,后面无法相加
conv_2d_pointwise = nn.Conv2d(in_c, out_c, 1)
result1 = conv_2d(x) + conv_2d_pointwise(x) + x
print(conv_2d_pointwise.weight.size()) # 打印3*3 conv层参数维度
print(conv_2d_pointwise.bias.size())
print(conv_2d_pointwise.weight.size()) # 打印1*1 conv层参数维度
print(conv_2d_pointwise.bias.size())
参数维度的输出结果如下:
torch.Size([2, 2, 3, 3]) # 3*3 conv weight [out_c, in_c, k, k]
torch.Size([2]) # 3*3 conv bias [out_c]
torch.Size([2, 2, 1, 1]) # 1*1 conv
torch.Size([2])
算法融合
1.改造
代码中需要使用torch.nn.functional.pad
,查看官网上的例子就差不多明白了。
pad
官网API:torch.nn.functional.pad — PyTorch 2.0 documentation
对应上图中(A)的第2幅小图,卷积核参数对应上图(B)的第2幅小图。
首先是将 1 × 1 1\times 1 1×1的卷积变成 3 × 3 3\times 3 3×3卷积:
# 原来 conv_2d_pointwise.weight.size() 为[2,2,1,1]
# 需要将其从 1*1 卷积转变成 3*3 卷积,所以 weight.size 需要变为 [2,2,3,3]
# 使用F.pad将最后的h和w维填充0,使之从[2,2,1,1]->[2,2,3,3]
pointwise_to_conv_weight = F.pad(conv_2d_pointwise.weight, [1, 1, 1, 1]) # pad last dim by (1, 1) and 2nd to last by (1, 1)
conv_2d_for_pointwise = nn.Conv2d(in_c, out_c, k, padding="same")
conv_2d_for_pointwise.weight = nn.Parameter(pointwise_to_conv_weight) # 修改参数
conv_2d_for_pointwise.bias = conv_2d_pointwise.bias
接着是将恒等映射转变成
3
×
3
3 \times 3
3×3的卷积,所以我们不需要考虑相邻像素点以及通道之间的关联性。因此,对于这样一个卷积层,它的weight.size()
首先肯定是2*2*3*3
的大小。
# 不考虑相邻点和通道之间的关联性
# 只考虑单个通道的影响
zeros = torch.unsqueeze(torch.zeros(k, k), 0) # [1,3,3]
# 只考虑一个点的影响
stars = torch.unsqueeze(F.pad(torch.ones(1, 1), [1, 1, 1, 1]), 0) # [1,3,3]
stars和zeros的效果如下图:
# 在第0维进行拼接,再升维 [1,2,3,3]
stars_zeros = torch.unsqueeze(torch.cat([stars, zeros], 0), 0)
zeros_stars = torch.unsqueeze(torch.cat([zeros, stars], 0), 0)
identity_to_conv_weight = torch.cat([stars_zeros, zeros_stars], 0) # [2,2,3,3]
identity_to_conv_bias = torch.zeros([out_c])
conv_2d_for_identity = nn.Conv2d(in_c, out_c, k, padding="same")
conv_2d_for_identity.weight = nn.Parameter(identity_to_conv_weight)
conv_2d_for_identity.bias = nn.Parameter(identity_to_conv_bias)
result2 = conv_2d(x) + conv_2d_for_pointwise(x) + conv_2d_for_identity(x)
# print(result2)
print(torch.all(torch.isclose(result1, result2))) # 判断是否相等
因为是两个浮点矩阵,不能直接用torch.equal
去比较,只能通过torch.isclose
方法比较result1和result2。由于两个均为张量,再加torch.all
将它们统一判断一下,输出结果如下:
tensor(True)
2.融合
最后将这3个 3 × 3 3 \times 3 3×3的卷积融合起来。对应上图中(A)的第3幅小图,根据(B)中的参数示意图,将所有卷积层的权重和偏置各自相加。
conv_2d_for_fusion = nn.Conv2d(in_c, out_c, k, padding="same")
conv_2d_for_fusion.weight = nn.Parameter(conv_2d.weight.data +
conv_2d_for_pointwise.weight.data +
conv_2d_for_identity.weight.data) # 所有参数相加
conv_2d_for_fusion.bias = nn.Parameter(conv_2d.bias.data +
conv_2d_for_pointwise.bias.data +
conv_2d_for_identity.bias.data)
result3 = conv_2d_for_fusion(x)
# print(result3)
print(torch.all(torch.isclose(result3, result2))) # 判断是否相等
判断result3和result2是否相等,输出结果如下:
tensor(True)
对比耗时
导入time
库,使用time.time()
来计算不同方法之间的耗时。全部代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
in_c = 2
out_c = 2
k = 3 # kernel_size
w = h = 9
x = torch.ones(1, in_c, h, w) # 输入图片 [batch_size,channels,h,w]
# res_block = 3*3 conv + 1*1 conv + input
# 方法1.原生写法
t1 = time.time()
conv_2d = nn.Conv2d(in_c, out_c, k, padding="same") # in_channels=out_channels,不然输入输出不一致
conv_2d_pointwise = nn.Conv2d(in_c, out_c, 1)
result1 = conv_2d(x) + conv_2d_pointwise(x) + x
t2 = time.time()
print(conv_2d.weight.size())
print(conv_2d.bias.size())
print(conv_2d_pointwise.weight.size())
print(conv_2d_pointwise.bias.size())
print(t2-t1)
# 方法2.算法融合
# 把 point-wise卷积 和 x 本身都写成 3*3 的卷积
# 最终把三个卷积写成一个卷积
# 1) 改造
pointwise_to_conv_weight = F.pad(conv_2d_pointwise.weight, [1, 1, 1, 1, 0, 0, 0, 0]) # 从里到外pad,上面一行,下面一行,左边一列,右边一列各自pad
conv_2d_for_pointwise = nn.Conv2d(in_c, out_c, k, padding="same")
conv_2d_for_pointwise.weight = nn.Parameter(pointwise_to_conv_weight)
conv_2d_for_pointwise.bias = conv_2d_pointwise.bias
# 不考虑相邻点和通道之间的关联性
# 只考虑单个通道的影响
zeros = torch.unsqueeze(torch.zeros(k, k), 0)
# 只考虑一个点的影响
stars = torch.unsqueeze(F.pad(torch.ones(1, 1), [1, 1, 1, 1]), 0)
stars_zeros = torch.unsqueeze(torch.cat([stars, zeros], 0), 0)
zeros_stars = torch.unsqueeze(torch.cat([zeros, stars], 0), 0)
identity_to_conv_weight = torch.cat([stars_zeros, zeros_stars], 0)
identity_to_conv_bias = torch.zeros([out_c])
conv_2d_for_identity = nn.Conv2d(in_c, out_c, k, padding="same")
conv_2d_for_identity.weight = nn.Parameter(identity_to_conv_weight)
conv_2d_for_identity.bias = nn.Parameter(identity_to_conv_bias)
result2 = conv_2d(x) + conv_2d_for_pointwise(x) + conv_2d_for_identity(x)
# print(result2)
print(torch.all(torch.isclose(result1, result2))) # 判断是否相等
# 2) 融合
t3 = time.time()
conv_2d_for_fusion = nn.Conv2d(in_c, out_c, k, padding="same")
conv_2d_for_fusion.weight = nn.Parameter(conv_2d.weight.data +
conv_2d_for_pointwise.weight.data +
conv_2d_for_identity.weight.data) # 参数相加起来
conv_2d_for_fusion.bias = nn.Parameter(conv_2d.bias.data +
conv_2d_for_pointwise.bias.data +
conv_2d_for_identity.bias.data)
result3 = conv_2d_for_fusion(x)
t4 = time.time()
print("原生写法耗时: ", t2 - t1, "\n算子融合写法耗时: ", t4 - t3)
# print(result3)
print(torch.all(torch.isclose(result3, result2))) # 判断是否相等
输出结果为:
原生写法耗时: 0.0029582977294921875
算子融合写法耗时: 0.0009975433349609375
可以看到,使用算子融合确实耗时更少。