python中sklearn实现交叉验证

http://blog.csdn.net/ztchun/article/details/71169530


1、概述

在实验数据分析中,有些算法需要用现有的数据构建模型,如卷积神经网络(CNN),这类算法称为监督学习(Supervisied Learning)。构建模型需要的数据称为训练数据。

模型构建完后,需要利用数据验证模型的正确性,这部分数据称为测试数据。测试数据不能用于构建模型中,只能用于最后检验模型的准确性。

有时候模型的构建的过程中,也需要检验模型,辅助模型构建。所以会将训练数据分为两个部分,1)训练数据;2)验证数据。 
将数据分类就要采用交叉验证的方法,个人写的交叉验证算法不可避免有一定缺陷,考虑使用强大sklearn包可以实现交叉验证算法。

2、python实现

请注意:以下的方法实现根据最新的sklearn版本实现,老版本的函数很多已经过期。

2.1 K次交叉检验(K-Fold Cross Validation)

K次交叉检验的大致思想是将数据大致分为K个子样本,每次取一个样本作为验证数据,取余下的K-1个样本作为训练数据。

from sklearn.model_selection import KFold
import numpy as np
X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
y = np.array([1, 2, 3, 4])
kf = KFold(n_splits=2)

for train_index, test_index in kf.split(X):
    print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

2.2 Stratified k-fold

StratifiedKFold()这个函数较常用,比KFold的优势在于将k折数据按照百分比划分数据集,每个类别百分比在训练集和测试集中都是一样,这样能保证不会有某个类别的数据在训练集中而测试集中没有这种情况,同样不会在训练集中没有全在测试集中,这样会导致结果糟糕透顶。

from sklearn.model_selection import StratifiedKFold
import numpy as np

X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
y = np.array([0, 0, 1, 1])
skf = StratifiedKFold(n_splits=2)
for train_index, test_index in skf.split(X, y):
    print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

2.3 train_test_split

随机根据比例分配训练集和测试集。这个函数可以调整随机种子。

import numpy as np
from sklearn.model_selection import train_test_split
X, y = np.arange(10).reshape((5, 2)), range(5)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.33, random_state=42)
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

3、总结

sklearn功能非常强大提供的交叉验证函数也非常多,如:

           'BaseCrossValidator',
           'GridSearchCV',
           'TimeSeriesSplit',
           'KFold',
           'GroupKFold',
           'GroupShuffleSplit',
           'LeaveOneGroupOut',
           'LeaveOneOut',
           'LeavePGroupsOut',
           'LeavePOut',
           'ParameterGrid',
           'ParameterSampler',
           'PredefinedSplit',
           'RandomizedSearchCV',
           'ShuffleSplit',
           'StratifiedKFold',
           'StratifiedShuffleSplit',
           'check_cv',
           'cross_val_predict',
           'cross_val_score',
           'fit_grid_point',
           'learning_curve',
           'permutation_test_score',
           'train_test_split',
           'validation_curve'
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

感兴趣的可以查看sklearn源码。


  • 1
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值