1. 概要
GPT-2 是使用 transformer 的解码块构建的,而 BERT 用的是编码块,但一个关键不同在于 GPT-2 是和传统语言模型一样,每次只输出一个token;模型实际运作方式是将已经输出的token作为下一轮输入的一部分,这也叫“自回归”
GPT-2 和后来的模型比如 TransformerXL 与 XLNet 都是天然自回归的。BERT 并没这样做。放弃自回归使 BERT 可以结合单词上下文(左右两侧的单词)获取更好的效果。XLNet 找到了某种整合上下文的新方式,所以自回归和上下文关联的特点它都有。
自关注层(self-attention layer)中 mask future token 并不是像 BERT 显式的用 [mask] 将其替换掉,而是阻碍其从右边token里提取信息。比如我们执行到了第四步,可以看到只有当前和之前出现了的token能参与计算
2. GPT2 流程
GPT2 的总体结构如图5所示,总的来说就是transformer的decoder部分,与transfomer的区别在于:(1)去除了encoder-decoder子层 (2)模型能处理的序列长度达到了 4000(transfomer 是512)
2.1 输入
输入由两部分组成:token embedding 和 position embedding
第一个解码块开始接手,先经 mask self-attention 处理,然后送入前馈神经网络。转换块处理完成后将解析token得到的向量发给上层解码器。这一过程在每一块中都是一样的,但每一块的mask self-attention和前馈神经网络参数各不相同。
之后,将第一步的输出添加到输入序列中,进行下一步的预测
2.2 mask self-attention
self-attention流程和transformer一样,通过q , k 向量获得权重,最后将v 向量加权和,区别在于每次进行self-attention时只能与当前词的前面词做attention, 如图7所示,在输入为it
时, 只能同it
前<s> a robot must obey the orders given
这8个词做attention, 不能同it
后面的词做attention
mask self-attention的实现如图8所示,遮罩型和原始自注意直到第二步前都是一致的(求 q, k, v, 并用q, v 打分)。我们来看第二个token,这里,最后两个标识被盖住了,模型在打分这一步遇到了阻碍。基本上后续token的相关系数(权重)都是 0 ,所以模型也就不能”偷看“了。
遮挡效果通常是借助 attention mask
的矩阵实现的。试想现在有一个 4 个词构成的序列(”robot must obey orders”)。对一个语言模型来讲,这个序列会被拆成 4 步处理——每次一词(假定每个词就是一个token)。同时这些模型都是批处理的工作模式,不妨假设一批 4 个,这样整个序列都会在同一批内接受处理。
相乘后,为 score 带上attention mask
。这样想要遮挡的单元格的值就会是负无穷或是一个非常小的负数,如下图所示, attention mask其实就是一个上三角矩阵,上三角元素都是一个负无穷的数
接着执行 softmax 变换就得到了参与 self-attention 运算的相关系数值(权重)
- 模型处理第一行的时候,因为只有一个词(robot),所有精力都给它
- 到了第二行,就有两个词了,对must进行处理,分配 48% 的注意力给 robot,52% 留在must
2.3 输出
当顶层解码块处理完成,模型会将其输出的向量与嵌入矩阵相乘, 结果就是词表中各个词的得分。
当然我们可以直接选择最高分的单词(top-k = 1)。但最好还是考虑一下其他的词。所以更好的做法是将得分视为概率对词表进行采样(得分更高的词更有可能被选出来)。折中办法是将 top-k 设置为 40,让模型从 40 个得分最高的词里选。
这样,模型就完成了一次迭代得到了一个单词输出。模型会持续迭代直到得到整个序列(1024 个标识)的输出或者解析出序列截止标识。
2.4 注意事项
- 上文中提到的token,实际上 GPT-2 使用字节对编码(Byte Pair Encoding)来对词表创建token,这意味着token通常只是单词的一部分。
- 上文中的 GPT-2 运行模式为 eval 模式,所以才会每次只处理一个词。训练时模型会面对更长的文本并同时处理多个标识,训练期间模型处理的批大小(512)也超过了评估模式的
- tranformer中使用了ResNet 和 LN ,在GPT中也使用了,上文并未提出
- 每个块都有一个权重矩阵
W
Q
W^Q
WQ
W
K
W^K
WK
W
V
W^V
WV 来计算
q
,
k
,
v
q ,k, v
q,k,v向量,而整个模型只有一个token embedding和一个position embedding 矩阵,如下图所示
关注微信公众号 funNLPer 了解更多AI算法