torch中的roll函数可以用于张量的位置变换操作。
博客推荐
import torch
import numpy as np
import matplotlib.pyplot as plt
shift_size = 3
'''构造多维张量'''
x=np.arange(18).reshape(1,3,3,2)
x=torch.from_numpy(x)
print(x)
if shift_size > 0:
shifted_x = torch.roll(x, shifts=(1,-1), dims=(1, 2))
#shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
print("---------经过循环位移了---------")
print(shifted_x)
tensor([[[[ 0, 1],
[ 2, 3],
[ 4, 5]],
[[ 6, 7],
[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15],
[16, 17]]]], dtype=torch.int32)
---------经过循环位移了---------
tensor([[[[14, 15],
[16, 17],
[12, 13]],
[[ 2, 3],
[ 4, 5],
[ 0, 1]],
[[ 8, 9],
[10, 11],
[ 6, 7]]]], dtype=torch.int32)
从上面的例子可以看出移动只发生在指定的维度[1,2]
上,在swin-transformer
中就是图片的高、宽(其形式为[B,H,W,C]
) dim=1
的维度表示图片的H,也就是一个shape为[2,3]
的张量,shift=1表示这个维度上向下移动一个元素,多出的补到开头,而dim=2
维度为图片的W,shift=-1表示这个维度上向上移动一个元素。