最近关注深度学习泛化能力评估与解释的问题,重读文章understanding deep learning requires rethinking generalization,该文是ICLR2017的best paper之一,到目前为止引用量已达2000多,分享一点阅读笔记,方便无暇读文的童鞋参考。
基本概念
泛化能力
- 定义:从已知数据学习的模型对未知数据的预测能力。
- 泛化误差:模型对已知数据的预测能力和对未知数据的预测能力之间的差距,常用训练集和测试集预测能力之间的差距表示。
- 模型预测能力度量指标:auc/roc/f1-score/accuracy/ks等等。
- 下面是随着训练样本规模的增大,训练集和验证集之间的预测准确性gap变小,泛化误差降低,泛化能力提高。
过拟合/欠拟合
- 定义:模型在训练集上预测能力极强,在训练集以外的其他数据集上很差,表示过拟合;模型在训练集上表现较差,表示欠拟合。
- 原因:模型过于复杂,训练样本噪声多、数据量较少等会造成过拟合,如果模型过于简单不足以表达特征到标签的映射函数,或者训练过程提前终止,会导致欠拟合。
- 下图表示三种拟合状态
泛化能力与过拟合/欠拟合
通常来说,过拟合和欠拟合都会导致模型泛化能力差。前者是从偏差和方差的角度评价,后者是从模型迁移到未知测试集上的预测能力gap角度评价。
问题|实验|观点
围绕问题“深度模型在复杂度较高的情况下,有时也会有较好的泛化能力,这个现象与传统的机器学习经验相悖”,进行了一系列实验,并提出了一些观点。
神经网络的表达能力
- 具有一定复杂度的神经网络可拟合白噪声和随机标签,且能够快速收敛(不比拟合有信号数据花费时间更长)。
- 数据中噪声的比例会影响模型的泛化能力和达到过拟合所需时间,噪声比例越高,泛化能力越差,达到过拟合所需时间越长。
- 在同时包含噪声和信号的数据中,模型会率先抓住“信号”,快速收敛,然后对噪声进行暴力拟合,达到过拟合状态,噪声的比例影响达到过拟合的时间;
- 神经网络的参数个数远高于训练样本个数时,模型也可能有较好的泛化能力,泛化误差较小。
- 实验数据如下
正则化对模型泛化能力的影响
- 在神经网络中添加正则项,并不一定能够降低模型的泛化误差,对正则项进行调参优化,往往可以提升测试集的预测准确率。
- 在正则化的诸多方法中,增加weight decay对泛化能力影响较小,但是扩大数据集,提升数据集的丰富度则对泛化能力有较大提升。
- 随机梯度下降算法在通向收敛的过程中,倾向于向”最小正则化“的方向发展,这种隐性正则化因素是导致复杂深度模型泛化能力强的原因之一,这个是论文的一个重要观点。
- 实验对比
有限样本上的表达能力
这部分主要是简化模型结构,来证明SGD中的隐性正则化因素对模型泛化能力的影响。先说明对特定规模的数据,什么规模的模型有完全表达能力,然后简化到线性模型,在线性模型下推理说明SGD在收敛过程中起到的隐性正则化作用。
- 复杂度为O(n)的双层神经网络具有对有限数据集的完全表达能力,其中n是样本的规模,定理描述如下,证明部分可自读论文。
- 模型泛化能力的上限与模型的规模无关,经验上来说,丰富数据可以提升模型的泛化能力。
- SGD中的隐性正则化项越小,模型的泛化能力并不一定越高。
结论
本文首先通过几组实验证明神经网络AlexNet, Inception 以及 ImageNet等具有极强的表达能力,甚至可以拟合噪声和随机标签数据,在这些模型的参数数量超过训练样本的数量时,仍然具有良好的泛化能力,这就对了传统的机器学习经验提出了挑战(模型复杂度过高,往往导致过拟合,从而导致泛化误差高);进而探讨传统的正则化方法如weight decay, dropout 以及data augment等显性正则化方法对模型泛化能力的影响,实验发现,weight decay和dropout对模型的泛化能力影响较小,data augment经验上可以大幅降低泛化误差,在这个结论的基础上,提出深度复杂网络的泛化能力是由什么因素导致的问题,猜想SGD在通向收敛的过程中,是不是存在一些影响模型泛化能力的因素?
在探讨这个问题前,首先提出了对于有限样本来说,线性规模的神经网络对其具有完全表达能力,然后简化到线性模型,理论推导线性模型在SGD收敛过程中,总是向着正则化最小的方向收敛,因此得出SGD收敛过程中隐性正则化特质可能是造成复杂神经网络泛化能力强的原因之一。
一点思考
- 从文中的对比实验可以看出,在实际应用中,对于有限的数据集,不需要涉及过于复杂的网络结构(只需相对于样本集先行规模),就可以完全表达数据集的标签映射函数。
- 传统的正则化方法并不一定能够降低泛化误差,但是经验上,正则化项经过调优后,往往可以提升测试集的预测准确性。
- 提升训练集的数据丰富度可以大幅度提升模型的泛化能力。
- 有必要研究一些泛化能力强的模型的结构,理解导致泛化能力强的原因,从而在设计模型时应用之。