pytorch中的gather函数_Pytorch中torch.gather函数祥解

本文详细解释了PyTorch中的torch.gather()函数,用于在指定维度上收集Tensor的值。通过示例介绍了如何按照不同维度取值,并提供了一个在多分类问题中获取标签对应概率的应用场景。
摘要由CSDN通过智能技术生成

原创申明:本文为作者原创,转载请注明出处!

引言:在多分类中,torch.gather常用来取出标签所对应的概率,但对于刚开始接触Pytorch的同学来说,torch.gather()可能不太好理解,这里做一些说明和演示,帮助理解。

官方说明

gather( input', dim, index, out=None, sparse_grad=False)

Gathers values along an axis specified by dim

沿着给定的维度dim收集值

Args: 参数(初学者可只看前三个参数)

input (Tensor): the source tensor 源tensor(Tensor类型)

dim (int): the axis along which to index 要进行索引的轴方向(int类型)

index (LongTensor): the indices of elements to gather(LongTensor类型)

out (Tensor, optional): the destination tensor 返回值(Tensor类型)

sparse_grad(bool,optional): If True, gradient w.r.t. :attr:input will be a sparse tensor. 若为真,这关于input的梯度为sparse tensor

注意:index的维度要和input中dim所指的维度相同

例子说明

1) 按照dim = 0, 取一个2*2 tensor的对角线上的数值

#按照dim = 0, 取一个2*2tensor的对角线上的数值

import torch

a = torch.Tensor([[1, 2],

[3, 4]])

b = torch.gather(a, dim = 0, index=torch.LongTensor([[0, 1]]))

print('a = ', a)

print('b = ', b)

输出如下:

a = tensor([[1., 2.],

[3., 4.]])

b = tensor([[1., 4.]])

说明:

可以看到a的dim=0, 即行方向的维度和index的维度是匹配的,就是说a和index由行方向从左往右看,有2列,即有2个样本,行方向是匹配的。另外,函数输出的tensor和index大小相同。

上面代码的操作逻辑是:

在a中,由行

math?formula=%5Ccolor%7Bred%7D%7B%E4%BB%8E%E5%B7%A6%E5%BE%80%E5%8F%B3%7D看,有两个样本,索引分别为0和1;每个样本有两个特征,每个特征中

math?formula=%5Ccolor%7Bred%7D%7B%E4%BB%8E%E4%B8%8A%E5%BE%80%E4%B8%8B%7D索引分别为0和1;依据index中的索引值,取第0样本的第0个特征1,再取第1个样本的第1个特征4

2) 按照dim = 1, 取一个2*2 tensor的对角线上的数值

#按照dim = 1, 取一个2*2 tensor的对角线上的数值

import torch

a = torch.Tensor([[1, 2],

[3, 4]])

c = torch.gather(a, dim = 1, index=torch.LongTensor([[0],

[1]]))

print('a = ', a)

print('c = ', c)

输出如下:

a = tensor([[1., 2.],

[3., 4.]])

c = tensor([[1.],

[4.]])

说明:

可以看到a的dim=1, 即列方向的维度和index的维度是匹配的,就是说a和index由列方向从上往下看,有2行,即有2个样本,列方向是匹配的。另外,函数输出的tensor和index大小相同。

上面代码的操作逻辑是:

在a中,由列

math?formula=%5Ccolor%7Bred%7D%7B%E4%BB%8E%E4%B8%8A%E5%BE%80%E4%B8%8B%7D看,有两个样本,索引分别为0和1;每个样本有两个特征,每个特征中

math?formula=%5Ccolor%7Bred%7D%7B%E4%BB%8E%E5%B7%A6%E5%BE%80%E5%8F%B3%7D索引分别为0和1;依据index中的索引值,取第0样本的第0个特征1,再取第1个样本的第1个特征4。

3)更复杂一点的例子

index变为2*2的longtensor

#

import torch

a = torch.Tensor([[1, 2],

[3, 4]])

d = torch.gather(a, dim= 0, index=torch.LongTensor([[0, 0],

[1, 0]]))

print('a = ', a)

print('d = ', d)

输出:

a = tensor([[1., 2.],

[3., 4.]])

d = tensor([[1., 2.],

[3., 2.]])

说明:

index可看做是行[[0, 0]] 和 [[1, 0]]的组合,从上往下,先[[0, 0]] 再[[1, 0]],根据例子1)中的逻辑可知输出为d。如果是dim = 1, 则index按照列[[0, 1]] T 和 [[0, 0]]T的组合(T表示转置),从左往右,先[[0, 1]] T 再 [[0, 0]]T,按照2)中的逻辑,得可输出。

实际中的一个例子

有三个标签[0, 1, 2],即三个类别。现在知道两个样本(A 和 B)所得到的三个标签的概率分别为[0.1, 0.3, 0.6]和[0.3, 0.2, 0.5], 用myY_hat表示, 这两个样本的真实标签分别为0和2, 那么我们很容易知道A所预测的真实标签的概率为0.1, B所预测的真实标签的概率为0.5,A误分类,B正确分类。那么用程序这么获得标签对应的概率呢,这里就可以用gather函数。

myY_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])

myY = torch.LongTensor([0, 2])

print(myY.view(-1, 1))

print(myY_hat.gather(1, myY.view(-1, 1)))

输出:

tensor([[0],

[2]])

tensor([[0.1000],

[0.5000]])

附:

Tensor的基本数据类型有五种:

32位浮点型:torch.FloatTensor,pyorch.Tensor()默认的就是这种类型。

64位整型:torch.LongTensor。

32位整型:torch.IntTensor。

16位整型:torch.ShortTensor。

64位浮点型:torch.DoubleTensor。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: torch.gather函数PyTorch的一个函数,用于在给定维度上按索引从输入张量提取元素并构建新的张量。 torch.gather函数的语法为:torch.gather(input, dim, index, out=None)。 参数说明: - input:输入张量,即需要从提取元素的张量。 - dim:要在哪个维度上进行提取操作。 - index:一个包含需要提取元素的索引的张量。 - out:一个可选的输出张量。 在torch.gather函数,我们会按照dim指定的维度,在input张量上进行提取操作。提取操作是根据index张量给定的索引值来进行的。最终会构建一个新的张量,其包含了根据索引从input张量提取出来的元素。 例如,如果input是一个2维张量,shape为(3,4),而index是一个1维张量,shape为(3,),则dim的取值范围为[0, 1]。如果dim=0,那么提取操作将沿着第一个维度进行,在每一列上按照index张量对应的值进行元素的提取。如果dim=1,那么提取操作将沿着第二个维度进行,在每一行上按照index张量对应的值进行元素的提取。 使用torch.gather函数可以灵活地根据给定的索引从输入张量提取出所需的元素,这对于实现一些特定需求的操作非常有用。例如,可以在处理图像分类任务时,根据预测的类别标签,从softmax输出概率提取出对应类别的概率,进而用于计算损失函数或者评估模型性能等。 ### 回答2: torch.gather函数是一个PyTorch的操作函数,用于在指定维度上根据索引获取原始张量的元素。这个函数的使用方式为: output = torch.gather(input, dim, index, out=None, sparse_grad=False) 其,input是原始的张量,dim是指定的维度,index是需要提取的元素的索引。函数会根据dim指定的维度,在input张量提取index指定的元素,并返回一个新的张量output。 例如,假设input是一个3x4的二维张量,index是一个2x3的二维张量,dim的取值为1,那么torch.gather函数会在input的第1个维度上根据index的元素索引,提取相应的元素。最终得到的output是一个2x3的张量。 torch.gather函数在很多机器学习任务非常有用。例如,在序列标注任务,我们可以使用torch.gather函数根据标签索引来选择对应的预测结果。在图像分类任务,我们可以根据类别索引使用torch.gather函数进行结果的选择。此外,在自然语言处理任务torch.gather函数也可以用来根据单词的索引来选择对应的词向量。 需要注意的是,所提取的元素的维度必须与index的维度一致,否则会引发异常。此外,dim的取值必须在0到input的维度之间,否则也会引发异常。如果不指定out参数,函数会返回一个新的张量作为输出,如果指定了out参数,则会把提取的结果保存到指定的张量。最后,如果sparse_grad为True,则会返回一个稀疏梯度,否则返回一个密集梯度。 总之,torch.gather函数提供了一种方便和高效地根据索引提取元素的方式,广泛应用于各种机器学习任务
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值