《Distilling the Knowledge in a Neural Network》
模型蒸馏
所谓模型蒸馏就是将训练好的复杂模型推广能力“知识”迁移到一个结构更为简单的网络中。或者通过简单的网络去学习复杂模型中“知识”。其基本流程如下图:
基本可以分为两个阶段:
原始模型训练:
根据提出的目标问题,设计一个或多个复杂网络(N1,N2,…,Nt)。
收集足够的训练数据,按照常规CNN模型训练流程,并行的训练1中的多个网络得到。得到(M1,M2,…,Mt)
精简模型训练:
根据(N1,N2,…,Nt)设计一个简单网络N0。
收集简单模型训练数据,此处的训练数据可以是训练原始网络的有标签数据,也可以是额外的无标签数据。
将2中收集到的样本输入原始模型(M1,M2,…,Mt),修改原始模型softmax层中温度参数T为一个较大值如T=20。每一个样本在每个原始模型可以得到其最终的分类概率向量,选取其中概率至最大即为该模型对于当前样本的判定结果。对于t个原始模型就可以t概率向量。然后对t概率向量求取均值作为当前样本最后的概率输出向量,记为soft_target,保存。
标签融合2中收集到的数据定义为hard_target,有标签数据的hard_target取值为其标签值1,无标签数据hard_taret取值为0。Target =a*hard_target + b*soft_target(a+b=1)。Target最终作为训练数据的标签去训练精简模型。参数a,b是用于控制标签融合权重的,推荐经验值为(a=0.1 b=0.9)
设置精简模型softmax层温度参数与原始复杂模型产生Soft-target时所采用的温度,按照常规模型训练精简网络模型。
部署时将精简模型中的softmax温度参数重置为1,即采用最原始的softmax