Pytorch中的torch.gather函数的理解

Pytorch中的torch.gather函数的理解

Pytorch中的torch.gather函数

pytorch比tensorflow更加编程友好,准备用pytorch试着做一些实验。
先看一下简单的用法示例代码,然后结合官方示例来解读:

b = torch.Tensor([[1,2,3],[4,5,6]])
#   1 2 3
#   4 5 6 
print b
index_0 = torch.LongTensor([[1],[2]]) 
#  [[1],  元素对应行标为[[0],  列标为[[0],
#   [2]]                [1]]         [0]]  
#  index_0[0][0] == 1, index_0[1][0] == 2
print (torch.gather(b, dim=1, index=index_0))
#dim=1, out[i][j][k] = input[i][index[i][j][k]][k]
#  [[1], 替换列标 [[0], out元素对  [[0],  列标为[[1],
#   [2]]          [0]]  应b的行标  [1]]         [2]]
#  out[0][0]= b[0][1]= 2 ,
#  out[1][0]= b[1][2]= 6
index_1 = torch.LongTensor([[0,1],[2,0]])
#  [[0,1],   元素对应行标为[[0,0],  列标为[[0,1],
#   [2,0]]                 [1,1]]         [0,1]]
#index_1[0][0] == 0, index_1[0][1] == 1, index_1[1][0] == 2, index_1[1][1] == 0
print (torch.gather(b, dim=1, index=index_1))
#dim=1, out[i][j][k] = input[i][index[i][j][k]][k]
# [[0,1], 替换列标[[0,1],  out元素对 [[0,0],  列标为[[0,1],
#  [2,0]]         [0,1]]  应b的行标  [1,1]]         [2,0]]  
# out[0][0]= b[0][0]= 1, out[0][1]= b[0][1]= 2,
# out[1][0]= b[1][2]= 6, out[1][1]= b[1][0]= 4,
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
#  [[0,1,1],   元素对应行标为[[0,0,0],  列标为[[0,1,2],
#   [0,0,0]]                 [1,1,1]]         [0,1,2]]
#...
print (torch.gather(b, dim=0, index=index_2))
#dim=0,  out[i][j][k] = input[index[i][j][k]][j][k] 
# [[0,1,1],替换行标[[0,0,0],   out元素对 [[0,1,1], 列标为[[0,1,2], 
# [0,0,0]]         [1,1,1]]    应b的行标  [0,0,0]]       [0,1,2]]  
# out[0][0]= b[0][0]= 1,out[0][1]= b[1][1]= 5, out[0][2]= b[1][2]= 6 ,
# out[1][0]= b[0][0]= 1,out[1][1]= b[0][1]= 2, out[1][2]= b[0][2]= 3 ,

输出结果:

1  2  3
4  5  6
[torch.FloatTensor of size 2x3]

tensor([[2.],
        [6.]])

 1  2
 6  4
[torch.FloatTensor of size 2x2]

 1  5  6
 1  2  3
[torch.FloatTensor of size 2x3]

结合上面的例子来看官方解读及示例,官方解读是给了三个公式:

torch.gather(input, dim, index, out=None) → Tensor
 '''
    Gathers values along an axis specified by dim.

    For a 3-D tensor the output is specified by:

    out[i][j][k] = input[index[i][j][k]][j][k]  # dim=0
    out[i][j][k] = input[i][index[i][j][k]][k]  # dim=1
    out[i][j][k] = input[i][j][index[i][j][k]]  # dim=2

    Parameters:	

        input (Tensor) – The source tensor
        dim (int) – The axis along which to index
        index (LongTensor) – The indices of elements to gather
        out (Tensor, optional) – Destination tensor

    Example:
'''
    >>> t = torch.Tensor([[1,2],[3,4]])
    >>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
     1  1
     4  3
    [torch.FloatTensor of size 2x2] 

可以看出index的形状和input的维度是一致的,都是二维的,里面的index元素数值不能超过input的界限,比如行的不能超过1,列的不能超过2。

理解了这几个式子也就记住了这个方法的用法。

### 回答1: torch.gather函数PyTorch的一个函数,用于在给定维度上按索引从输入张量提取元素并构建新的张量。 torch.gather函数的语法为:torch.gather(input, dim, index, out=None)。 参数说明: - input:输入张量,即需要从提取元素的张量。 - dim:要在哪个维度上进行提取操作。 - index:一个包含需要提取元素的索引的张量。 - out:一个可选的输出张量。 在torch.gather函数,我们会按照dim指定的维度,在input张量上进行提取操作。提取操作是根据index张量给定的索引值来进行的。最终会构建一个新的张量,其包含了根据索引从input张量提取出来的元素。 例如,如果input是一个2维张量,shape为(3,4),而index是一个1维张量,shape为(3,),则dim的取值范围为[0, 1]。如果dim=0,那么提取操作将沿着第一个维度进行,在每一列上按照index张量对应的值进行元素的提取。如果dim=1,那么提取操作将沿着第二个维度进行,在每一行上按照index张量对应的值进行元素的提取。 使用torch.gather函数可以灵活地根据给定的索引从输入张量提取出所需的元素,这对于实现一些特定需求的操作非常有用。例如,可以在处理图像分类任务时,根据预测的类别标签,从softmax输出概率提取出对应类别的概率,进而用于计算损失函数或者评估模型性能等。 ### 回答2: torch.gather函数是一个PyTorch的操作函数,用于在指定维度上根据索引获取原始张量的元素。这个函数的使用方式为: output = torch.gather(input, dim, index, out=None, sparse_grad=False) 其,input是原始的张量,dim是指定的维度,index是需要提取的元素的索引。函数会根据dim指定的维度,在input张量提取index指定的元素,并返回一个新的张量output。 例如,假设input是一个3x4的二维张量,index是一个2x3的二维张量,dim的取值为1,那么torch.gather函数会在input的第1个维度上根据index的元素索引,提取相应的元素。最终得到的output是一个2x3的张量。 torch.gather函数在很多机器学习任务非常有用。例如,在序列标注任务,我们可以使用torch.gather函数根据标签索引来选择对应的预测结果。在图像分类任务,我们可以根据类别索引使用torch.gather函数进行结果的选择。此外,在自然语言处理任务torch.gather函数也可以用来根据单词的索引来选择对应的词向量。 需要注意的是,所提取的元素的维度必须与index的维度一致,否则会引发异常。此外,dim的取值必须在0到input的维度之间,否则也会引发异常。如果不指定out参数,函数会返回一个新的张量作为输出,如果指定了out参数,则会把提取的结果保存到指定的张量。最后,如果sparse_grad为True,则会返回一个稀疏梯度,否则返回一个密集梯度。 总之,torch.gather函数提供了一种方便和高效地根据索引提取元素的方式,广泛应用于各种机器学习任务
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值