机器学习好伙伴之scikit-learn的使用——验证曲线

机器学习好伙伴之scikit-learn的使用——验证曲线

什么是验证曲线呢,其内容主要包含当超参数变化时,loss的变化情况。
在这里插入图片描述

什么是验证曲线

验证曲线主要反应的是当超参数变化时,模型的训练状况,常用的表示方法是训练集的loss和测试集的loss与超参数之间的关系,其作用是可以帮助我们选择合适的超参数。其示意图如下:
在这里插入图片描述

sklearn中验证曲线的实现

在进行验证曲线的绘制之前,首先要导入验证曲线的绘制的模块。

from sklearn.model_selection import validation_curve 

验证曲线的绘制的重要函数是:

validation_curve(
	estimator, 
	X, y, 
	param_name, 
	param_range, 
	groups=None, 
	cv=’warn’, 
	scoring=None, 
	n_jobs=None, 
	pre_dispatch=all, 
	verbose=0, 
	error_score=raise-deprecating’
)

其常用参数如下:
1、estimator:用于预测的模型
2、X:预测的特征数据
3、y:预测结果
4、param_name:超参数的名称
5、param_range:超参数的取值范围
6、cv:交叉验证生成器或可迭代的次数
7、scoring:调用的方法
可进行的scoring方式具体可以查阅
https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter
在这里插入图片描述
使用方式如下:

# 从1e-6到1e-2次方,分五段
param_range = np.logspace(-6, -2, 5)

#使用validation_curve快速找出参数对模型的影响
train_loss, test_loss = validation_curve(
    SVC(), X, y, param_name='gamma', 
    param_range=param_range, cv=10, scoring='neg_mean_squared_error')

代表使用SVM的分类模型,输入特征为X,输出label为y,进行10折交叉验证,通过均值平方差的方式计分,学习曲线分为5段,进行绘制的超参数是gamma,选取的范围是param_range。
其一共具有2个返回值,分别是train_loss, test_loss,其中train_loss指的是训练集的loss,其shape为(5,10),第n行对应学习曲线的第n段,第n行的内容代表着第n段的10折交叉验证的结果;test_loss的含义与train_loss类似,其对应的是测试集的loss。

应用示例

代码源自莫烦python教学网站

# 验证曲线模块
from sklearn.model_selection import validation_curve 
# 导入digits数据集
from sklearn.datasets import load_digits 
# 支持向量机
from sklearn.svm import SVC 
import matplotlib.pyplot as plt
import numpy as np

digits = load_digits()
X = digits.data
y = digits.target

# 建立参数测试集
# 从1e-6到1e-2次方,分五段
param_range = np.logspace(-6, -2, 5)

#使用validation_curve快速找出参数对模型的影响
train_loss, test_loss = validation_curve(
    SVC(), X, y, param_name='gamma', param_range=param_range, cv=10, scoring='neg_mean_squared_error')
    
# loss值为负数,需要取反
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1)

# 设置样式与label
plt.plot(param_range, train_loss_mean, 'o-', color="r",
         label="Training")
plt.plot(param_range, test_loss_mean, 'o-', color="g",
        label="Cross-validation")

plt.xlabel("Training examples")
plt.ylabel("Loss")
# 显示图例
plt.legend(loc="best")
plt.show()

实验结果为:
在这里插入图片描述

  • 4
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Bubbliiiing

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值