1.背景
ensemble learning 通常可以有效提升模型性能,但是通常在在线inference的时候不可能, 被称为cumbersome model
之前的研究表明possible to compress the knowledge in an ensemble into a single model(核心思想)
ensemble model -> single model
或者cumbersome model -> efficient model
这个操作被称为模型蒸馏 distillation
一个常见的思维障碍: 认为trained model的knowledge来源于参数值,因此 hard to change form with same knowledge
更抽象的看法: 模型不属于任何实例,只是一个learned 输入到输出的映射
one-hot 编码, 用cross entropy训练时,模型总是去最大化正确类的概率,但同时,模型的其他预测可以给我们很多关于模型是如何泛化的信息,
比如, BMW 在是别的时候,误分类为benz的概率会比 误分类为sea的概率大很多,从视觉上来讲,benz与bmw 更加接近,但是在训练标注中,没有这部分信息,
而且人为也很难量化,类似的有一个label-smoothing, 但是label-smoothing的平滑有点强行, 缺乏先验知识,效果一般,
感觉distillation 有一点高级版本的label-smooth的感觉,
同理, weight target 也有一点类似的感觉,比如在色情分类中, 如果一个样本是porn, 那误分为normal 的惩罚和误分为 normal 惩罚 normal更多,
都是为了加入更多的先验知识, 即学到correct way to generalize
方法:
用cumbersome模型输出的概率vector作为训练小模型的soft target
在transfer的过程中,既可以用相同的训练集,也可以用独立的迁移数据集
当cumbersome model是集成模型时,用他输出的算数平均或者几何平均作为soft target
soft target 熵更高,对于单个case提供更多信息, much less variance in the gradient between training cases
因此, small model经常可以使用更小的数据集和更大的学习率
提出了Temperature 的概念,可以使label 更加smooth
the transfer set could consist entirely of unlabeled data !!关键
or we could use the original training set,
但是 we found that using the original training works well
add a small term to the objective function that encourages the small model
to predict the true targets as well as matching the soft targets provided by the cumbersome model.
也就是说, 在实际训练的时候,用带标注的训练集作为hard target ,再加上cumbersome生成的soft target 结合,效果会更好
最开始理解为训练的实际target 就是soft target 和true label的一个加权, 后来文章中提到了这个,并说明了一种更好的计算loss的方法
2.Distillation 公式:
soft target : exp(z i /T ) / sigma exp(z j /T )
T 越大,越smooth,
蒸馏方法
1.simplest form of distillation:
纯用无标注的transfer set , 大小模型都用同样的temperature, 训练解说后, 在预测的时候把temperature改为1
2.
如果有correct labels of transfer set可以把这部分信息也加进来
第一个方法: use the correct labels to modify the soft targets 可以理解为用soft target和hard target 的加权作为最终算loss 的target
第二个方法 but we found that a better way is to simply use a weighted average of two different objective functions
即先计算soft loss再计算hard loss 然后再加权
soft loss 大模型和小模型采用相同的temperature, hard loss , 小模型所用的temperature为1
同时,发现,再计算total loss的时候, hard loss的权重要很低, 原因:
在不加weight 的情况下, soft target 产生的梯度只有hard target的 1/T2 因此, 将soft loss 乘以T2 , 保证在temperature 变化的时候,
soft loss 和hard loss对梯度的相对贡献不变
2.1证明了logits 匹配是蒸馏的一种特殊情况, 即拟合vi 和 zi
后面的都是各种测试结果,不详细说了
总结
1.步骤:
1) 用标注数据训练一个cumbersome的模型,可以用各种,,大模型,各种集成
2) 用这个模型的带temperature参数的softmax输出作为soft target,
如果是用完全soft target,可以用一个全新的未标注训练集,然后采用相同的temperature训练,训练完预测时temperature改为1
如果是hard target和soft target 结合的方式, 即loss = loss_hard + loss_soft * T2 (重要)