Transformer 模型自提出以来,彻底改变了自然语言处理(NLP)领域。它的核心由 Encoder 和 Decoder 两部分组成,其中 Decoder 在生成任务(如机器翻译、文本生成)中扮演着至关重要的角色。本文将深入浅出地介绍 Decoder 的结构、训练和推理过程。
1. Decoder 的结构
Decoder 是 Transformer 中负责生成目标序列的部分。它的结构与 Encoder 类似,但有一些关键区别:
1.1 核心组件
Masked Self-Attention:
-
与 Encoder 的 Self-Attention 不同,Decoder 的 Self-Attention 是 掩码的(Masked),即每个词只能关注它前面的词,而不能关注后面的词。这是为了确保生成过程是自回归的(即逐个生成词)。
Cross-Attention:
-
Decoder 还引入了 Cross-Attention,用于将 Encoder 的输出(源序列的表示)K、V矩阵与 Decoder 的输入Query矩阵一起进行多头交叉注意力机制。这使得 Decoder 能够利用源序列的信息生成目标序列。
-
注意:encoder block 有6个,只有最后一层的输出作为每个decoder bolck的输入
-
前馈神经网络(Feed-Forward Network, FFN):
-
每个 Decoder 层还包含一个前馈神经网络,相当于一个夹心饼干🍪,先升维度再降低维度,在增加模型表现力的情况下,保证输出的维度和输入统一,方便进入下一个decoder block块中处理
1.2 输入输出
输入:
-
Decoder 的输入是目标序列的起始部分(例如
<sos> i have a
),通常通过 Shift Right 操作生成。
输出:
-
Decoder 的输出是目标序列的后续部分(例如
i have a cat <eos>
),其中<eos>
是结束符号。
2. 训练与推理过程
2.1 训练过程
在训练阶段,Decoder 使用 Teacher Forcing 技术:
Teacher Forcing
核心思想:训练时在解码时使用真实的目标序列作为输入,而非模型自身生成的输出
-
输入输出格式: 假设目标序列(标准输出)是
i have a cat
,在训练时:-
解码器输入(
dec_inputs
):会在序列开头添加起始符(如<sos>
),即<sos> i have a
。 -
解码器真实输出(
dec_outputs
):会移除起始符,即i have a cat
,每个词对应一个预测位置。
-
-
预测过程: 模型在每个位置(时间步)预测下一个词:
-
输入
<sos>
→ 预测i
-
输入
<sos> i
→ 预测have
-
输入
<sos> i have
→ 预测a
-
输入
<sos> i have a
→ 预测cat
-
每个位置的预测会与真实标签对比,计算交叉熵损失(CrossEntropyLoss)。
-
优点:
-
易于并行化(Transformer 特有优势):所有位置的预测可以同时计算
-
更稳定,避免之前错误预测导致累计误差影响
-
-
缺点:
-
暴露偏差(Exposure Bias):训练时使用真实数据,但推理时依赖模型自身输出,导致训练与推理的输入分布不一致。模型可能在推理时对错误敏感
-
Shift Right for Mask
-
在训练时,目标序列需要向右移动一位(即 Shift Right),并在开头添加
<sos>
。这样做的目的是确保 Decoder 在生成每个词时,只能看到它前面的词,而不能看到它后面的词。这是通过 掩码(Mask) 实现的。
2.2 推理过程
在推理阶段,Decoder 是自回归的:
-
模型逐个生成词,每次生成一个词后,将其作为输入用于生成下一个词。
-
例如,模型首先生成
"i"
,然后输入"i"
生成"have"
,接着输入"i have"
生成"a"
,最后输入"i have a"
生成"cat"
。 -
推理过程与训练过程不一致,因为模型在推理时只能使用自己生成的词作为输入,而不是真实标签。
3 训练与推理的微妙差异
维度 |
训练模式 |
推理模式 |
方式 |
teacher forcing |
自回归 |
输入方式 |
完整序列(并行) |
逐步生成(串行) |
可见信息 |
全体历史(带掩码) |
仅已生成内容 |
误差传递 |
全局优化 |
局部累积 |
速度 |
快(并行计算) |
慢(串行计算) |
4. 总结
-
Decoder 结构:包含 Masked Self-Attention、Cross-Attention 和前馈神经网络。
-
训练与推理:
-
训练时使用 Teacher Forcing并行处理,
-
推理时是自回归的,串行处理,使用前一个输出作为输入,预测下一个词。
-