pytorch函数(1):torch.gather()的正确理解方法

torch.gather()

pytorch 官方文档

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

作用:沿着由dim指定的轴收集数值。

  • 参数
input (Tensor) – 目标变量,输入
dim (int) – 需要沿着取值的坐标轴
index (LongTensor) – 需要取值的索引矩阵
sparse_grad (bool,optional) – 如果为真,输入将是一个稀疏张量
out (Tensor, optional) – 输出
  • 简单例子
>>> t = torch.tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
tensor([[ 1,  1],
        [ 4,  3]])

解释与例子

torch.gather迷惑大家的点主要是在dim的问题上,大家不知道怎么个取值法。 它的官方文档不太好理解,举得例子也太过于简单,下面我将举个例子详细说明gather的原理。

  • dim=0的情况
input = [
    [0.0, 0.1, 0.2, 0.3],
    [1.0, 1.1, 1.2, 1.3],
    [2.0, 2.1, 2.2, 2.3]
]#shape [3,4]
input = torch.tensor(input)
length = torch.LongTensor([
    [2,2,2,2],
    [1,1,1,1],
    [0,0,0,0],
    [0,1,2,0]
])#[4,4]
out = torch.gather(input, dim=0, index=length)
print(out)

结果

tensor([[2.0000, 2.1000, 2.2000, 2.3000],
        [1.0000, 1.1000, 1.2000, 1.3000],
        [0.0000, 0.1000, 0.2000, 0.3000],
        [0.0000, 1.1000, 2.2000, 0.3000]])

这里直接给到大家对应的取值矩阵。
矩阵
gather函数就是通过length矩阵也就是我上述的矩阵来进行取值的。

例 如 X 20 就 代 表 了 i n p u t 的 第 2 行 第 0 列 的 2.0 例如X_{20}就代表了input的第2行第0列的2.0 X20input202.0
因为dim=0代表的是横向,按行取值。

即length矩阵中的数的值代表的是行数,数的位置代表的列数,比如length矩阵中的第三行第三列(从0数起)的数 0,其值是0,代表在input中所取的数是第0行,位置是第三列,则表示在input中所取的数是第三列
X 03 = 0.3 X_{03} =0.3 X03=0.3

  • dim=1
input = [
    [0.0, 0.1, 0.2, 0.3],
    [1.0, 1.1, 1.2, 1.3],
    [2.0, 2.1, 2.2, 2.3]
]#shape [3,4]
input = torch.tensor(input)
length = torch.LongTensor([
    [2,2,2,2],
    [1,1,1,1],
    [0,1,2,0]
])#[3,4]
out = torch.gather(input, dim=1, index=length)
print(out)

结果

tensor([[0.2000, 0.2000, 0.2000, 0.2000],
        [1.1000, 1.1000, 1.1000, 1.1000],
        [2.0000, 2.1000, 2.2000, 2.0000]])

同理与dim=0时相反,数值代表列数,位置代表行数。所以对应的取值矩阵是
[ X 02 X 02 X 02 X 02 X 11 X 11 X 11 X 11 X 20 X 21 X 22 X 20 ] \begin{bmatrix} X_{02}& X_{02}& X_{02}& X_{02}\\ X_{11}& X_{11}& X_{11}& X_{11}\\ X_{20}& X_{21}& X_{22}& X_{20} \end{bmatrix} X02X11X20X02X11X21X02X11X22X02X11X20

总结一下

  • dim=0
    length矩阵中的数的值代表的是行数,所处位置代表的是列数。
    所以此时length矩阵里面的数的值应当小于input.shape[0]=3(不能等于因为是从0计数),length矩阵的的最大列数应当小于等于input.shape[1]=4(此处可以等于因为是列数)

  • dim=1
    length矩阵中的数的值代表的是列数,所处位置代表的是行数。
    所以此时length矩阵里面的数的值应当小于input.shape[1]=4(不能等于因为是从0计数),length矩阵的的最大行数应当小于等于input.shape[0]=3(此处可以等于因为是行数)
    如有疑问欢迎留言询问哦。

### 回答1: torch.gather函数PyTorch中的一个函数,用于在给定维度上按索引从输入张量中提取元素并构建新的张量。 torch.gather函数的语法为:torch.gather(input, dim, index, out=None)。 参数说明: - input:输入张量,即需要从中提取元素的张量。 - dim:要在哪个维度上进行提取操作。 - index:一个包含需要提取元素的索引的张量。 - out:一个可选的输出张量。 在torch.gather函数中,我们会按照dim指定的维度,在input张量上进行提取操作。提取操作是根据index张量中给定的索引值来进行的。最终会构建一个新的张量,其中包含了根据索引从input张量中提取出来的元素。 例如,如果input是一个2维张量,shape为(3,4),而index是一个1维张量,shape为(3,),则dim的取值范围为[0, 1]。如果dim=0,那么提取操作将沿着第一个维度进行,在每一列上按照index张量中对应的值进行元素的提取。如果dim=1,那么提取操作将沿着第二个维度进行,在每一行上按照index张量中对应的值进行元素的提取。 使用torch.gather函数可以灵活地根据给定的索引从输入张量中提取出所需的元素,这对于实现一些特定需求的操作非常有用。例如,可以在处理图像分类任务时,根据预测的类别标签,从softmax输出概率中提取出对应类别的概率,进而用于计算损失函数或者评估模型性能等。 ### 回答2: torch.gather函数是一个PyTorch中的操作函数,用于在指定维度上根据索引获取原始张量中的元素。这个函数的使用方式为: output = torch.gather(input, dim, index, out=None, sparse_grad=False) 其中,input是原始的张量,dim是指定的维度,index是需要提取的元素的索引。函数会根据dim指定的维度,在input张量中提取index中指定的元素,并返回一个新的张量output。 例如,假设input是一个3x4的二维张量,index是一个2x3的二维张量,dim的取值为1,那么torch.gather函数会在input的第1个维度上根据index中的元素索引,提取相应的元素。最终得到的output是一个2x3的张量。 torch.gather函数在很多机器学习任务中非常有用。例如,在序列标注任务中,我们可以使用torch.gather函数根据标签索引来选择对应的预测结果。在图像分类任务中,我们可以根据类别索引使用torch.gather函数进行结果的选择。此外,在自然语言处理任务中,torch.gather函数也可以用来根据单词的索引来选择对应的词向量。 需要注意的是,所提取的元素的维度必须与index的维度一致,否则会引发异常。此外,dim的取值必须在0到input的维度之间,否则也会引发异常。如果不指定out参数,函数会返回一个新的张量作为输出,如果指定了out参数,则会把提取的结果保存到指定的张量中。最后,如果sparse_grad为True,则会返回一个稀疏梯度,否则返回一个密集梯度。 总之,torch.gather函数提供了一种方便和高效地根据索引提取元素的方式,广泛应用于各种机器学习任务中。
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值