Online Knowledge Distillation via Collaborative Learning(阅读笔记)
模型图
h是对数据的随即增强 函数,logits i是每个模型的输出。Lce为交叉熵损失函数,Lkd为KL散度。
本文中每个Network都是学生,每个学生生成的logits经过本文提出的方法集成一个Ensemble logits,把它作为教师教给每个学生Knowledge。由于这个Ensemble logits是在线生成的(每次输出都会产生不同的Ensemble logits),所以说是通过协作学习的在线知识蒸馏。
Method
本文思路还是比较清晰的,每个Network有两个损失函数,一个是Lce(module_out,label),还有一个则为KL散度,KL散度是衡量两个概率分布相似性的函数,这里Lkd=(softmax(out/T),softmax(Ensemble/T)),T为温度参数。
论文中给出了四个生成Ensemble logits的方法,把我的理解分享一下吧。
1.KDCL–Naive:
Zt是Ensembles,Zk是每个Network的logits,y是真实label。
这里方法很简答,寻找那个Lce最小的zk,然后把它最为zt。
2.KDCL-Linear:
Linear方法思路也很简单,然后寻找一个合适的线性组合,把所有的Zk加权相加后得到的Ensembles log与y的Lce最小。
这里就相当于在原来模型基础上又引入了一个优化问题,即寻找最优线性组合,使得满足我们的条件。公式中Z为zk的矩阵,a为权重。
3.KDCL-Minlogits
第三个我有点没看明白,也没找到他的官方code。所以在这简单说一下我的理解:
Network得到的logit中对每一类都有一个得分,经过softmax之后,得分高的类经过softmax生成的概率就大,比如预测为c类时,作者希望c类得分高,其余类得分都要小得多,所以作者就是要找这么一个logit,使得预测为c类时,Ensemble logit除了c外其他位置得分尽可能小,但是它的式子我没看到,可能我理解错了,有兴趣的可以去原文看下。
4.KDCL-General:
作者首先引入了一个泛化误差的概念:
输入一个值为x的数据输出为f(x),真实label值为t,此时的泛化误差为
在此基础上,作者把每个模型的预测输出fx加了权重w后当作 Ensemble logits,然后用总的泛化误差来衡量这个logits,w的初始值设为 1/m(共m个网络)。 具体公式如下:
5.ICL
为了提高模型在数据域的扰动不变性(鲁棒性),作者提出了ICL这种数据取样方法。具体就是对每个子网路抽取具有相同数据增强方法的图片进行处理生成Ensemble Logits。
创新:
1.ICL
2.在线:
每一次计算都可以根据每个学生网络的输出生成Ensemble logits,不用预训练复杂的Teacher。