pytorch中对Tensor进行分块的操作: split()
tensor.split(selection_or_int,dim),返回的是一个列表,第一个参数是Int, 则按照参数的值在dim上均分(最后一组保留)。如果是List, 则List中包含了每块的大小。
reshape_bottleneck = torch.randn((8, 8, 16, 26, 26))
t_fea_forward, _ = reshape_bottleneck.split([8 - 1, 1], dim=1) # n, t-1, c//r, h, w
diff= reshape_bottleneck[:, :-1] - t_fea_forward
print(diff)