Shuffle Net V1源代码地址:GitHub - megvii-model/ShuffleNet-Series
代码中ShuffleNetV1/blocks.py中channel_shuffle源代码如下:
def channel_shuffle(self, x):
batchsize, num_channels, height, width = x.data.size()
assert num_channels % self.group == 0
group_channels = num_channels // self.group
x = x.reshape(batchsize, group_channels, self.group, height, width)
x = x.permute(0, 2, 1, 3, 4)
x = x.reshape(batchsize, num_channels, height, width)
return x
代码中第一个reshape应该改成一下代码:
x = x.reshape(batchsize, self.group,group_channels, height, width)
原因
上图是B站大佬同济子豪兄 在将ShuffleNetV1里面Channel Shuffle里面通道重排的过程,这里来给大家解释一下。比如说一个特征图大小为1*12*3*3大小,分别是bacth_size,channel,w,h。我们这里将特征图大小3*3等效看作1*1,然后就联系上图理解一下。这里是将特征图通过分组卷积进行处理,这里分为了3组。这样每组里就有4个数字,也就是这里的Reshape成3行4列矩阵。然后再进行转置,再通过一个reshape处理图中写的是Flatten,程序里是reshape又变回了原来的1*12*1*1这就是整个Channel Shuffle的流程,下面代码是另一个大佬写的,然后忘了地址了。
def channel_shuffle1(x, groups):
batch_size, num_channels, height, width = x.size()
channels_per_group = num_channels // groups
print(channels_per_group)
# reshape
# b, c, h, w =======> b, g, c_per, h, w
x = x.view(batch_size, groups, channels_per_group, height,width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batch_size, -1, height, width)
return x
这里的groups取3就行了,这两个代码功能其实是一样的(上面修改过后的)。
来举几个极端例子说明一下,首先我们生成一个1*6*1*1的张量。
如上图所示,我们自己按照上面子豪兄的图自己算一下,在groups=3的情况下,应该是如下图这样。
下面是源代码输出结果
上面的结果是将上面6个数字,分成了2组,每组3个的结果,大家可以算一下。
下面是另一个大佬写的代码输出结果
与上面输出结果相符合。
ps:去网上看了一圈好像大家都没提出来,这个部分可能对网络性能没啥影响吧。这里就纠正一下,不然大家跟着源代码推理出来就对不上。