一、torch.roll()函数参数
torch.roll(input, shifts, dims=None)
功能:按照指定的维度滚动tensor, 如果元素超出了维度,则回归到最初的位置。
input:输入的tensor
shifts:可以为int,也可以是int型的元组。张量的元素移位的位数。如果移位是一个元组,则dim必须是相同大小的元组,并且每个维度将按相应的值滚动。
dims:roll的维度,沿着那个维度滚动。
二、使用示例
代码如下(示例):
>>>import torch
>>>x = torch.arange(1, 17).view(4, 4)
>>>x
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[13, 14, 15, 16]])
>>>y=torch.roll(x,shifts=1,dims=0)#按第0维,移位为1(即顺序),滚动
>>>y
tensor([[13, 14, 15, 16],
[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]])
>>>y1=torch.roll(x,shifts=-1,dims=0)#按第0维,移位为-1(即逆序),滚动
>>>y1
tensor([[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[13, 14, 15, 16],
[ 1, 2, 3, 4]])
>>>y2=torch.roll(x,shifts=2,dims=0)#按第0维,移位为2,滚动
>>>y2
tensor([[ 9, 10, 11, 12],
[13, 14, 15, 16],
[ 1, 2, 3, 4],
[ 5, 6, 7, 8]])
>>>z=torch.roll(x,shifts=1,dims=1)#按第1维,移位为1,滚动
>>>z
tensor([[ 4, 1, 2, 3],
[ 8, 5, 6, 7],
[12, 9, 10, 11],
[16, 13, 14, 15]])
>>>z1=torch.roll(x,shifts=-1,dims=1)#按第1维,移位为-1,滚动
>>>z1
tensor([[ 2, 3, 4, 1],
[ 6, 7, 8, 5],
[10, 11, 12, 9],
[14, 15, 16, 13]])
>>>z2=torch.roll(x,shifts=2,dims=1)#按第1维,移位为2,滚动
>>>z2
tensor([[ 3, 4, 1, 2],
[ 7, 8, 5, 6],
[11, 12, 9, 10],
[15, 16, 13, 14]])
>>>o=torch.roll(x,shifts=(2,1),dims=(0,1))#第0维,移位为2滚动;第1维,移位为1滚动
>>>o
tensor([[12, 9, 10, 11],
[16, 13, 14, 15],
[ 4, 1, 2, 3],
[ 8, 5, 6, 7]])
总结
总结了torch.roll()函数的用法。