Pytorch:torch.Tensor.scatter()和torch.Tensor.gather

miguemath指出:torch.Tensor.scatter_()是torch.gather()函数的方向反向操作。两个函数可以看成一对兄弟函数。gather用来解码one hot,scatter_用来编码one hot。

接下来我们用例子看一下:

1. Tensor.scatter(dim, index, src) → Tensor

它是torch.Tensor.scatter_()的错位版本,即:

scatter_(dim, index, src, reduce=None) → Tensor

该函数用来scatter

对于一个3-D的tensor,self会被更新为

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

link:官方文档1

在自然语言处理当中,相应任务的完成离不开vocabulary,vocab将相应的词语token映射为id,我们假设有一组词语,他们的id号是[1,2,0,3], 我们可以用scatter函数得到他们的onehot编码。

>>> index = torch.tensor([1,2,0,3])
>>> index = index.unsqueeze(-1)
>>> index
tensor([[1],[2],[0],[3]])
>>> onehot = torch.zeros(4, 4)
>>> onehot
tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])
>>> onehot.scatter_(1, index, 1)
tensor([[0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 0., 0., 1.]])

这里该函数的执行逻辑应该是:
onehot[i][index[i][j]] = 1 for i in range(index.shape[0]) for j in range(index.shape[1])

这里第三个参数可以为1的原因是该函数除了src还有一个输入参数, 在src没有指定时使用(详见文档):

value (float) – the source element(s) to scatter, incase src is not specified

2. Tensor.gather(dim, index) → Tensor

该函数等价于:

torch.gather(input, dim, index, sparse_grad=False, out=None) → Tensor

这里的input就是我们的使用对象self,这两个等价函数的输出如下:

对于一个3-D的tensor,output为

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

link:官方文档2

我们可以利用gather函数将one hot编码中的具体数值取出来(一般很少是从one hot编码中取,而是从模型得到的、关于词表的一个概率向量中取出想要词汇的概率)

>>> prob_vector= torch.rand((3,5))
>>> prob_vector
tensor([[0.5145, 0.7593, 0.0271, 0.2807, 0.0975],
        [0.8461, 0.1193, 0.7042, 0.3711, 0.7330],
        [0.3504, 0.7746, 0.5122, 0.8491, 0.7501]])
>>> index = torch.tensor([2,1,3])
>>> index = index.unsqueeze(-1)
>>> index
tensor([[2],
        [1],
        [3]])
>>> target_prob = prob_vector.gather(1,index)
>>> target_prob
tensor([[0.0271],
        [0.1193],
        [0.8491]])

我们现在有一个概率向量prob_vector,我们想取出与ground truth词汇对应的那个概率,并计算交叉熵。比如第一时间步groud truth词汇的id为2,后面分别为1,3,这时,我们可以通过gather取出相应的概率,就可以很简单地计算交叉熵(将取出的概率取log()即可,因为ground truth词汇的one hot vector在对应位置概率就是1,其他位置全为零),理论上这比整个概率向量(比如[0.5145, 0.7593, 0.0271, 0.2807, 0.0975])直接去与ground truth的one hot编码(比如[0,0,0,1,0])去做交叉熵效率更高。

同样,如果不是在求交叉熵的场景,你也可以利用gather函数从原tensor中取出任何的元素。

3、参考资料

[1]https://pytorch.org/docs/stable/
[2]one hot编码:torch.Tensor.scatter_()函数用法详解

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值