小白学Pytorch系列–Torch.nn API Shuffle Layers(16)
方法 | 注释 |
---|
nn.ChannelShuffle | 将形状为
(
∗
,
C
,
H
,
W
)
(*,C,H,W)
(∗,C,H,W)的张量中的通道划分为g组,并将它们重新排列为
(
∗
,
C
g
,
g
,
H
,
W
)
(*,C^g,g,H,W)
(∗,Cg,g,H,W),同时保持原始张量形状。 |
nn.ChannelShuffle
>>> channel_shuffle = nn.ChannelShuffle(2)
>>> input = torch.randn(1, 4, 2, 2)
>>> print(input)
[[[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]],
[[9, 10],
[11, 12]],
[[13, 14],
[15, 16]],
]]
>>> output = channel_shuffle(input)
>>> print(output)
[[[[1, 2],
[3, 4]],
[[9, 10],
[11, 12]],
[[5, 6],
[7, 8]],
[[13, 14],
[15, 16]],
]]