Swin Transformer中torch.roll()详解

torch.roll()这个函数看官方解释很懵,直接对照可视化来理解
参考:torch.roll 函数的理解

torch.roll(x, shifts=(40, 40), dims=(1, 2))
在这里插入图片描述
这里img的shape是[1,56,56,96],即[B,H,W,C]格式。
dim=1,shift=40指的就是数据沿着H维度,将数据朝正反向滚动40,超出部分循环回到图像中
dim=2,shift=40指的就是数据沿着W维度,将数据朝正反向滚动40,超出部分循环回到图像中
这里的原点是左上角,H的正方向向下,W正方向向右
可视化代码:

import torch    
import numpy as np   
import matplotlib.pyplot as plt

shift_size = 3
'''构造多维张量'''
x=np.arange(301056).reshape(1,56,56,96)
x=torch.from_numpy(x)

if shift_size > 0:
    shifted_x = torch.roll(x, shifts=(40, 40), dims=(1, 2))
    #shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
    print("---------经过循环位移了---------")
else:
    shifted_x = x
   
'''可视化部分'''
plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.imshow(x[0,:,:,0])
plt.title("orgin_img")
plt.subplot(1,2,2)
plt.imshow(shifted_x[0,:,:,0])
if torch.equal(shifted_x, x):
    plt.title("non_shifted")
else:
    plt.title("shifted_img")
plt.show()
plt.pause(5)
plt.close()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值