Pytorch结合语言模型理解torch.gather

Pytorch结合语言模型理解torch.gather

torch.gather一直是我觉得不太理解的函数,往往在代码里也比较少会碰到,偶尔碰到了,去查一下官方文档,查了也不太看的懂,这里结合语言模型,来解释一下torch.gather的意义,一切还是从官方文档说起

文档解释

torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor
Gathers values along an axis specified by dim.

For a 3-D tensor the output is specified by:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

Parameters

  • input (Tensor) – the source tensor
  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to gather
  • out (Tensor, optional) – the destination tensor
  • sparse_grad (bool,optional) – If True, gradient w.r.t. input will be a sparse tensor.

文档的解释就是在某个维度上gather(收集?),并且给了一个3D tensor的例子

我们来看函数的三个主要的输入

  • input,就是我们要gather的所在的Tensor
  • dim,就是我们gather的维度
  • index,确定了要gather的维度后,我们还需要确认需要在这个维度上gather哪些值,这由index确定

相信根据上面的三维例子,我们很容易能写出这个函数在二维Tensor上的效果

out[i][j] = input[index[i][j]][j]  # if dim == 0
out[i][j] = input[i][index[i][j]]  # if dim == 1

虽然类似的各种维度的结果我们都能很轻易的写出,但相信读者和我一样疑惑,这个解释太抽象了,具体地说,这玩意到底有啥用?

今天我在看语言模型的Loss的时候,发现gather竟然出现在了里面,于是我仔细思考了一下,发现了gather函数的方便之处,这些还是要从语言模型讲起。

语言模型

简单来讲,假设我们现在需要构建一个以序列为输出的模型(你可以想象image captioning或者机器翻译),此时我们的Decoder需要输出一个序列,具体地,它在每个time step输出一个分布,然后我们在这个分布中采样一个词,输入到模型中,再得到下一个time step的输出,如此反复直到它生成一个End of Sentence(EOS) token,这就是一个普通的RNN based Decoder

这样的模型的损失函数往往就是下面这个式子
l o s s = − ∑ t = 1 T log ⁡ ( p θ ( y t ∣ y < t ) ) loss=-\sum_{t=1}^{T}\log(p_\theta(y_t|y_{<t})) loss=t=1Tlog(pθ(yty<t))

其中 y y y是ground truth序列

假设我们前面的模型都已经写好了,我们已经算出了所需要的结果,它是一个 T × v o c a b _ s i z e T\times vocab\_size T×vocab_size的矩阵,其中的每行都做了log softmax,我们也有ground truth,是一个T维的向量,每一维上是一个数字,是词对应的编码,现在我们要计算损失,应该怎么办呢?

为了更加直观,就举个实际的例子T=2,vocab_size=3,输出的最后结果如下
( m a n h e y E O S − 0.1 − 0.2 − 0.3 − 1.1 − 1.2 − 2.3 − 1.3 − 0.6 − 1.3 ) \left( \begin{array}{ccc} man & hey & EOS\\ -0.1 & -0.2 & -0.3\\ -1.1 & -1.2 & -2.3\\ -1.3 & -0.6 & -1.3\\ \end{array} \right) man0.11.11.3hey0.21.20.6EOS0.32.31.3

假设我们的ground truth是hey man,那么loss就应该是-[(-0.2)+(-1.1)+(-1.3)],因为第一个词是hey,所以我们要取第一行第二列的结果,第二个词是man,所以要取第二行第一列的结果,加上最后的EOS,是第三行第三列的结果,为了把这三个值gather起来,我们就可以通过ground truth g t = [ 1 , 0 , 2 ] gt=[1,0,2] gt=[1,0,2](这是根据word2idx={‘man’:0,‘hey’:1,'EOS’2}得到的)这样做

loss = torch.sum(torch.gather(output,1,gt.unsqueeze(1)))

于是我们就借助ground truth将需要计算loss的对数似然取出来啦!

这就是我目前遇到的一个gather的具体用法

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值