Cannot input a tensor of dimension other than 0 as a scalar argument
错误代码:
channels_per_group = num_channels // groups
@torch.jit.script_method
def channel_shuffle(self, x, groups):
batchsize, num_channels, height, width = x.size()
# assert (num_channels % groups == 0)
channels_per_group = num_channels // groups
# reshape
x = x.view(batchsize, groups, channels_per_group, height, width)
# transpose
# - contiguous() required if transpose() is used before view().
# See https://github.com/pytorch/pytorch/issues/764
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x