1. 引言
Label Smoothing 又被称之为标签平滑,常常被用在分类网络中来作为防止过拟合的一种手段,整体方案简单易用,在小数据集上可以取得非常好的效果。
Label Smoothing 做为一种简单的训练trick,可以通过很少的代价(只需要修改target的编码方式),即可获得准确率的提升,本文就其原理和具体实现进行介绍,希望可以帮主大家理解其背后的具体原理。
2. 初识
我们首先来看Label Smoothing的公式,在介绍之前我们先来观察一下传统的 one-hot encoding的公式,如下:
而Label Smoothing 导入了一个factor机制,公式改变如下:
只看公式,多少有些难懂,好嘛!我们来举个栗子瞧瞧啦。。。
不妨假设我们今天有四个类别,分别为dog,cat,bird,turtle,我们对其进行编码,即
dog = 0 , cat = 1 , bird = 2 , turtle = 3
- 我们采用one-hot对其进行编码,结果如下:
dog = [ 1 , 0 , 0 , 0 ]
cat = [ 0 , 1, 0 , 0 ]
bird = [ 0 , 0 , 1 , 0 ]
turtle = [ 0, 0 , 0 , 1 ]
- 采用Label smoothing对其编码,引入一个factor来将其几率分配给其他类别 ,这里假设factor = 0.1,则生成的标签如下:
dog = [ 0.9 , 0.03 , 0.03 , 0.03 ]
cat = [ 0.03 , 0.9 , 0.03 , 0.03 ]
bird = [ 0.03 , 0.03 , 0.9 , 0.03 ]
turtle = [ 0.03 , 0.03 , 0.03 , 0.9 ]
3. 深入
有了上述直观的理解,相必大家对Label Smoothing有了简单的认识,接着我们来思考这样的改变会对损失函数带来什么样的影响。
为此我们先来看一下分类任务中最常见的cross entropy损失函数,如下:
接着我们使用上述四个类别,来看看正确分类时Loss的计算,如下:
观察上图可以看出,如果整体分类呈现正常梯度下降的话,使用Label Smoothing相比不使用的loss下降相对比较小。
那反过来,如果网络越学预测效果越差呢?
通过上图可以看出,就算训练阶段预测错误时使用Label Smoothing的loss也相比之前惩罚的更小(扣得更少)。
4. 实现
我们来对Label Smoothing技术,作如下总结:
- 使用了Label Smoothing损失函数后,在训练阶段预测正确时 loss 不会下降得太快,预测错误的時候 loss 不会惩罚得太多,使其不容易陷入局部最优点,这在一定程度可以抑制网络过拟合的现象。
- 对于分类类别比较接近的场景,网络的预测不会过于绝对,在引入Label Smoothing技巧后,通过分配这些少数的几率也可以使得神经网络在训练的时候不这么绝对。
接着,我们来用Python对其实现,代码如下:
def label_smoothing(labels, factor=0.1):
num_labels = labels.get_shape().as_list()[-1]
labels = ((1-factor) * labels) + (factor/ num_labels)
return labels
5. 经验分享
在实际调参的一些经验分享如下:
- 不管是在object detection的分类网络或者是多分类网络导入label smoothing皆有不错的效果,基本上算轻松又容易提升准确度的做法
- 当数据量足够多的时候,Label smoothing这个技巧很容易使网络变得欠拟和。
- factor通常设置为0.1,之前做对比实验试过使用0.2,0.3等参数,会发现皆无较好的效果,反而使网络变得难以收敛。
- 可以利用label smoothing的特性来做点微小的改动,比如遇上相似类型的事物时,可以将factor分配给相似的类别,而不是全部类别,这通常会有不错的效果。
6. 总结
本文对Label Smoothing技术进行了简单的介绍,通过简单的例子来增加大家的直观认识,最后分享了该技巧的代码实现。
您学废了吗?
关注公众号《AI算法之道》,获取更多AI算法资讯。
参考: 论文