白盒蒸馏原理简述

白盒蒸馏

1. 什么是白盒蒸馏

白盒蒸馏是指在蒸馏过程中使用到教师模型的参数或 logits 的 蒸馏技术.

接下来我们会介绍经典的白盒蒸馏方法和代码实现。

2. MiniLLM

大模型能力的强大也伴随着参数量的膨胀,为了以合理的成本部署大模型,如何将大模型的知识蒸馏到小模型是一个问题。从前,面对有限的状态空间(比如有限的分类类别),教师模型和学生模型的参数量都足以学习每一种类别的模式;而在大模型自回归生成的场景下,学生模型参数变少后,天然地失去了和大模型同等的表达能力,从而传统的蒸馏可能效果不佳。

MiniLLM是一种针对生成式语言模型的全新的KD方法,它是一种白盒蒸馏方法,这种方法使用逆向KL散度,理论上使得学生模型模仿教师模型概率较大的生成结果,忽略教师模型概率不大的生成结果。这样做一定程度放弃了模型生成的多样性,从而实现高性价比的LLM部署落地。

2.1 前向KL散度

前向KL散度是传统蒸馏时使用的损失函数,这里我们再复习一下它的概念:

假设老师分布为 p p p, 学生分布为 q θ q_\theta qθ, θ \theta θ 是学生模型的参数。

前向KL散度可以看成是两个分布相似程度的定义(注意KL散度具有不对称性,不是距离):

K L ( p ∣ ∣ q θ ) = ∑ i p ( i ) l o g p ( i ) q θ ( i ) KL(p||q_\theta) = \sum_i p(i)log\frac{p(i)}{q_\theta(i)} KL(p∣∣qθ)=ip(i)logqθ(i)p(i)

从定义可以看出,在 p p p分布为 0 0 0的地方, q q q分布无论为多少,都不影响这一项为 0 0 0,所以当我们最小化前向KL散度时, q q q会在老师概率分布小的地方分配大的概率。对应到大模型生成上,就是在老师模型输出可能性很小的地方,学生模型却放大了这种可能性,显然这是不符合模型生成预期的。

2.2 逆向KL散度

reversed KL:

K L ( q θ ∣ ∣ p ) = ∑ i q θ ( i ) l o g q θ ( i ) p ( i ) = − ∑ i q θ ( i ) l o g p ( i ) q θ ( i ) KL(q_\theta ||p) = \sum_i q_\theta(i)log\frac{q_\theta(i)}{p(i)} = -\sum_i q_\theta(i)log\frac{p(i)}{q_\theta(i)} KL(qθ∣∣p)=iqθ(i)logp(i)qθ(i)=iqθ(i)logqθ(i)p(i)

蒸馏时,使用逆向KL散度代替前向KL散度。最小化逆向KL散度时,老师分布大的地方,学生分布也同步变大,而老师分布小的地方,学生分布会更小。下面这张经典的图片可以看出前向和后向KL的差异。

在这里插入图片描述

这样直观上看,使用逆向KL散度更加符合生成模型的场景。

2.3 基于策略梯度的优化

MiniLLM的论文中提出了另一个新颖的视角——逆向KL其实可以等价于强化学习,并进行了公式推导。策略梯度是一种强化学习算法:将期望的回报写成一个可导的函数,然后求使得这个函数的最大的策略(比如使用梯度上升)。

由于这部分涉及较多数学公式推导和强化学习,有兴趣的同学可以查看论文自行学习。

3. BabyLlama(实践)

将小模型蒸馏直接应用到了大模型上。它的损失函数是以下两种损失的加权和:

  • 和硬损失的交叉熵
  • 和软损失的KL散度

损失函数:

def compute_loss(self, model, inputs, return_outputs=False):
        # 硬损失,即和ground truth的交叉熵
        outputs_student = model(**inputs)
        student_loss = outputs_student.loss

        # compute teacher output
        with torch.no_grad():
            all_teacher_logits = []
            for teacher in self.teachers:
                outputs_teacher = teacher(**inputs)
                all_teacher_logits.append(outputs_teacher.logits)
            avg_teacher_logits = torch.stack(all_teacher_logits).mean(dim=0)

        # assert size
        assert outputs_student.logits.size() == avg_teacher_logits.size()

        # 软损失,和教师模型输出分布的KL散度
        loss_function = nn.KLDivLoss(reduction="batchmean")
        loss_logits = (
            loss_function(
                F.log_softmax(outputs_student.logits / self.args.temperature, dim=-1),
                F.softmax(avg_teacher_logits / self.args.temperature, dim=-1),
            )
            * (self.args.temperature ** 2)
        )
        # Return weighted student loss
        loss = self.args.alpha * student_loss + (1.0 - self.args.alpha) * loss_logits
        return (loss, outputs_student) if return_outputs else loss
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值