https://www.cnblogs.com/jins-note/p/9548012.html
</h1>
<div class="clear"></div>
<div class="postBody">
问题导入
在机器学习领域中,常见的一类工作是使用带标签数据训练神经网络实现分类、回归或其他目的,这种训练模型学习规律的方法一般称之为监督学习。在监督学习中,训练数据所对应的标签质量对于学习效果至关重要。如果学习时使用的标签数据都是错误的,那么不可能训练出有效的预测模型。同时,深度学习使用的神经网络往往结构复杂,为了得到良好的学习效果,对于带标签的训练数据的数量也有较高要求,即常被提到的大数据或海量数据。
矛盾在于:给数据打标签这个工作在很多场景下需要人工实现,海量、高质量标签本身费时费力,在经济上相对昂贵。因此,实际应用中的机器学习问题必须面对噪音标签的影响,即我们拿到的每一个带标签数据集都要假定其中是包含噪声的。进一步,由于样本量很大,对于每一个带标签数据集,我们不可能人工逐个检查并校正标签。
基于上述矛盾现状,在实际工作中必须面对以下两点问题
1. 训练集带标签样本中噪音达到什么水平对于模型预测结果会有致命影响
2. 对于任意给定带标签训练集,如何快速找出可能是噪音的样本
本文接下来将围绕这两点通过实验给出分析
数据、神经网络设计和代码
本文以Tensorflow教程中提及的MNIST问题[1]为数据来源和问题定义。此问题为图像识别问题,图片为手写的0-9字符,每个图片格式为28*28灰度图。训练集数据包括55000张手写数字和标签,验证集包括约10000张图片和标签。通过训练神经网络从而实现当输入一张验证集中的图片后,神经网络能够正确预测这张图片的标签。
对于MNIST问题本身,Tensorflow教程[2]描述的包含2个卷积池化层的CNN网络已经足以实现99%左右的预测精度,因此在本实验中,笔者直接引用Tensorflow官方样例中的CNN网络[3]作为预测模型的神经网络。
本文所有代码可以在笔者的Github项目中获得:wangyaobupt/NoisyLabels
噪声标签对于分类器性能的影响
考虑到MNIST是机器学习领域使用多年的数据库,且在其数据上训练的模型已经得到了较好的结果,由此可以合理推断其标签本身的噪声含量较低(这个推理将在下一个章节通过实验证实)。因此,在这一节的实验中,我们假定原始的MNIST的训练集和验证集标签都是无噪声的。
使用如下步骤给标签添加噪声
1. 根据给定的噪声比例noiseLevel,从N个总样本中选择出K个样本,K = N*noiseLevel
2. 对于选出的K个样本中的每一个样本,将其原始标签替换为0-9之间扣除原始标签之外的随机数
上述算法的代码实现如下,testcase2.py提供了完整的可执行程序
# Add random noise to MNIST training set
# input:
# mnist_data: data structure that follow tensorflow MNIST demo
# noise_level: a percentage from 0 to 1, indicate how many percentage of labels are wrong
def addRandomNoiseToTrainingSet(mnist_data, noise_level):
# the data structure of labels refer to DataSet in tensorflow/tensorflow/contrib/learn/python/learn/datasets/mnist.py
label_data_set = mnist_data.train.labels
#print label_data_set.shape
<span class="n">totalNum</span> <span class="o">=</span> <span class="n">label_data_set</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi" style="color: rgba(0, 132, 255, 1)"><span class="hljs-number">0</span></span><span class="p">]</span>
<span class="n">corruptedIdxList</span> <span class="o">=</span> <span class="n">randomSelectKFromN</span><span class="p">(</span><span class="nb" style="color: rgba(0, 132, 255, 1)"><span class="hljs-built_in">int</span></span><span class="p">(</span><span class="n">noise_level</span><span class="o">*</span><span class="n">totalNum</span><span class="p">),</span><span class="n">totalNum</span><span class="p">)</span>
<span class="c1" style="font-style: italic; color: rgba(153, 153, 153, 1)"><span class="hljs-comment">#print 'DE<span class="hljs-doctag">BUG:</span> 1st elements in corruptedIdxList is: ', corruptedIdxList[0], ' length = ', len(corruptedIdxList)</span></span>
<span class="k"><span class="hljs-keyword">for</span></span> <span class="n">cIdx</span> <span class="ow"><span class="hljs-keyword">in</span></span> <span class="n">corruptedIdxList</span><span class="p">:</span>
<span class="c1" style="font-style: italic; color: rgba(153, 153, 153, 1)"><span class="hljs-comment">#print "DE<span class="hljs-doctag">BUG:</span> convert index =