Bias Variance Trade-off
Prediction Error
motivation
prediction error的来源有两种,一种是bias,另外一种是variance。他们分别对应了模型的underfitting problem和overfitting problem。当优化一个模型时,减小bias和减小variance的方法一般来说是冲突的。因此,明白模型的error主要是由哪种问题引起,可以帮助我们选择正确的优化方向。
bias
bias是模型为了更容易地解决问题,而把问题简化的程度。更详细来说,是模型的预测和对应ground truth之间的差距的均值。它代表了模型预测的平均水平。High bias problem一般是由于模型过于简单,无法学到data的pattern造成的。当然,data中的noise过多,也会造成High bias的结果。
High bias problem也叫uderfitting,当模型有High bias problem时,它在训练集和测试集上的performance一般都很差,且两者之间不会很有明显的gap。下图为一个比较典型的High bias模型训练时的loss变化:
从一开始,训练集上的loss就下降缓慢,或者是训练很久之后,训练集上的loss还持续下降。(此处暂不考虑lr带来的影响)
variance
variance是指对于某一输入,使用从同一分布采样出来的不同数据集训练的模型,这些模型的预测值的variance。当有High variance问题时,模型对训练集比较敏感:一些模型对某些samples能很精准地被预测,而另外一些模型在这些sample上的的performance可能会很差。High variance problem一般是由于模型过度学习:不仅模拟了data中有用的pattern,而且把数据集中的noise也一起学习了。这导致模型极度受到数据集中noise分布的影响,无法学习到真正的映射。
High variance problem也叫overfitting,当模型有High variance problem时,它在训练集和测试集上的performance的差距会很明显:由于过度学习,它在训练集上的表现会很好,但是在测试集上的performance会有明显的下降。下图为一个比较典型的High variance模型训练时的loss变化:
loss在训练集上持续下降,说明模型有足够的能力学习数据的pattern,但是在测试集上的loss在train到一定程度以后反而上升,说明模型开始过度学习一些无关的pattern。
comparison
下图生动地解释了high bias和high variance的不同:
derivation
假设现在需要模拟的映射为:
y
=
f
(
x
)
+
e
r
r
o
r
y=f(x)+error
y=f(x)+error
用模型 f ^ ( x ) \hat{f}(x) f^(x) 来逼近 f ( x ) f(x) f(x),在某点的期望方差为:
E r r o r ( x ) = E [ ( y − f ^ ( x ) ) 2 ] = ( f ^ ( x ) − f ( x ) ) 2 + E [ ( f ^ ( x ) − E [ f ^ ( x ) ] ) 2 ] + σ e 2 = bias 2 + V a r i a n c e + σ e 2 \begin{aligned} Error(x)&=E[(y-\hat{f}(x))^2] \\ \\ &=(\hat{f}(x)-f(x))^2+E[(\hat{f}(x)-E[\hat{f}(x)])^2]+\sigma^2_e \\ \\ &=\text{bias}^2+Variance+\sigma^2_e \end{aligned} Error(x)=E[(y−f^(x))2]=(f^(x)−f(x))2+E[(f^(x)−E[f^(x)])2]+σe2=bias2+Variance+σe2
Analysis
如果模型在测试集上的表现不好,如何判断模型是哪一种问题呢?这会决定之后我们的改进方向。
loss curve
上一节提到的,通过观察loss curve在训练集和测试集上的分别的走势,可以区分模型是否存在underfitting或者overfitting的问题。
learning curve
也可以先固定模型的规模,只改变数据集大小,来观察模型performance随着数据集规模的不同,在训练集和测试集上分别的变化。
1. underfitting
当数据集规模增大时,数据的复杂度也在增大,模型越难准确模拟,因此在训练集上的error会增加。而且由于问题复杂度在前期迅速增加,error上升的速度会较快。但是由于模型能学习到更多有用的pattern,因此在测试集上的error会降低。
若模型存在underfitting的问题,当数据集增大到一定程度时,因为模型的能力到达瓶颈,它无法再学习新的pattern,致使error的变化趋于平缓。当然由于variance问题不严重,在测试集和训练集上的表现相差不大。
2. overfitting
模型在训练集上和测试集上的performance差距明显。因为模型表达力够强,增加数据集规模不会像在简单模型上一样,引起error的骤升。
Solution
general methods
平衡bias和variance以总体减小prediction error:
reduce bias
- 增加模型复杂度
- 检查数据集是否有太多噪音
reduce variance
- 使用更简单的模型:减少参数数目,正则化,dropout
- 增加数据规模
- early stop
个人理解:
underfitting和overfitting问题其实是task的难度和模型复杂度不匹配引发的问题:
- 对于underfitting来说,task太复杂,模型的学习能力不够,因此解决问题的关键是提高模型的复杂度。
- 而对于overfitting的模型来说,task或者说数据集中underlying的关系比较简单,在足够的训练之后,模型有能力“记住”所有samples,从而放弃了学习samples共同的feature。此时的解决方法着重在削减模型的能力,迫使它去学习underlying的特征。增加数据集这个方法,可以潜在地增加数据集的复杂度,给模型更多数据个维度之间的依赖关系的信息,从而提升了task的难度。Early stop是一个常用的训练小tip,模型即使表达力再强,不经过足够的update也无法完全“记住“samples。因此,不过度训练模型也可以限制模型的表达能力,让模型学到一些samples的共同特征就及时停止。
Reference
https://towardsdatascience.com/understanding-the-bias-variance-tradeoff-165e6942b229
https://machinelearningmastery.com/learning-curves-for-diagnosing-machine-learning-model-performance/
Machine Learning (coursera)