1、nn.conv
说下个人的理解,在 torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias)中,conv2d会首先根据设定的in_channel 和 out_channel初始化相对应的weights和bias,但是可以仔细地查看最终卷积时候调用的函数,其实是对weight和input进行conv2d的卷积计算,所以其实最重要的维度对齐是由weight和bias控制的。
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, bias, self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, weight, bias, self.stride,
self.padding, self.dilation, self.groups)
def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.weight, self.bias)
2、nn.module内部模块的增加
本部分主要是在collapse操作之后引入的,是用于模型内部的结构删减替换。self.__delattr__本身就是用于类内部删除相关属性。
class RBRepSR_three(nn.Module):
"""Residual block without BN, without relu
It has a style of:
-----Conv3----Conv3----Conv3-----+-
|_______________________________|
Args:
num_feat (int): Channel number of intermediate features.
Default: 64.
res_scale (float): Residual scale. Default: 1.
pytorch_init (bool): If set to True, use pytorch default init,
otherwise, use default_init_weights. Default: False.
"""
def __init__(self, num_feat=64, res_scale=1, pytorch_init=False, deploy_flag=False):
super(RBRepSR_three, self).__init__()
self.res_scale = res_scale
self.deploy = deploy_flag
self.num_feat = num_feat
if self.deploy:
self.rbr_reparam = nn.Conv2d(num_feat, num_feat, 7, 1, 7//2, bias=True)
else:
self.padding = nn.ZeroPad2d(3)
self.conv3x3_1 = nn.Conv2d(num_feat, num_feat, 3, 1, 0, bias=True)
self.conv3x3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 0, bias=True)
self.conv3x3_3 = nn.Conv2d(num_feat, num_feat, 3, 1, 0, bias=True)
#self.relu = nn.ReLU(inplace=True)
if not pytorch_init:
default_init_weights([self.conv3x3_1, self.conv3x3_2, self.conv3x3_3], 0.1)
def forward(self, x):
if hasattr(self, 'rbr_reparam'):
return self.rbr_reparam(x)
else:
identity = x
out = self.conv3x3_3(self.conv3x3_2(self.conv3x3_1(self.padding(x))))
return identity + out * self.res_scale
def get_equivalent_kernel_bias(self):
temp = reparameter_33(self.conv3x3_1, self.conv3x3_2)
fused = reparameter_33(temp, self.conv3x3_3)
kernel_identity = torch.zeros((self.num_feat, self.num_feat, 7, 7))
for i in range(self.num_feat):
kernel_identity[i, i, 3, 3] = 1
self.rbr_reparam = nn.Conv2d(self.num_feat, self.num_feat, 7, 1, 7//2, bias=True)
self.rbr_reparam.weight.data = fused.weight.data * self.res_scale + kernel_identity
self.rbr_reparam.bias.data = fused.bias.data * self.res_scale
def switch_to_deploy(self):
if hasattr(self, 'rbr_reparam'):
return
self.get_equivalent_kernel_bias()
for para in self.parameters():
para.detach_()
self.__delattr__('padding')
self.__delattr__('conv3x3_1')
self.__delattr__('conv3x3_2')
self.__delattr__('conv3x3_3')