[ACL 2024] Revisiting Knowledge Distillation for Autoregressive Language Models

Introduction

  • 作者提出 Autoregressive KD with Adaptive Teaching Modes (ATKD),通过对难易样本采用不同的学习策略来解决 larger teachers might dramatically result in a poorer student, especially when the model capability gap is large 的问题,可以作为一种通用的学习策略提升不同的已有 KD 算法的精度
    在这里插入图片描述

Method

Rethinking Knowledge Distillation for Autoregressive LMs

  • Reformulation of L K L \mathcal L_{\mathbf {KL}} LKL. KL 散度可以被分解为 ground truth 类别上的 binary KL loss K L ( p b t ∣ ∣ q b t ) \mathrm{KL}(\mathrm{p}_\mathrm{b}^t||\mathrm{q}_\mathrm{b}^t) KL(pbt∣∣qbt) 和非 ground truth 类别上的 KL loss K L ( p ^ t ∣ ∣ q ^ t ) \mathrm{KL}(\hat{\mathrm{p}}^\mathrm{t}||\hat{\mathrm{q}}^\mathrm{t}) KL(p^t∣∣q^t),前者可以帮助 student 学习 target 相关的信息,被称为 target-oriented knowledge distillation (TKD),后者可以帮助 student 学习 non-target 中蕴含的知识,被称为 diversity-oriented knowledge distillation (DKD);此外,这两部分的蒸馏损失被加上了一个权值 p \ g t t p_{\backslash g_t}^t p\gtt,该项反映了 teacher 的 uncertainty,被称为 uncertainty coefficient (UNC)
    L K L = ∑ t = 1 T ( p g t t log ⁡ ( p g t t q g t t ) + ∑ j = 1 , j ≠ g t C p j t log ⁡ ( p j t q j t ) ) = ∑ t = 1 T ( p g t t log ⁡ ( p g t t q g t t )       + p \ g t t ∑ j = 1 , j ≠ g t C p ^ j t ( log ⁡ ( p ^ j t q ^ j t ) + log ⁡ ( p \ g t t q \ g t t ) ) = ∑ t = 1 T ( p g t t log ⁡ ( p g t t q g t t ) + p ∖ g t t log ⁡ ( p ∖ g t t q ∖ g t t )       + p ∖ g t t ∑ j = 1 , j ≠ g t C p ^ j t log ⁡ ( p ^ j t q ^ j t ) = ∑ t = 1 T ( K L ( p b t ∣ ∣ q b t ) + p \ g t t K L ( p ^ t ∣ ∣ q ^ t ) ) \begin{aligned} \mathcal{L}_{\mathrm{KL}}& =\sum_{t=1}^{T}(p_{g_{t}}^{t}\log(\frac{p_{g_{t}}^{t}}{q_{g_{t}}^{t}})+\sum_{j=1,j\neq g_{t}}^{C}p_{j}^{t}\log(\frac{p_{j}^{t}}{q_{j}^{t}})) \\&=\sum_{t=1}^T\left(p_{g_t}^t\log(\frac{p_{g_t}^t}{q_{g_t}^t})\right. \\ &\ \ \ \ \ +p_{\backslash g_{t}}^{t}\sum_{j=1,j\neq g_{t}}^{C}\hat{p}_{j}^{t}\left(\log(\frac{\hat{p}_{j}^{t}}{\hat{q}_{j}^{t}})+\log(\frac{p_{\backslash g_{t}}^{t}}{q_{\backslash g_{t}}^{t}})\right) \\ &=\sum_{t=1}^{T}\left(p_{g_{t}}^{t}\log(\frac{p_{g_{t}}^{t}}{q_{g_{t}}^{t}})+p_{\setminus g_{t}}^{t}\log(\frac{p_{\setminus g_{t}}^{t}}{q_{\setminus g_{t}}^{t}})\right. \\ &\ \ \ \ \ +p_{\setminus g_t}^t\sum_{j=1,j\neq g_t}^C\hat{p}_j^t\log(\frac{\hat{p}_j^t}{\hat{q}_j^t}) \\ &=\sum_{t=1}^T\left(\mathrm{KL}(\mathrm{p}_\mathrm{b}^t||\mathrm{q}_\mathrm{b}^t)+p_{\backslash g_t}^t\mathrm{KL}(\hat{\mathrm{p}}^\mathrm{t}||\hat{\mathrm{q}}^\mathrm{t})\right) \end{aligned} LKL=t=1T(pgttlog(qgttpgtt)+j=1,j=gtCpjtlog(qjtpjt))=t=1T(pgttlog(qgttpgtt)     +p\gttj=1,j=gtCp^jt(log(q^jtp^jt)+log(q\gttp\gtt))=t=1T(pgttlog(qgttpgtt)+pgttlog(qgttpgtt)     +pgttj=1,j=gtCp^jtlog(q^jtp^jt)=t=1T(KL(pbt∣∣qbt)+p\gttKL(p^t∣∣q^t))其中, T T T 为序列长度, p , q p,q p,q 分别为 teacher 和 student 的概率分布, g t gt gt 为 teacher 预测的 ground-truth 类别, p g t t = exp ⁡ ( z g t t ) ∑ j = 1 C exp ⁡ ( z j t ) , p ∖ g t t = ∑ k = 1 , k ≠ g t C exp ⁡ ( z k t ) ∑ j = 1 C exp ⁡ ( z j t ) , p ^ i t = exp ⁡ ( z i t ) ∑ j = 1 , j ≠ g t C exp ⁡ ( z j t ) p_{g_t}^t=\frac{\exp(z_{g_t}^t)}{\sum_{j=1}^C\exp(z_j^t)},p_{\setminus g_t}^t=\frac{\sum_{k=1,k\neq g_t}^C\exp(z_k^t)}{\sum_{j=1}^C\exp(z_j^t)},\hat{p}_i^t=\frac{\exp(z_i^t)}{\sum_{j=1,j\neq g_t}^C\exp(z_j^t)} pgtt=j=1Cexp(zjt)exp(zgtt),pgtt=j=1Cexp(zjt)k=1,k=gtCexp(zkt),p^it=j=1,j=gtCexp(zjt)exp(zit) p i t = p ∖ g t t ⋅ p ^ i t p_i^t=p_{\setminus g_t}^t\cdot \hat{p}_i^t pit=pgttp^it p b t = [ p g t t , p ∖ g t t ] \mathrm{p}_{\mathrm{b}}^t=[p_{g_t}^t,p_{\setminus g_t}^t] pbt=[pgtt,pgtt]
  • Empirical Analyses. (1) UNC measures the learning difficulties of tokens, where the hard-to-learn ones are more important for KD. 根据 p \ g t t p_{\backslash g_t}^t p\gtt 的大小可以把 tokens 分为难样本 (top-50% uncertainty) 和简单样本,实验发现难样本对 student 的学习更重要,尤其是 student 和 teacher 差距比较大的时候,这可能是因为难样本能让 student 学到丰富的类间信息,同时避免过拟合
    在这里插入图片描述(2) DKD contributes more (than TKD) but is greatly suppressed, especially for the larger teachers. 作者对 TKD 和 DKD 做了解耦,去除了权重 p \ g t t p_{\backslash g_t}^t p\gtt 来考察它们各自的作用,作者发现 DKD 显著优于 TKD,但在 KL loss 中,由于 p \ g t t p_{\backslash g_t}^t p\gtt 的存在,DKD 的权值被降低了,并且这一现象在更大规模的模型中尤为显著,这也是作者认为的导致 larger teachers might dramatically result in a poorer student 的原因在这里插入图片描述在这里插入图片描述(3) TKD plays different roles in tokens with different learning difficulties. TKD 在简单样本上可能会导致 student 过拟合,从而影响泛化性;在难样本上能降低难样本的学习难度,从而提升 student 精度
    在这里插入图片描述

Improving Knowledge Distillation with Adaptive Teaching Modes

  • Autoregressive KD with Adaptive Teaching Modes (ATKD). 基于上述观察很容易想到,不同的 tokens 根据其难易程度,应该有不同的学习策略;简单样本仅使用 DKD,难样本 (top-50% uncertainty) 使用 DKD + TKD
    L K L e = − ∑ t ∈ D e K L ( p ^ t ∣ ∣ q ^ t ) , L K L h = − ∑ t ∈ D h K L ( p b t ∣ ∣ q b t ) + K L ( p ^ t ∣ ∣ q ^ t ) \begin{aligned} &\mathcal{L}_\mathrm{KL}^{e} =-\sum_{t\in\mathcal{D}_e}\mathrm{KL}(\mathbf{\hat{p}^t}||\mathbf{\hat{q}^t}), \\ &\mathcal{L}_{\mathrm{KL}}^h =-\sum_{t\in\mathcal{D}_h}\mathrm{KL}(\mathbf{p_b^t}||\mathbf{q_b^t})+\mathrm{KL}(\mathbf{\hat{p}^t}||\mathbf{\hat{q}^t}) \end{aligned} LKLe=tDeKL(p^t∣∣q^t),LKLh=tDhKL(pbt∣∣qbt)+KL(p^t∣∣q^t)最终的损失函数为简单样本和难样本上损失的加权和 L K L a l l = λ ∗ L K L e + ( 1 − λ ) ∗ L K L h \mathcal{L}_{\mathrm{KL}}^{all}=\lambda*\mathcal{L}_{\mathrm{KL}}^e+(1-\lambda)*\mathcal{L}_{\mathrm{KL}}^h LKLall=λLKLe+(1λ)LKLh其中, λ = 0.2 \lambda=0.2 λ=0.2

Experiments

  • Compared Results. S NLG \mathcal S_{\textrm{NLG}} SNLG 为语言生成任务,由 GPT-4 打分; S NLU \mathcal S_{\textrm{NLU}} SNLU 为语言理解任务,为 benchmark 得分
    在这里插入图片描述在这里插入图片描述
  • Ablation Study. (1) Impact of ratio k k k. k k k 用于确定 top- k k k uncertainty 的 tokens 为难样本;(2) Impact of coefficient λ λ λ. 用于确定难易样本损失的权重
    在这里插入图片描述

References

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值