pytorch中的gather函数_理解pytorch几个高级选择函数(如gather)

1. 引言

最近在刷开源的Pytorch版动手学深度学习,里面谈到几个高级选择函数,如index_select,masked_select,gather等。这些函数大多很容易理解,但是对于gather函数,确实有些难理解,官方文档开始也看得一脸懵,感觉不太直观。下面谈谈我对这几个函数的一些理解。

2. 维度的理解

对于numpy和pytorch,其数组在做维度运算上刚开始可能会给人一种直观上的误解,以numpy求矩阵某个维度的最大值为例(pytorch的理解也是一样的)

import numpy as np

a = np.arange(1, 13).reshape(3, 4)

"""

result:

a = [[1, 2, 3, 4],

[5, 6, 7, 8,],

[9, 10, 11, 12]]

"""

# 对a维度0求最大值

a.max(axis = 0)

"""

result:

[9, 10, 11, 12]

"""

# 对a维度1求最大值

a.max(axis = 1)

"""

result:

[4, 8, 12]

"""

如果对a矩阵在维度0上找最大值,根据我们直观上的经验应该是[4, 8, 12]。即从[1, 2, 3, 4]找到4,从[5, 6, 7, 8]找到8,从[9, 10, 11, 12]找到12。但是从上面结果来看,numpy运算却给了我们直观上认为是列最大值的结果[9, 10, 11, 12]。

实际numpy(pytorch)运算应该理解为往给定的维度进行移动运算。还是以维度0为例,维度0上有3个向量,分别为[1, 2, 3, 4],[5, 6, 7, 8]和[9, 10, 11, 12]。往维度0移动,即[1, 2, 3, 4]和[5, 6, 7, 8]逐元素计算最大值,得到[5, 6, 7, 8],再和[9, 10, 11, 12]运算得到结果[9, 10, 11, 12]。

db49a94cbd817256a01a75af2c5abab2.png

另外,对于维度为3的数组,在numpy和pytorch中,应该把维度0理解为通道数,维度1和维度2才是对应高和宽。如果是3维数组对应着用于多输入通道和单输出通道的卷积核(维度为U x V x D),那么4维数组就对应着用于多输入通道和多输出通道的卷积核(维度为U x V x D x P),此时,维度0则为多通道卷积核数量的方向,维度1为通道数,维度2和3才是分别对应高和宽。

9a9096bbcb584c05a7483e7b4d941f10.png

3. gather函数

pytorch和numpy中许多函数都涉及维度运算,gather也不例外,但是它相对于其他函数更难理解。依然先来看一个例子

import torch

a = torch.arange(1, 16).reshape(5, 3)

"""

result:

a = [[1, 2, 3],

[4, 5, 6],

[7, 8, 9],

[10, 11, 12],

[13, 14, 15]]

"""

# 定义两个index

b = torch.tensor([[0, 1, 2], [2, 3, 4], [0, 2, 4]])

c = torch.tensor([[1, 2, 0, 2, 1], [1, 2, 1, 0, 0]])

# axis=0

output1 = a.gather(0, b)

"""

result:

[[1, 5, 9],

[7, 11, 15],

[1, 8, 15]]

"""

# axis=1

output2 = a.gather(1, c)

"""

result:

[[2, 3, 1, 3, 2],

[5, 6, 5, 4, 4]]

"""

上面的例子看起来可能有点复杂,我们来一步步的分析它,先从gather维度为0开始讲起。

a.gather(0, b)分为3个部分,a是需要被提取元素的矩阵,0代表的是提取的维度为0,b是提取元素的索引

其中规定b和a是同维张量,即a是2维张量,b也必须是2维张量

0除了代表往维度0的方向提取元素外,还有一个特权---提取结果output可以在这个维度上的长度与a不同。打个比方,a现在的shape为(5, 3),那么提取结果output1的shape可以是(1,3),(2, 3),甚至(n, 3)。具体维度0的长度到底为多少由b来决定。

根据0的特权,导致了给定的b张量除了维度0外,其他的维度大小必须和a一样。其中张量b实际上包含以下两个信息

b可以利用除用于gather的维度(此处为维度0)外的维度来定位出唯一一个向量,也就是a[:, ?](三维度也是同理的,有a[:, ?1, ?2]),?的取值范围为a同维度的index。

对于上述定位出的向量,通过b中的元素来定位提取向量中的哪一个元素。

上面说得可能有点抽象,实际上b中的每个元素都能在a中提取出一个元素。举个具体点的例子,按照上面所说的,b[0, 0]可以提取a中的一个元素。对于b[0,0],除了维度0外,可以通过维度1来定位出唯一一个向量a[:, 0]。因为b[0, 0]的元素为0,即提取的是a[:, 0]的第0个元素---1,并将其作为output1[0, 0]的提取结果。

下图给出了维度0和维度1,gather运算的图示

2df66e76a24a527ae91a407077d6fde7.png

对于3维或者更高维度的张量gather的原理也是一样的

860f0297c30fd353e5f9f8251f4bf74c.png

4. index_select函数

其他的高级选择函数都比较容易理解,这里简单的提一下。torch.index_select主要是根据传入的tensor来往给定的axis方向来选取张量

import torch

a = torch.arange(9).reshape(3, 3)

torch.index_select(a, 0, torch.tensor([0, 2]))

"""

result:

[[0, 1, 2],

[6, 7, 8]]

"""

5. masked_select函数

实际上就是通过掩码条件来选择元素,像torch.masked_select(x, x>0.5),实际上是和x[x>0.5]等价的,最后返回的是一维张量

import torch

a = torch.rand(5, 3)

# 结果和a[a > 0.5]等价

torch.masked_select(a, a>0.5)

6. nonzero函数

找到非零元素的index

import torch

a = torch.eye(3)

torch.nonzero(a)

"""

result: 对应着非零元素的index

[[0, 0],

[1, 1],

[2, 2]]

"""

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值