拟合学习曲线的纵坐标表示什么_用学习曲线 learning curve 来判别过拟合问题

[导读]学习曲线就是通过画出不同训练集大小时训练集和交叉验证的准确率,可以看到模型在新数据上的表现,进而来判断模型是否方差偏高或偏差过高,以及增大训练集是否可以减小过拟合。

学习曲线是什么?

学习曲线就是通过画出不同训练集大小时训练集和交叉验证的准确率,可以看到模型在新数据上的表现,进而来判断模型是否方差偏高或偏差过高,以及增大训练集是否可以减小过拟合。

怎么解读?

当训练集和测试集的误差收敛但却很高时,为高偏差。

左上角的偏差很高,训练集和验证集的准确率都很低,很可能是欠拟合。

我们可以增加模型参数,比如,构建更多的特征,减小正则项。

此时通过增加数据量是不起作用的。

当训练集和测试集的误差之间有大的差距时,为高方差。

当训练集的准确率比其他独立数据集上的测试结果的准确率要高时,一般都是过拟合。

右上角方差很高,训练集和验证集的准确率相差太多,应该是过拟合。

我们可以增大训练集,降低模型复杂度,增大正则项,或者通过特征选择减少特征数。

理想情况是是找到偏差和方差都很小的情况,即收敛且误差较小。

怎么画?

在画学习曲线时,横轴为训练样本的数量,纵轴为准确率。

例如同样的问题,左图为我们用 naive Bayes 分类器时,效果不太好,分数大约收敛在 0.85,此时增加数据对效果没有帮助。

右图为 SVM(RBF kernel),训练集的准确率很高,验证集的也随着数据量增加而增加,不过因为训练集的还是高于验证集的,有点过拟合,所以还是需要增加数据量,这时增加数据会对效果有帮助。

上图的代码如下:

模型这里用 GaussianNB 和 SVC 做比较,

模型选择方法中需要用到 learning_curve 和交叉验证方法 ShuffleSplit。import numpy as npimport matplotlib.pyplot as pltfrom sklearn.naive_bayes import GaussianNBfrom sklearn.svm import SVCfrom sklearn.datasets import load_digitsfrom sklearn.model_selection import learning_curvefrom sklearn.model_selection import ShuffleSplit

首先定义画出学习曲线的方法,

核心就是调用了  sklearn.model_selection 的 learning_curve,

学习曲线返回的是 train_sizes, train_scores, test_scores,

画训练集的曲线时,横轴为 train_sizes, 纵轴为 train_scores_mean,

画测试集的曲线时,横轴为 train_sizes, 纵轴为 test_scores_mean:def plot_learning_curve(estimator, title, X, y, ylim=None, cv=None,

n_jobs=1, train_sizes=np.linspace(.1, 1.0, 5)):~~~

train_sizes, train_scores, test_scores = learning_curve(

estimator, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes)

train_scores_mean = np.mean(train_scores, axis=1)

test_scores_mean = np.mean(test_scores, axis=1)

~~~

在调用 plot_learning_curve 时,首先定义交叉验证 cv 和学习模型 estimator。

这里交叉验证用的是 ShuffleSplit, 它首先将样例打散,并随机取 20% 的数据作为测试集,这样取出 100 次,最后返回的是 train_index, test_index,就知道哪些数据是 train,哪些数据是 test。

estimator 用的是 GaussianNB,对应左图:cv = ShuffleSplit(n_splits=100, test_size=0.2, random_state=0)estimator = GaussianNB()

plot_learning_curve(estimator, title, X, y, ylim=(0.7, 1.01), cv=cv, n_jobs=4)

再看 estimator 是 SVC 的时候,对应右图:cv = ShuffleSplit(n_splits=10, test_size=0.2, random_state=0)estimator = SVC(gamma=0.001)

plot_learning_curve(estimator, title, X, y, (0.7, 1.01), cv=cv, n_jobs=4)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值