Pytorch结合语言模型理解torch.gather
torch.gather一直是我觉得不太理解的函数,往往在代码里也比较少会碰到,偶尔碰到了,去查一下官方文档,查了也不太看的懂,这里结合语言模型,来解释一下torch.gather的意义,一切还是从官方文档说起
文档解释
torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
Parameters
- input (Tensor) – the source tensor
- dim (int) – the axis along which to index
- index (LongTensor) – the indices of elements to gather
- out (Tensor, optional) – the destination tensor
- sparse_grad (bool,optional) – If True, gradient w.r.t. input will be a sparse tensor.
文档的解释就是在某个维度上gather(收集?)值,并且给了一个3D tensor的例子
我们来看函数的三个主要的输入
- input,就是我们要gather的值所在的Tensor
- dim,就是我们gather的维度
- index,确定了要gather的维度后,我们还需要确认需要在这个维度上gather哪些值,这由index确定
相信根据上面的三维例子,我们很容易能写出这个函数在二维Tensor上的效果
out[i][j] = input[index[i][j]][j] # if dim == 0
out[i][j] = input[i][index[i][j]] # if dim == 1
虽然类似的各种维度的结果我们都能很轻易的写出,但相信读者和我一样疑惑,这个解释太抽象了,具体地说,这玩意到底有啥用?
今天我在看语言模型的Loss的时候,发现gather竟然出现在了里面,于是我仔细思考了一下,发现了gather函数的方便之处,这些还是要从语言模型讲起。
语言模型
简单来讲,假设我们现在需要构建一个以序列为输出的模型(你可以想象image captioning或者机器翻译),此时我们的Decoder需要输出一个序列,具体地,它在每个time step输出一个分布,然后我们在这个分布中采样一个词,输入到模型中,再得到下一个time step的输出,如此反复直到它生成一个End of Sentence(EOS) token,这就是一个普通的RNN based Decoder
这样的模型的损失函数往往就是下面这个式子
l
o
s
s
=
−
∑
t
=
1
T
log
(
p
θ
(
y
t
∣
y
<
t
)
)
loss=-\sum_{t=1}^{T}\log(p_\theta(y_t|y_{<t}))
loss=−t=1∑Tlog(pθ(yt∣y<t))
其中 y y y是ground truth序列
假设我们前面的模型都已经写好了,我们已经算出了所需要的结果,它是一个 T × v o c a b _ s i z e T\times vocab\_size T×vocab_size的矩阵,其中的每行都做了log softmax,我们也有ground truth,是一个T维的向量,每一维上是一个数字,是词对应的编码,现在我们要计算损失,应该怎么办呢?
为了更加直观,就举个实际的例子T=2,vocab_size=3,输出的最后结果如下
(
m
a
n
h
e
y
E
O
S
−
0.1
−
0.2
−
0.3
−
1.1
−
1.2
−
2.3
−
1.3
−
0.6
−
1.3
)
\left( \begin{array}{ccc} man & hey & EOS\\ -0.1 & -0.2 & -0.3\\ -1.1 & -1.2 & -2.3\\ -1.3 & -0.6 & -1.3\\ \end{array} \right)
⎝⎜⎜⎛man−0.1−1.1−1.3hey−0.2−1.2−0.6EOS−0.3−2.3−1.3⎠⎟⎟⎞
假设我们的ground truth是hey man,那么loss就应该是-[(-0.2)+(-1.1)+(-1.3)],因为第一个词是hey,所以我们要取第一行第二列的结果,第二个词是man,所以要取第二行第一列的结果,加上最后的EOS,是第三行第三列的结果,为了把这三个值gather起来,我们就可以通过ground truth g t = [ 1 , 0 , 2 ] gt=[1,0,2] gt=[1,0,2](这是根据word2idx={‘man’:0,‘hey’:1,'EOS’2}得到的)这样做
loss = torch.sum(torch.gather(output,1,gt.unsqueeze(1)))
于是我们就借助ground truth将需要计算loss的对数似然取出来啦!
这就是我目前遇到的一个gather的具体用法