参考:https://blog.csdn.net/YXD0514/article/details/132466512
- RepConv简介
RepConv是一种模型重参化技术,它可以在推理阶段将多个计算模块合并为一个,提高模型的效率和性能。它最初是用于VGG网络的,但后来也被应用到其他网络结构,如ResNet和DenseNet。RepConv的主要思想是在训练时使用多分支的卷积层,然后在推理时将分支的参数重参数化到主分支上,从而减少计算量和内存消耗。RepConv在目标检测等任务上取得了很好的效果。
import torch
import torch.nn as nn
class repconv3x3(nn.Module):
def __init__(self, c1, c2):
super().__init__()
self.conv1 = nn.Conv2d(c1,c2,3,1,1)
self.conv2 = nn.Conv2d(c1,c2,1,1,0)
self.conv_fuse=nn.Conv2d(c1,c2,3,1,1)
def fuse_1x1conv_3x3conv(self, conv1, conv2):
conv1x1_weight = nn.functional.pad(conv2.weight, [1,1,1,1])
conv_weight = conv1x1_weight + conv1.weight
conv_bias = conv2.bias + conv1.bias
return conv_weight,conv_bias
def forward(self, x):
x = self.conv1(x) + self.conv2(x)
return x
def forward_fuse(self, x):
self.conv_fuse.weight.data,self.conv_fuse.bias.data=self.fuse_1x1conv_3x3conv(self.conv1,self.conv2)
return self.conv_fuse(x)
inputs = torch.rand((1,1,3,3))
# 重点 模型调到 eval 模式
model = repconv3x3(1,2).eval()
out1 = model.forward(inputs)
out2 = model.forward_fuse(inputs)
print("difference:",((out2-out1)**2).sum().item())
difference: 2.4980018054066022e-15
import torch
import torch.nn as nn
class repconv3x3(nn.Module):
def __init__(self, c1, c2):
super().__init__()
self.conv1x1 = nn.Conv2d(c1,c2,1,1,0)
self.conv1x3 = nn.Conv2d(c2,c2,(1,3),1,(0,1))
self.conv3x1 = nn.Conv2d(c2,c2,(3,1),1,(1,0))
self.conv_fuse=nn.Conv2d(c1,c2,3,1,1)
def fuse_1x1conv_1x3conv_3x1conv(self, conv1, conv2, conv3):
weight=nn.functional.pad(conv1.weight.data,(1,1,1,1))+nn.functional.pad(conv2.weight.data,(0,0,1,1))+nn.functional.pad(conv3.weight.data,(1,1,0,0))
bias=conv1.bias.data+conv2.bias.data+conv3.bias.data
return weight,bias
def forward(self, x):
x = self.conv3x1(x)+self.conv1x3(x)+self.conv1x1(x)
return x
def forward_fuse(self, x):
self.conv_fuse.weight.data, self.conv_fuse.bias.data=self.fuse_1x1conv_1x3conv_3x1conv(self.conv1x1,self.conv1x3,self.conv3x1)
return self.conv_fuse(x)
inputs = torch.rand((1,2,3,3))
# 重点 模型调到 eval 模式
model = repconv3x3(2,2).eval()
out1 = model.forward(inputs)
out2 = model.forward_fuse(inputs)
print("difference:",((out2-out1)**2).sum().item())
difference: 7.327471962526033e-15
- AvgPooling 转换 Conv
池化层是针对各个输入通道的(对单层特征图进行池化操作),而卷积层会将所有输入通道的结果相加。平均池化层可以等价一个固定权重的卷积层,假设池化核的大小为 K,那么可以设置卷积层权重为 1/K。池化权重另外要注意的是卷积层会将所有输入通道结果相加,所以我们需要对当前输入通道设置固定的权重,对其他通道权重设置为0。
import torch
import torch.nn as nn
class repconv3x3(nn.Module):
def __init__(self, c1):
super().__init__()
self.avg = nn.AvgPool2d(3,1,1)
self.conv_fuse=nn.Conv2d(c1,c1,3,1,1,bias=False)
def fuse_avg(self):
self.conv_fuse.weight.data[:]=0
for i in range(self.conv_fuse.in_channels):
self.conv_fuse.weight.data[i,i,:,:]=1/(torch.prod(torch.tensor(self.conv_fuse.kernel_size)))
def forward(self, x):
x = self.avg(x)
return x
def forward_fuse(self, x):
self.fuse_avg()
return self.conv_fuse(x)
inputs = torch.rand((1,2,3,3))
# 重点 模型调到 eval 模式
model = repconv3x3(2).eval()
out1 = model.forward(inputs)
out2 = model.forward_fuse(inputs)
print("difference:",((out2-out1)**2).sum().item())
difference: 1.0658141036401503e-14
2.RepConv优缺点
https://mp.weixin.qq.com/s?__biz=MzU1NjEwMTY0Mw==&mid=2247558125&idx=1&sn=4992c109ea00d4b87db8dcdc4404c02b&chksm=fbc99a89ccbe139f1bee2b622c5529c3374f2b92a74a57979c3bde9356b9215bde413201e1e0&scene=27