torch.roll 函数的理解:
看这个函数的大多应该是从swin transform来的吧,废话不多说:
https://blog.csdn.net/weixin_42899627/article/details/116095067
我觉得直接看这个博客就很好了,不过还是总结一下吧:
torch.roll(input, shifts, dims=None)
torch.roll(x, shifts=(2, 1), dims=(0, 1))
tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8]])
结果就是:
tensor([[6, 5],
[8, 7],
[2, 1],
[4, 3]])
dims=(0, 1) 这个意思就是第0行 1是第一列 然后去和shifts=(2, 1)对应一下, 就是第0行的下移2 第一列的右移动1 就和挤牙膏一样,首先[1, 2] 被动的从0行下移动到 [5, 6]的这个位置,剩下的 1 2 , 3 4,7 8, 都往下移动,然后在加上第一列的往右移动1,最后变成:
tensor([******,
[3, 4],
[1,2],
[7, 8]])
tensor([[6, 5],
[8, 7],
[2, 1],
[4, 3]])
在swin源码中:
reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
不得不说代码真难懂,数学能力真的好强啊、。。