torch.roll()函数使用方法

官方文档在这里,说的比较清楚,但是举的例子不是很直观。我们再详细解释一下:

torch.roll(input, shifts, dims=None) → Tensor

  • input:输入的tensor
  • shifts滚动的方向和长度,若为正,则向索引大的方向滚动;若为负,则向索引减小的方向滚动。可是一个整数也可以是一个元组
  • dims:tensor滚动的维度;要和shifts设置的数量对齐。

这里要特别指出,如果移动的位置已经超出本身维度的大小,就补到反方向去。即不会丢弃数值,也不会凭空补齐数值。会循环滚动。

# 二维tensor举例
x = torch.tensor([[0, 0, 0, 0],
                  [0, 1, 1, 0],
                  [0, 1, 1, 0],
                  [0, 0, 0, 0]])
                  
# 第0维向索引大的方向滚动1个位置,即整体向下移动1个像素,第1行的元素由第4行元素补齐
y = torch.roll(x, shifts=(1), dims=(0))
tensor([[0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 1, 1, 0],
        [0, 1, 1, 0]])

# 第0维向索引小的方向滚动1个位置,即整体向上移动1个像素,第4行的元素由第1行元素补齐
x = torch.roll(x, shifts=(-1), dims=(0))
tensor([[0, 1, 1, 0],
        [0, 1, 1, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0]])
        
# 第1维向索引大的方向滚动2个位置,即整体向右移动2个像素,第1、2列的元素由第3、4列元素补齐        
x = torch.roll(x, shifts=(0, 2), dims=(0, 1))
tensor([[0, 0, 0, 0],
        [1, 0, 0, 1],
        [1, 0, 0, 1],
        [0, 0, 0, 0]])
        
# 第0、1维向索引小的方向滚动2像素,即整体向下、右移动2个像素
x = torch.roll(x, shifts=(2, 2), dims=(0, 1))
tensor([[1, 0, 0, 1],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [1, 0, 0, 1]])
        

torch.roll()的使用比较简单,在实际中的应用也比较多,比如在swin-transformer中,利用torch.roll()进行MSA的计算,具体原理就不讲解了:

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值