Contents
Introduction
- 作者提出一种通用的 KD 框架 Generalized Knowledge Distillation (GKD),囊括了 on/off/mixed-policy KD 以及各种不同的 divergence;实验发现 on-policy 效果是最好的,但并没有一种对所有任务都最优的 divergence
Method
Generalized Knowledge Distillation (GKD)
- On-policy / Off-policy / Mixed-policy KD. Off-policy KD 使用固定的数据集
(
X
,
Y
)
(X,Y)
(X,Y) 进行知识蒸馏,如果
Y
Y
Y 来自固定数据集则是 Supervised KD,
Y
Y
Y 由 teacher 生成则是 Sequence-Level KD,但 off-policy KD 存在 student model 训推不一致的现象,student model 训练和推理时的输入序列分布不一致,从而影响 student 的推理质量
为此,作者提出 on-policy KD,使用 student 生成的数据作为训练数据 (we use a temperature of γ = 1 γ = 1 γ=1 to encourage diversity),另外这里的 student 是经过 SFT 的模型,这样 student 开始就能产生比较合理的训练数据
- Choice of Divergence. forward KL
D
K
L
(
P
∥
Q
)
\mathcal D_{KL}(P\|Q)
DKL(P∥Q), reverse KL
D
K
L
(
Q
∥
P
)
\mathcal D_{KL}(Q\|P)
DKL(Q∥P) and generalized JSD JSD(
β
β
β) interpolates between the forward and reverse KL using the bounded coefficient
0
<
β
<
1
0<\beta<1
0<β<1. Gradients of JSD(
β
β
β) behave similarly to forward KL and reverse KL when
β
β
β is close to 0 and 1 respectively. forward KL 的 mode-average 特性使得 KD 在 student 和 teacher 能力差距过大时生成低质量输出,而 reverse KL 的 mode-seeking 特性虽然可以避免上述问题,但也可能损失 student 的生成多样性;JSD 则是在两者间进行插值
- Generalized KD (GKD). 作者提出了一种通用框架 GKD 囊括 on/off/mixed-policy KD 以及各种不同的 divergence
其中, λ ∈ [ 0 , 1 ] \lambda\in[0,1] λ∈[0,1] 控制 on/off/mixed-policy KD, D \mathcal D D 可以是任意 divergence;作者在实验中发现,on-policy 效果最好,而 divergence 的选择需要根据下游任务、采样策略等选择,没有最优的 divergence 选择
RL Fine-tuning + On-policy GKD
- GKD 可以无缝插入到 RLHF 的框架里,使用如下的目标函数进行训练,进一步优化一些不可微的训练目标:
we recommend using reverse KL or JSD (0.9) when integrating GKD with RL.
Experiments
- Abstractive Summarization. (1) Comparison to baselines. ImitKD and f-distill, which can be viewed as “mixed” data variants of GKD (
λ
=
0.5
λ = 0.5
λ=0.5) with forward KL and total variation distance as divergence.
(2) Data efficiency and scaling.
(3) GKD Ablations. on-policy 的效果最好;最佳的 divergence 与采样温度有关,对于 temperature sampling,mode-seeking divergences 效果更好,而对于 greedy sampling,不同 divergence 效果都差不多
(4) Choosing GKD Divergence. The optimal choice of divergence is temperature-dependent
(5) On-policy GKD with RL. 作者进一步结合了 RL with textual entailment feedback as the reward (RLEF) 来降低模型文本总结的幻觉 (faithful summaries must be textually entailed from their input documents)
- Machine Translation. (1) On-policy GKD outperforms commonly-used KD approaches
(2) GKD Ablations.
- Arithmetic Reasoning. (1) Comparison to baselines.
(2) GKD Ablations.
- Task-agnostic Distillation: Instruction Tuning.