[pytorch]torch.roll函数

torch中的roll函数可以用于张量的位置变换操作。
博客推荐

import torch    
import numpy as np   
import matplotlib.pyplot as plt

shift_size = 3
'''构造多维张量'''
x=np.arange(18).reshape(1,3,3,2)
x=torch.from_numpy(x)
print(x)
if shift_size > 0:
    shifted_x = torch.roll(x, shifts=(1,-1), dims=(1, 2))
    #shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
    print("---------经过循环位移了---------")
    print(shifted_x)
tensor([[[[ 0,  1],
          [ 2,  3],
          [ 4,  5]],

         [[ 6,  7],
          [ 8,  9],
          [10, 11]],

         [[12, 13],
          [14, 15],
          [16, 17]]]], dtype=torch.int32)
---------经过循环位移了---------
tensor([[[[14, 15],
          [16, 17],
          [12, 13]],

         [[ 2,  3],
          [ 4,  5],
          [ 0,  1]],

         [[ 8,  9],
          [10, 11],
          [ 6,  7]]]], dtype=torch.int32)

从上面的例子可以看出移动只发生在指定的维度[1,2]上,在swin-transformer中就是图片的高、宽(其形式为[B,H,W,C]) dim=1的维度表示图片的H,也就是一个shape为[2,3]的张量,shift=1表示这个维度上向下移动一个元素,多出的补到开头,而dim=2维度为图片的W,shift=-1表示这个维度上向上移动一个元素。

ps:dim=1固定一个数后(例如x[0][1]),就是[2,3]的张量,也就是可以看成固定了某一行后,这行有两列,可以理解为两个patch,每个patch由3个元素来表示。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值