机器学习好伙伴之scikit-learn的使用——学习曲线

机器学习好伙伴之scikit-learn的使用——学习曲线

什么是学习曲线呢,其内容主要包含当训练量增加时,loss的变化情况。
在这里插入图片描述

什么是学习曲线

学习曲线主要反应的是学习的一个过程,常用的表示方法是训练集的loss和测试集的loss与训练量之间的关系。其示意图如下:
在这里插入图片描述

sklearn中学习曲线的实现

在进行学习曲线的绘制之前,首先要导入学习曲线的绘制的模块。

from sklearn.model_selection import learning_curve

学习曲线的绘制的重要函数是:

learning_curve(
	estimator, 
	X, y, 
	train_sizes=array([0.1, 0.325, 0.55, 0.775, 1. ]), 
	cv=None, 
	scoring=None, 
	exploit_incremental_learning=False, 
	n_jobs=1, 
	pre_dispatch='all', 
	verbose=0
)

其常用参数如下:
1、estimator:用于预测的模型
2、X:预测的特征数据
3、y:预测结果
4、train_sizes:训练样本相对的或绝对的数字,这些量的样本将会生成learning curve,当其为[0.1, 0.325, 0.55, 0.775, 1. ]时代表使用10%训练集训练,32.5%训练集训练,55%训练集训练,77.5%训练集训练100%训练集训练时的分数。
5、cv:交叉验证生成器或可迭代的次数
6、scoring:调用的方法
可进行的scoring方式具体可以查阅
https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter
在这里插入图片描述
使用方式如下:

train_sizes, train_loss, test_loss = learning_curve(
    SVC(gamma=0.01), X, y, cv=10, scoring='neg_mean_squared_error',
    train_sizes=np.linspace(.1, 1.0, 5))

代表使用SVM的分类模型,输入特征为X,输出label为y,进行10折交叉验证,通过均值平方差的方式计分,学习曲线分为5段。
其一共具有3个返回值,分别是train_sizes, train_loss, test_loss,其中train_loss指的是训练集的loss,其shape为(5,10),第n行对应学习曲线的第n段,第n行的内容代表着第n段的10折交叉验证的结果;test_loss的含义与train_loss类似,其对应的是测试集的loss。

应用示例

代码源自莫烦python教学网站

# 学习曲线模块
from sklearn.model_selection import learning_curve 
# 导入digits数据集
from sklearn.datasets import load_digits 
# 支持向量机
from sklearn.svm import SVC 
import matplotlib.pyplot as plt
import numpy as np

digits = load_digits()
X = digits.data
y = digits.target

# neg_mean_squared_error代表求均值平方差
train_sizes, train_loss, test_loss = learning_curve(
    SVC(gamma=0.01), X, y, cv=10, scoring='neg_mean_squared_error',
    train_sizes=np.linspace(.1, 1.0, 5))

# loss值为负数,需要取反
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1)

# 设置样式与label
plt.plot(train_sizes, train_loss_mean, 'o-', color="r",
         label="Training")
plt.plot(train_sizes, test_loss_mean, 'o-', color="g",
        label="Cross-validation")

plt.xlabel("Training examples")
plt.ylabel("Loss")
# 显示图例
plt.legend(loc="best")
plt.show()

实验结果为:
在这里插入图片描述
如上图所示的训练结果存在过拟合的现象。
调整GAMMA = 0.001后过拟合现象消失,Cross-validation不再上升。
在这里插入图片描述

  • 2
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Bubbliiiing

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值