Data Distillation: Towards Omni-Supervised Learning
这是一种挑战真实世界数据的 self-train 的方法,在Kaggle等大数据竞赛中非常有用。
- Omni-Supervised Learning:全方位监督学习,属于半监督(semi-supervised )中的一种,使用带标签的数据和不带标签的其他数据进行学习,可以一定程度上突破带标签数据的性能限制;
- data distillation:对没有标签的数据进行多种变换(类似与数据增强),使用单模型进行预测,然后集成预测结果,自动生成图像标签的方法;
问题:
通常情况,将模型自己预测的数据加入到训练集中无法提供有意义的信息,例如将分类置信度高的数据加入训练集中重新训练,但是置信度高说明网络已经可以提取用于识别这种数据的特征,再把这种数据加入到训练集,属于无用的劣质数据,或者说这种数据的用处很小,并且会破坏数据分布,使网络的假阳性过高,模型假优;对网络性能提升有帮助的反而是那些预测的置信度低的数据,但是预测结果又不准,无法加入训练集;
解决方法:
- 使用多模型预测,将多模型的预测结果集成,当作数据的label,加入训练集,例如 Distilling the Knowledge in a Neural Network
- 使用单模型,但是对数据进行类似于数据增强的变换,将模型对这些同一张图像的不同变换后的预测结果进行集成,当作预测数据的最终label,加入训练集,例如 Data Distillation: Towards Omni-Supervised Learning(就是本篇文章)
步骤
- 首先在人工标注的数据集上训练模型
- 对无标注的数据进行多种变换,使用训练的模型预测每种变化后的数据
- 多种crop
- 多种scale
- 其他(类似于数据增强)
- 将预测结果集成,当作最终的标签
- 将标注的数据加入训练集,重新训练模型
预测结果集成
- 集成的后的准确率高于单一图像的预测
- 集成的结果生成新的知识,模型可以学习这些知识提高性能,达到自学习的效果(个人观点: 图像变换之后,模型预测时关注的重点各不相同,集成结果可以将这些不同的关注点结合起来,以供模型进行学习)
集成效果如下图
对于集成数据变换后的模型预测结果,作者提出了一种直接生成 hard label 的方法,不用改变损失函数或其他地方来适应普通集成方法生成的 soft label。
- 普通的集成方法:例如分类,使用模型类别预测结果置信度的平均值当作最终的 label,但是这种方法生成的 label 是概率,属于 soft label,需要改变损失函数以兼容这种 soft label。另外,例如目标检测、人体姿态估计等这种输出是结构化向量的问题,对预测结果进行平均会改变输出的结构,生成错误标签;
- 本文方法:直接以某种方式集成多种变换后的预测结果,直接输出与人工标注结果相同的、不改变输出空间结构的 hard label; (没看到是什么方式集成的,是我没看懂?上原文)
数据蒸馏(重新训练)
- 在扩增数据重新训练时,保证每个 batch 的图像中既有自动生成标签的图像,也有原本的手工标注的数据;
- 训练策略需要调整,以适应数据集的增加;
实验验证
作者在COCO数据集上使用人体关键点检测进行了验证,比Mask-RCNN 的 baseline 提高了 2 个百分点的 AP,相比之下,扩增人工标注的数据仅增加了 3 个百分点的AP。