项目实战(十) - - GPT-2实现文本生成
GPT-2实现文本生成
由于GPT-2主要基于Transformer的Decoder模块,前两节笔记中已将大部分要点详细介绍,本节更多的关注GPT-2不同的部分
1. Result 呈现
GPT-2实现文本生成的成果展现,给定一个输入,模型会将后续向量依次输出,从而生成句子子,理解了前面语言模型实战博客的过程,这里就比较容易了
2. GPT-2 VS BERT
-
结构差异
GPT-2 是使用「transformer 解码器模块」构建的,而 BERT 则是通过「transformer 编码器」模块构建的 -
任务差异
GPT-2 就像传统的语言模型一样,一次只输出一个单词(token);BERT训练两个任务:①Masked Language Model; ②Next Sentence Prediction -
模型差异
GPT-2,以及一些诸如 TransformerXL 和 XLNet 等后续出现的模型,本质上都是自回归模型,BERT不同。虽然没有使用自回归机制,但 BERT 获得了结合单词前后的上下文信息的能力,从而取得了更好的效果
XLNet 使用了自回归,并且引入了一种能够同时兼顾前后的上下文信息的方法
3. Self-Attention VS Masked Self-Attention
Self-Attention模块允许一个位置看到它右侧单词的信息(如下左图),而Masked Self-Attention模块则不允许看到后方要预测的信息(如下右图)
另模型只关注之前生成的向量,对要预测的向量进行屏蔽
For Example:
通过将查询矩阵和键矩阵相乘来计算注意力得分
在相乘之后,我们加上注意力掩模三角矩阵。它将我们想要屏蔽的单元格设置为负无穷或非常大的负数
// Attention mask.
if attention_mask is not None:
assert batch_size > 0, "batch_size has to be defined and > 0"
attention_mask = attention_mask.view(batch_size, -1)
// We create a 3D attention mask from a 2D tensor mask.
// Sizes are [batch_size, 1, 1, to_seq_length]
// So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
// this attention mask is more simple than the triangular masking of causal attention
// used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
// Since attention_mask is 1.0 for positions we want to attend and 0.0 for
// masked positions, this operation will create a tensor which is 0.0 for
// positions we want to attend and -10000.0 for masked positions.
// Since we are adding it to the raw scores before the softmax, this is
// effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) // fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
然后,对每一行执行 softmax 操作,从而得到我们在自注意力机制中实际使用的注意力得分:
4. Sampling
GPT-2 中有top-k参数和top-p参数,分属两种sampling方式:
top-k采样,模型会从概率前 k 大的单词中抽样选取下一个单词
Top-p采样,设定概率阈值,取满足阈值条件的样本进行采样
- top-k/top-p采样中,k/p值的影响:
在无条件生成长文本的深度模型中,较大的k/较小的p值代表跟高的熵。较小的k值/较大的p值,生成的文本,往往更简单,重复度高
Reference:https://www.jiqizhixin.com/articles/2019-08-26-12