说明
\qquad
identity分支就是直连分支,简单说就是复制自我。在实际代码运行中,为了优化掉本操作已减少访存等带来的时间消耗,可以将本操作与其他卷积核进行融合。
\qquad
融合的基础是identity可以转化为
1
×
1
1\times1
1×1卷积核以及
3
×
3
3\times3
3×3卷积核。具体见如下的pytorch代码
代码
import torch
import torch.nn as nn
input = torch.randn(1, 3, 512, 512, dtype=torch.float32)
print(input.shape)
"""
identity 等价转化为 1×1卷积
"""
# 构建 1×1的卷积核,保证输出和输入一致
conv1_1 = nn.Conv2d(3, 3, 1, 1)
conv_weight = torch.zeros(conv1_1.state_dict()['weight'].shape, dtype=torch.float32)
conv_bias = torch.zeros(conv1_1.state_dict()['bias'].shape, dtype=torch.float32)
num_filters = conv_weight.shape[0]
for idx_filters in range(num_filters):
conv_weight[idx_filters, idx_filters, :] = 1.0
# show conv kernels
print(conv_weight, conv_bias)
conv1_1.weight.data = conv_weight
conv1_1.bias.data = conv_bias
output = conv1_1(input)
print('------------- input------------------------------')
print(input[0, 0, :, :])
print('------------- result : identity to 1 × 1 convolution --------------')
print(output.shape)
print(output[0, 0, :, :])
"""
identity 等价转化为 3×3卷积
"""
# 构建 3×3 的卷积核,保证输出和输入一致
conv3_3 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
conv_weight = torch.zeros(conv3_3.weight.data.shape, dtype=torch.float32)
conv_bias = torch.zeros(conv3_3.bias.data.shape, dtype=torch.float32)
num_filters = conv_weight.shape[0]
for idx_filters in range(num_filters):
conv_weight[idx_filters, idx_filters, 1, 1] = 1.0
# show conv kernels
print(conv_weight, conv_bias)
conv3_3.weight.data = conv_weight
conv3_3.bias.data = conv_bias
output = conv3_3(input)
print('------------- input------------------------------')
print(input[0, 0, :, :])
print('------------- result : identity to 3 × 3 convolution --------------')
print(output.shape)
print(output[0, 0, :, :])