20240621日志:大模型压缩-从闭源大模型蒸馏

location:beijing
涉及知识:大模型压缩、知识蒸馏
在这里插入图片描述

Fig. 1 大模型压缩-知识蒸馏

1. 核心内容

本文提出在一个贝叶斯估计框架内估计闭源语言模型的输出分布,包括先验估计和后验估计。先验估计的目的是通过闭源模型生成的语料库(可能包含模型的粗粒度信息)得到先验分布;后验估计使用代理模型来更新先验分布并生成后验分布。利用这两个分布来进行知识蒸馏。

2. 方法

该文章的创新点是在知识蒸馏的过程中,使用一个代理模型作为教师模型和学生模型的中介,该项目配置如Table. 1

Table. 1 项目配置
项目方法
benchmarksBBH\ARC\AGIEval\MMLU\CSQA\GSM8K\
teacher modelGPT-4
proxy modelLLaMA-33B
student modelLLaMA-7B/13B

一些参数表示如下表

Table. 2 参数表示
变量含义
T \mathcal{T} T闭源的教师模型
S \mathcal{S} S学生模型
M \mathcal{M} M开源的代理模型
X X X输入的token序列
Y Y Y输出的token序列
p Y t p_{Y_t} pYt T \mathcal{T} T输出的概率Pr ( Y t ( Y_{t} (Yt | X , Y < t ) X, Y_{< t}) X,Y<t)
q Y t q_{Y_t} qYt S \mathcal{S} S输出的概率Pr ( Y t (Y_{t} (Yt | X , Y < t ) X,Y_{<t}) X,Y<t)
P Y t P_{Y_t} PYt p Y t p_{Y_t} pYt相关的离散随机变量

用指示函数 I Y t = w \mathbb{I}_{Y_t=\boldsymbol{w}} IYt=w(其实不是空心的I应该是空心的1,没法在CSDN打出来)表示 T \mathcal{T} T t t t时刻产生的one-hot编码标签。
传统的目标函数可以表示为
L t traditional = − ∑ w ∈ V I Y t = w log ⁡ q Y t = w + ∑ w ∈ V p Y t = w log ⁡ p Y t = w q Y t = w (1) \mathcal{L}_{t}^{\text{traditional}}=-\sum_{w\in\mathbb{V}}\mathbb{I}_{Y_{t}=w}\log q_{Y_{t}=w}+\sum_{w\in\mathbb{V}}p_{Y_{t}=w}\log\frac{p_{Y_{t}=w}}{q_{Y_{t}=w}}\tag{1} Lttraditional=wVIYt=wlogqYt=w+wVpYt=wlogqYt=wpYt=w(1)式中 V \mathbb{V} V表示词典, w w w是词典中的一个token,可以看出, L t traditional \mathcal{L}_{t}^{\text{traditional}} Lttraditional由两部分组成,第一部分表示由硬标签(Fig.2)产出的交叉熵损失(交叉熵与相对熵在第三章详细说明),第二部分表示用软标签计算出的KL损失,一般情况下由于 p Y t p_{Y_{t}} pYt很难得到,第二项是被忽略的。
在这里插入图片描述

Fig.2 硬标签与软标签

这篇论文就是解决第二项的问题。

2.1 先验估计

先验估计的目的是使用 T \mathcal{T} T生成的语料库 C \mathcal{C} C,得到每一步 t t t的近似 p Y t p_{Y_{t}} pYt的粗粒度估计 p ^ Y t \hat{p}_{Y_t} p^Yt,来自改良的n-gram算法(基于第n个项目的出现只与前面n-1个项目有关)来实现,对于给定一个输出token序列 Y ≤ t ∈ C Y_{\leq t}\in\mathcal{C} YtC,假设 Y t = w t Y_{t}=w_t Yt=wt其中 w t w_t wt V \mathbb{V} V中的一个token,对于 V \mathbb{V} V中的某个token w w w如果有 w = w t w=w_t w=wt,有
p ^ Y t = w = # ( Y t = w , Y t − 1 = w t − 1 , … , Y t − n = w t − n ) γ # ( Y t − 1 = w t − 1 , … , Y t − n = w t − n ) + γ − 1 γ (2) \hat{p}_{Y_t=w}=\frac{\#(Y_t=w,Y_{t-1}=w_{t-1},\ldots,Y_{t-n}=w_{t-n})}{\gamma\#(Y_{t-1}=w_{t-1},\ldots,Y_{t-n}=w_{t-n})}+\frac{\gamma-1}{\gamma}\tag{2} p^Yt=w=γ#(Yt1=wt1,,Ytn=wtn)#(Yt=w,Yt1=wt1,,Ytn=wtn)+γγ1(2)或者
p ^ Y t = w = # ( Y t = w , Y t − 1 = w t − 1 , … , Y t − n = w t − n ) γ # ( Y t − 1 = w t − 1 , … , Y t − n = w t − n ) (3) \hat{p}_{Y_t=w}=\frac{\#(Y_t=w,Y_{t-1}=w_{t-1},\ldots,Y_{t-n}=w_{t-n})}{\gamma\#(Y_{t-1}=w_{t-1},\ldots,Y_{t-n}=w_{t-n})}\tag{3} p^Yt=w=γ#(Yt1=wt1,,Ytn=wtn)#(Yt=w,Yt1=wt1,,Ytn=wtn)(3)式中, # \# #代表语料库 C \mathcal{C} C中出现某一token的数量, n n n代表窗口大小, γ \gamma γ是个超参数,由此可得到一个 p Y t p_{Y_{t}} pYt的粗略估计 p ^ Y t \hat{p}_{Y_t} p^Yt

2.2 后验估计

后验估计用来改善先验估计,后验估计使用贝叶斯估计框架,引入 T \mathcal{T} T的一个代理模型 M \mathcal{M} M(大于 S \mathcal{S} S), M \mathcal{M} M已经由 T \mathcal{T} T生成的 C \mathcal{C} C微调,该估计使用代理 M \mathcal{M} M生成的连续样本来细化 p ^ Y t \hat{p}_{Y_{t}} p^Yt
假设 p Y t p_{Y_{t}} pYt的值可以用一个离散(更好理解)的随机变量 P Y t P_{Y_t} PYt描述, P Y t P_{Y_t} PYt的数值取自m个数值 p 1 , p 2 , … , p m p^{1},p^{2},\ldots,p^{m} p1,p2,,pm,在0~1服从均匀分布。根据 p ^ Y t \hat{p}_{Y_t} p^Yt,可以重写 P Y t P_{Y_t} PYt的概率质量函数(连续的叫概率密度函数,离散的叫这个)为
E ( P Y t ) = ∑ i = 1 m p i Pr ⁡ ( P Y t = p i ) = p ^ Y t (4) \mathbb{E}(P_{Y_t})=\sum_{i=1}^mp^i\Pr(P_{Y_t}=p^i)=\hat{p}_{Y_t}\tag{4} E(PYt)=i=1mpiPr(PYt=pi)=p^Yt(4)
只要期望 E ( P Y t ) = p ^ Y t \mathbb{E}(P_{Y_t})=\hat{p}_{Y_t} E(PYt)=p^Yt,概率质量函数就可以变化。把 X X X Y < t Y_{<t} Y<t喂给 M \mathcal{M} M得到 t t t时刻的样本 w ^ ∈ V \hat{w}\in\mathbb{V} w^V,给定 w ^ \hat{w} w^ w ∈ V w\in\mathbb{V} wV,事件 A A A定义为如果 w ^ = w \hat{w}=w w^=w,A=1;否则A=0。
如果事件A=1发生,根据贝叶斯定理:
Pr ⁡ ( P Y t = w = p i ∣ A = 1 ) ∝ Pr ⁡ ( A = 1 ∣ P Y t = w = p i ) Pr ⁡ ( P Y t = w = p i ) = p i Pr ⁡ ( P Y t = w = p i ) (5) \Pr(P_{Y_t=w}=p^i|A=1)\propto\Pr(A=1|P_{Y_t=w}=p^i)\Pr(P_{Y_t=w}=p^i)=p^i\Pr(P_{Y_t=w}=p^i)\tag{5} Pr(PYt=w=piA=1)Pr(A=1∣PYt=w=pi)Pr(PYt=w=pi)=piPr(PYt=w=pi)(5)式中 w ∈ V , i ∈ { 1 , 2 , … , m } w\in\mathbb{V},i\in\{1,2,\ldots,m\} wV,i{1,2,,m},通过下式得出一个归一化因子,则 Pr ⁡ ( P Y t = w = p i ∣ A = 1 ) \operatorname*{Pr}(P_{Y_{t}=w}=p^{i}|A=1) Pr(PYt=w=piA=1)可以用 1 η p i Pr ⁡ ( P Y t = w = p i ) \frac1\eta p^i\Pr(P_{Y_t=w}=p^i) η1piPr(PYt=w=pi)来计算
η = ∑ i = 1 m p i Pr ⁡ ( P Y t = w = p i ) (6) \eta=\sum_{i=1}^mp^i\Pr(P_{Y_t=w}=p^i)\tag{6} η=i=1mpiPr(PYt=w=pi)(6)如果事件A=0发生,根据贝叶斯定理:
Pr ⁡ ( P Y t = w = p i ∣ A = 0 ) ∝ Pr ⁡ ( A = 0 ∣ P Y t = w = p i ) Pr ⁡ ( P Y t = w = p i ) = ( 1 − p i ) Pr ⁡ ( P Y t = w = p i ) (7) \Pr(P_{Y_{t}=w}=p^{i}|A=0)\propto\Pr(A=0|P_{Y_{t}=w}=p^{i})\Pr(P_{Y_{t}=w}=p^{i})=(1-p^{i})\Pr(P_{Y_{t}=w}=p^{i})\tag{7} Pr(PYt=w=piA=0)Pr(A=0∣PYt=w=pi)Pr(PYt=w=pi)=(1pi)Pr(PYt=w=pi)(7)式中 w ∈ V , i ∈ { 1 , 2 , … , m } w\in\mathbb{V},i\in\{1,2,\ldots,m\} wV,i{1,2,,m},同样通过下式得出一个归一化因子
η = ∑ i = 1 m ( 1 − p i ) Pr ⁡ ( P Y t = w = p i ) (8) \begin{aligned}\eta=\sum_{i=1}^m{(1-p^i)}\Pr(P_{Y_t=w}=p^i)\end{aligned}\tag{8} η=i=1m(1pi)Pr(PYt=w=pi)(8) Pr ⁡ ( P Y t = w = p i ∣ A = 0 ) \operatorname*{Pr}(P_{Y_{t}=w}=p^{i}|A=0) Pr(PYt=w=piA=0)可由 1 η ( 1 − p i ) Pr ⁡ ( P Y t = w = p i ) \frac1\eta(1-p^i)\Pr(P_{Y_t=w}=p^i) η1(1pi)Pr(PYt=w=pi)得出。
这样在A无论为0还是1都能有所替换,一次迭代结束, P r ( P Y t = p i ) \mathrm{Pr}(P_{Y_{t}}=p^{i}) Pr(PYt=pi) Pr ⁡ ( P Y t = w = p i ∣ A = 0 ) \operatorname*{Pr}(P_{Y_{t}=w}=p^{i}|A=0) Pr(PYt=w=piA=0) Pr ⁡ ( P Y t = w = p i ∣ A = 1 ) \operatorname*{Pr}(P_{Y_{t}=w}=p^{i}|A=1) Pr(PYt=w=piA=1)替换,然后进入下一次迭代。经过多轮采样,可以得到最终的概率质量函数 Pr ⁡ ( P Y t = p i ∣ M ) \operatorname*{Pr}(P_{Y_{t}}=p^{i}|\mathcal{M}) Pr(PYt=piM) p Y t p_{Y_{t}} pYt可以用期望来代替
E ( P Y t ∣ M ) = ∑ i = 1 m p i Pr ⁡ ( P Y t = p i ∣ M ) (9) \mathbb{E}(P_{Y_t}|\mathcal{M})=\sum_{i=1}^mp^i\Pr(P_{Y_t}=p^i|\mathcal{M})\tag{9} E(PYtM)=i=1mpiPr(PYt=piM)(9) E ( P Y t ∣ M ) \mathbb{E}(P_{Y_t}|\mathcal{M}) E(PYtM)即为后验估计。
该过程可以用下图3表示
在这里插入图片描述

Fig.3 后验估计过程

2.3 目标函数

t t t步的目标函数由三部分组成,用指示函数 I Y t = w \mathbb{I}_{Y_t=\boldsymbol{w}} IYt=w表示 T \mathcal{T} T t t t时刻产生的one-hot编码标签。第一部分的目标函数是交叉熵损失 L t c e = − ∑ w ∈ V I Y t = w log ⁡ q Y t = w \mathcal{L}_{t}^{\mathrm{ce}} = -\sum_{w\in\mathbb{V}}\mathbb{I}_{Y_{t}=w}\log q_{Y_{t}=w} Ltce=wVIYt=wlogqYt=w,第二部分基于先验估计 L t k l = ∑ w ∈ V p ^ Y t = w log ⁡ p ^ Y t = w q Y t = w \mathcal{L}_{t}^{\mathrm{kl}} = \sum_{w\in\mathbb{V}}\hat{p}_{Y_{t}=w}\log\frac{\hat{p}_{Y_{t}=w}}{q_{Y_{t}=w}} Ltkl=wVp^Yt=wlogqYt=wp^Yt=w,第三部分基于后验估计 L t ∣ M k l = ∑ w ∈ V E ( P Y t = w ∣ M ) log ⁡ E ( P Y t = w ∣ M ) q Y t = w \mathcal{L}_{t|\mathcal{M}}^{\mathrm{kl}}=\sum_{w\in\mathbb{V}}\mathbb{E}(P_{Y_{t}=w}|\mathcal{M})\log\frac{\mathbb{E}(P_{Y_{t}=w}|\mathcal{M})}{q_{Y_{t}=w}} LtMkl=wVE(PYt=wM)logqYt=wE(PYt=wM),最终得到目标函数
L = 1 T ∑ t = 1 T ( L t c e + α L t k l + β L t ∣ M k l ) (10) \mathcal{L}=\frac{1}{T}\sum_{t=1}^{T}(\mathcal{L}_{t}^{\mathrm{ce}}+\alpha\mathcal{L}_{t}^{\mathrm{kl}}+\beta\mathcal{L}_{t|\mathcal{M}}^{\mathrm{kl}})\tag{10} L=T1t=1T(Ltce+αLtkl+βLtMkl)(10)式中 α \alpha α β \beta β都是超参数。
总结一下如图4
在这里插入图片描述

Fig. 4 总体目标函数

3. 交叉熵损失函数与Kullback-Leibler(KL)损失函数

在信息论中,期望使用公式来表示事件所包含的信息的量度。

信息量,期望一个事件发生的概率越小,信息量就越大;而大概率的信息量较小,同时期望两个事件同时发生的信息量等于两个事件的信息量相加,由此可以规定一个事件的信息量为
I ( x i ) = − log ⁡ b P ( x i ) (11) I(x_i) = -\log_b P(x_i)\tag{11} I(xi)=logbP(xi)(11)
信息熵 𝐻(𝑋),也称为熵,是随机变量𝑋的期望信息量,可以通过对其所有可能结果的信息量求加权平均来计算:
H ( X ) = − ∑ i = 1 n P ( x i ) log ⁡ b P ( x i ) (12) H(X) = -\sum_{i=1}^{n} P(x_i) \log_b P(x_i)\tag{12} H(X)=i=1nP(xi)logbP(xi)(12)信息熵用来评估一个随机变量的不确定性,不确定性越大(对投色子,各数字概率密度均匀,取出任何数的概率相同),熵越大;不确定性越小(对扑克牌,普通牌与大小王的概率密度差距很大,取出普通牌的不确定性小),熵越小。

交叉熵假设随机变量𝑋的真实概率密度p,预测概率密度q,定义q对p的平均信息量的估计,叫做交叉熵,定义为公式
H ( p , q ) = ∑ p i I i q = − ∑ p i l o g 2 ( q i ) (13) H(p,q)=\sum p_iI_i^q=-\sum p_ilog_2(q_i)\tag{13} H(p,q)=piIiq=pilog2(qi)(13)交叉熵越小,预测的分布与真实的分布差异越小。且交叉熵总是大于熵的值。

KL散度也称为相对熵,是一种衡量两个概率分布差异的指标。KL散度是不对称的,即从分布P到分布Q的KL散度与从Q到P的KL散度不同。对于两个概率分布𝑃和𝑄定义在相同的概率空间上,KL散度定义为:
K L ( P ∥ Q ) = ∑ x [ P ( x ) ( I P − I Q ) ] = ∑ x P ( x ) log ⁡ ( P ( x ) Q ( x ) ) (14) \mathrm{KL}(P\parallel Q)=\sum_{x}[P(x)(I_P-I_Q)]=\sum_{x}P(x)\log\left(\frac{P(x)}{Q(x)}\right)\tag{14} KL(PQ)=x[P(x)(IPIQ)]=xP(x)log(Q(x)P(x))(14)
对于连续概率分布,求和变成积分。当两分布完全相同,则 K L ( P ∥ Q ) = 0 \mathrm{KL}(P\parallel Q)=0 KL(PQ)=0,KL熵用来衡量两分布的相似程度,KL熵越小,两分布越相似。

  • 21
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值