数据蒸馏: Data Distillation: Towards Omni-Supervised Learning

Data Distillation: Towards Omni-Supervised Learning

这是一种挑战真实世界数据的 self-train 的方法,在Kaggle等大数据竞赛中非常有用。

  • Omni-Supervised Learning:全方位监督学习,属于半监督(semi-supervised )中的一种,使用带标签的数据和不带标签的其他数据进行学习,可以一定程度上突破带标签数据的性能限制;
  • data distillation:对没有标签的数据进行多种变换(类似与数据增强),使用单模型进行预测,然后集成预测结果,自动生成图像标签的方法;

问题:

通常情况,将模型自己预测的数据加入到训练集中无法提供有意义的信息,例如将分类置信度高的数据加入训练集中重新训练,但是置信度高说明网络已经可以提取用于识别这种数据的特征,再把这种数据加入到训练集,属于无用的劣质数据,或者说这种数据的用处很小,并且会破坏数据分布,使网络的假阳性过高,模型假优;对网络性能提升有帮助的反而是那些预测的置信度低的数据,但是预测结果又不准,无法加入训练集;

解决方法:

步骤

  • 首先在人工标注的数据集上训练模型
  • 对无标注的数据进行多种变换,使用训练的模型预测每种变化后的数据
    • 多种crop
    • 多种scale
    • 其他(类似于数据增强)
  • 将预测结果集成,当作最终的标签
  • 将标注的数据加入训练集,重新训练模型

预测结果集成

  • 集成的后的准确率高于单一图像的预测
  • 集成的结果生成新的知识,模型可以学习这些知识提高性能,达到自学习的效果(个人观点: 图像变换之后,模型预测时关注的重点各不相同,集成结果可以将这些不同的关注点结合起来,以供模型进行学习)
    集成效果如下图
    在这里插入图片描述

对于集成数据变换后的模型预测结果,作者提出了一种直接生成 hard label 的方法,不用改变损失函数或其他地方来适应普通集成方法生成的 soft label。

  • 普通的集成方法:例如分类,使用模型类别预测结果置信度的平均值当作最终的 label,但是这种方法生成的 label 是概率,属于 soft label,需要改变损失函数以兼容这种 soft label。另外,例如目标检测、人体姿态估计等这种输出是结构化向量的问题,对预测结果进行平均会改变输出的结构,生成错误标签;
  • 本文方法:直接以某种方式集成多种变换后的预测结果,直接输出与人工标注结果相同的、不改变输出空间结构的 hard label; (没看到是什么方式集成的,是我没看懂?上原文)


    在这里插入图片描述

数据蒸馏(重新训练)

  • 在扩增数据重新训练时,保证每个 batch 的图像中既有自动生成标签的图像,也有原本的手工标注的数据;
  • 训练策略需要调整,以适应数据集的增加;

实验验证

作者在COCO数据集上使用人体关键点检测进行了验证,比Mask-RCNN 的 baseline 提高了 2 个百分点的 AP,相比之下,扩增人工标注的数据仅增加了 3 个百分点的AP。

  • 6
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值