知识蒸馏 - 模型压缩方法之一(伪压缩、广义压缩)
模型包括两部分
Net_T (Teacher Model) : 复杂而强大,完整地学习Ground Truth。不进行部署上线
Ground Truth 指的是训练集对监督学习技术的分类的准确性。可以把它理解为真值、真实的有效值或者是标准的答案。标注值
Net_S (Student Model) : 简单而弱小,同时学习Net_T的logit和Ground Truth。是最终应用模型
logit:是模型输出的对于各个类别的概率预测值
损失函数
L = α L s o f t + β L h a r d L = \alpha L_{soft} + \beta L_{hard} \\ L=αLsoft+βLhard
α \alpha α 和 β \beta β 是超参数
{ L s o f t = − ∑ j N p j T l o g ( q j T ) L h a r d = − ∑ j N c j l o g ( q j 1 ) \left\{ \begin{aligned} L_{soft} &=& -\sum_j^N {p_j^T log(q_j^T)} \\ L_{hard} &=& -\sum_j^N {c_j log(q_j^1)} \end{aligned} \right. ⎩ ⎨ ⎧LsoftLhard==−j∑NpjTlog(qjT)−j∑Ncjlog(qj1)
softmax-T
q
i
T
=
e
x
p
(
z
i
/
T
)
∑
j
N
e
x
p
(
z
j
/
T
)
q_i^T = \frac {exp(z_i/T)}{\sum_j^N{exp(z_j/T)}}
qiT=∑jNexp(zj/T)exp(zi/T)
这里的T就是Temperature,是一个在softmax操作之前需要统一除以的小参数。有如下属性:
- 如果 T=1,则就是softmax,根据logit输出各个类别的概率
- 如果T接近于0,则概率最大值接近于1,其他值接近于0,近似于onehot编码
- T越大,输出的结果的分布就越平缓,相当于平滑的作用。起到保留相思信息的作用。