pytorch中几个难理解的方法整理--gather&squeeze&unsqueeze

6 篇文章 0 订阅
3 篇文章 0 订阅
本文详细介绍了PyTorch中torch.gather函数的使用,通过多个维度的例子解释了其工作原理,以及如何理解官方给出的公式。同时,文章还探讨了torch.squeeze和unsqueeze这两个函数,分别用于移除和插入尺寸为1的维度,并提供了相关操作的实例解析。
摘要由CSDN通过智能技术生成

gather

pytorch中gather源码形式:torch.gather(input, dim, index, *, sparse_grad = False, out = None)

然后在pytorch官方文档中,写了这样的一个例子,这个例子是三维的

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

刚开始比较难理解,不知道什么意思,于是试了几个例子

一维:

>>> array1 = torch.tensor([1,2,3])
>>> torch.gather(array1, 0, torch.tensor([0,1]))

tensor([1, 2])

上述例子中,array1的矩阵形式为array1 = [1,2,3], 按维度0取值(对于一维的情况,顶多也为0), 将[array1[0],array1[1]]作为输出结果,也就是[1,2]

二维

>>> array2 = torch.tensor([[1,2,3],[4,5,6]])
>>> torch.gather(array2, 0, torch.tensor([[0, 1]]))

tensor([[1, 5]])

在上述二维的例子中,array2的形式为array2 = [[1,2,3],[4,5,6]], 按维度0取值, 将[array2[0][0], array2[1][1]]为输出结果,也就是[1,5]

看了上面两个例子,我们根据torch.gather(input, dim, index, *, sparse_grad = False, out = None)
看下公式:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

明确一点的是,输出的size是和index的size是一样的。

对于一维的,假设index的大小为n, 那么输出结果为 [input[index[0]], input[index[1]], … , input[index[n]]], 也就是我们上个例子中的[1,2]
对于二维的,如果dim为0,假设index的大小为m*n, 那么输出结果为

[[input[index[0][0]][0], input[index[0][1]][1], … , input[index[0][n]][n],
[input[index[1][0]][0], input[index[1][1]][1], … , input[index[1][n]][n],

[input[index[m][0]][0], input[index[m][1]][1], …, input[index[m][n]][n]

所以上面的例子我们输出为[1,5]

如果dim为1呢, 同样假设index的大小为m*n,那么输出结果为:

[[input[0][index[0][0]], input[1][index[0][1]], … , input[n][index[0][n]],
[input[0][index[1][0]], input[1][index[1][1]], … , input[n][index[1][n]],

[input[0][index[m][0]], input[1][index[m][1]], …, input[n][index[m][n]],

看到这里那是不是就对官网给出的公式有点理解了呢?
再来看几个3维的例子

array = torch.tensor(np.arange(24)).view(2,3,4)
array
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]], dtype=torch.int32)
torch.gather(array, 0, torch.tensor([[[0,1],[1,1]]]))

输出结果为

[[[array[0][0][0],array[1][0][1],
  [array[1][1][0],array[1][1][1]
]]
tensor([[[ 0, 13],
         [16, 17]]], dtype=torch.int32)

看到这里是不是有点理解了,不理解的平时多试试就比较清楚了。

squeeze

torch.squeeze中函数形式
torch.squeeze(input, dim = None, *, out = None) -> Tensor,默认dim参数为None
官网也描述了它的作用:Returns a tensor with all the dimensions of input of size 1 removed.
就是移除所有size为1的维度,比如说输入一个array,它的shape为(1,2,3,1,2),那么他的output的size为(2,3,2)
具体看一下例子:

>>> array = torch.zeros(1,2,3,1,2)
>>> torch.squeeze(array).size()

torch.Size([2, 3, 2])

如果某个维度的size不为1,那就不移除。

另外还有一种写法可以移除特定size为1的维度,写法torch.squeeze(array, dim)
例如:

>>> array = torch.zeros(1,2,3,1,2)
>>> torch.squeeze(array,0).size()

torch.Size([2, 3, 1, 2])

这里第四维度的1就没有移除掉。
下面我们再来看下unsqueeze方法

unsqueeze

torch官网中描述的方法 torch.unsqueeze(input, dim)->Tensor

作用是返回一个在特定维度插入size为1的tensor

>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, 0).size()
torch.Size([1, 4])

这个dim可以为多少呢?官方也是做出了解释A dim value within the range [-input.dim() - 1, input.dim() + 1) can be used. Negative dim will correspond to unsqueeze() applied at dim = dim + input.dim() + 1.

就是说,dim的输入只能在[-input.dim()-1, input.dim() + 1]范围内。在上面的例子中,维度限制在[-2, 1]之间。
如果是负数怎么处理呢? dim = dim + input.dim() + 1, 也就是说,如果输入-2, 那么应该输出dim = 0,
其实从这个公式,和list中里面的选取元素差不多,
例如list = [1,2,3,4]; list[0]= 1, list[-1] = 4,相当于 dim为-1 就是在最高维插入size为1, 而当dim为**-input.dim() - 1**相当于在维度0处插入size为1。

例子

>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, -2).size()
torch.Size([1, 4])
>>> torch.unsqueeze(x, -1).size()
torch.Size([4, 1])
>>> torch.unsqueeze(x, 0).size()
torch.Size([1, 4])

看到这里是不是就理解了呢?

(其他方法后续跟新)

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值