PyTorch 函数解释:torch.gather()


原文链接请参考:https://dreamhomes.top/posts/201906081516.html


参考官网:torch.gather

用法:torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor

收集输入的特定维度dim指定位置index的数值。

对于一个三维tensor,结果如下:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

参数说明:

  • input (Tensor) – 输入张量
  • dim (int) – 索引维度
  • index (LongTensor) – 收集元素的索引
  • out (Tensor, optional) – 目标张量
  • sparse_grad (bool,optional) – 输入为稀疏张量

直接看官网解释有点不明白,参考另一文章的实例说明:https://blog.csdn.net/cpluss/article/details/90260550

在序列标注问题上,我们给每一个单词都标上一个标签。不妨假设我们有4个句子,每个句子的长度不一定相同,标签如下:

input = [
    [2, 3, 4, 5],
    [1, 4, 3],
    [4, 2, 2, 5, 7],
    [1]
]

上例中有四个句子,长度分别为4,3,5,1,其中第一个句子的标签为2,3,4,5。我们知道,处理自然语言问题时,一般都需要进行padding,即将不同长度的句子padding到同一长度,以0为padding,那么上述经padding后变为:

input = [
    [2, 3, 4, 5, 0, 0],
    [1, 4, 3, 0, 0, 0],
    [4, 2, 2, 5, 7, 0],
    [1, 0, 0, 0, 0, 0]
]

那么问题来了,现在我们想获得每个句子中最后一个词语的标签,该怎么得到呢?既是,第一句话中的5,第二句话中的3,第三句话中7,第四句话中的1。

此时就需要用gather函数。

此时我们的input就是填充之后的tensor,dim=1, index就是各个句子的长度,即[[4],[3],[5],[1]]。之所以维度是4*1,是为了满足index维度和input维度之间的关系(讲解参数时有讲)。

代码如下所示:

In [26]: import torch
    ...: input = [
    ...:     [2, 3, 4, 5, 0, 0],
    ...:     [1, 4, 3, 0, 0, 0],
    ...:     [4, 2, 2, 5, 7, 0],
    ...:     [1, 0, 0, 0, 0, 0]
    ...: ]
    ...: input = torch.tensor(input)
    ...: #注意index的类型
    ...: length = torch.LongTensor([[4],[3],[5],[1]])
    ...: #index之所以减1,是因为序列维度是从0开始计算的
    ...: out = torch.gather(input, 1, length-1)
    ...: out
Out[26]:
tensor([[5],
        [3],
        [7],
        [1]])

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值