shuffle_channel
torch.nn.ChannelShuffle(groups)
将输入data的通道混洗重排,把所有通道分成group个组,并通过逐一从每个组中选择元素来组成新的顺序。
example:
input_data = torch.arange(0,16).view(1,4,2,2)
输入
tensor([[[[ 0, 1],
[ 2, 3]],
[[ 4, 5],
[ 6, 7]],
[[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15]]]])
channel_shuffle = nn.ChannelShuffle(2) #分成2组
output = channel_shuffle(input_data) #运行
结果
tensor([[[[ 0, 1],
[ 2, 3]],
[[ 8, 9],
[10, 11]],
[[ 4, 5],
[ 6, 7]],
[[12, 13],
[14, 15]]]])
依据这个功能,可以由几个op融合成shuffle_Channel
reshape1->transpose->reshape2
reshape1把channel分成对组
transpose组进行转置交换
reshape2组合并