sklearn.model_selection.KFold

本文介绍了sklearn.model_selection.KFold的使用,讲解了如何进行K折交叉验证,重点讨论了n_splits参数、shuffle参数和random_state的影响。通过实例展示了当数据集不能均等划分时,设置不同的参数如何影响结果,并演示了如何获取n_splits的值。
摘要由CSDN通过智能技术生成

K折交叉验证:sklearn.model_selection.KFold(n_splits=3, shuffle=False, random_state=None)

思路:将训练/测试数据集划分n_splits个互斥子集,每次用其中一个子集当作验证集,剩下的n_splits-1个作为训练集,进行n_splits次训练和测试,得到n_splits个结果

注意点:对于不能均等份的数据集,其前n_samples % n_splits子集拥有n_samples // n_splits + 1个样本,其余子集都只有n_samples // n_splits样本

参数说明:

n_splits:表示划分几等份

shuffle:在每次划分时,是否进行洗牌

①若为Falses时,其效果等同于random_state等于整数,每次划分的结果相同

②若为True时,每次划分的结果都不一样,表示经过洗牌,随机取样的

random_state:随机种子数

属性:

①get_n_splits(X=None, y=None, groups=None):获取参数n_splits的值

②split(X, y=None, groups=None):将数据集划分成训练集和测试集,返回索引生成器

通过一个不能均等划分的栗子,设置不同参数值,观察其结果

①设置shuffle=False,运行两次,发现两次结果相同

[python]  view plain  copy
  1. In [1]: from sklearn.model_selection import KFold  
  2.    ...: import numpy as np  
  3.    ...: X = np.arange(24).reshape(12,2)  
  4.    ...: y = np.random.choice([1,2],12,p=[0.4,0.6])  
  5.    ...: kf = KFold(n_splits=5,shuffle=False)  
  6.    ...: for train_index , test_index in kf.split(X):  
  7.    ...:     print('train_index:%s , test_index: %s ' %(train_index,test_index))  
  8.    ...:  
  9.    ...:  
  10. train_index:[ 3  4  5  6  7  8  9 10 11] , test_index: [0 1 2]  
  11. train_index:[ 0  1  2  6  7  8  9 10 11] , test_index: [3 4 5]  
  12. train_index:[ 0  1  2  3  4  5  8  9 10 11] , test_index: [6 7]  
  13. train_index:[ 0  1  2  3  4  5  6  7 10 11] , test_index: [8 9]  
  14. train_index:[0 1 2 3 4 5 6 7 8 9] , test_index: [10 11]  
  15.   
  16. In [2]: from sklearn.model_selection import KFold  
  17.    ...: import numpy as np  
  18.    ...: X = np.arange(24).reshape(12,2)  
  19.    ...: y = np.random.choice([1,2],12,p=[0.4,0.6])  
  20.    ...: kf = KFold(n_splits=5,shuffle=False)  
  21.    ...: for train_index , test_index in kf.split(X):  
  22.    ...:     print('train_index:%s , test_index: %s ' %(train_index,test_index))  
  23.    ...:  
  24.    ...:  
  25. train_index:[ 3  4  5  6  7  8  9 10 11] , test_index: [0 1 2]  
  26. train_index:[ 0  1  2  6  7  8  9 10 11] , test_index: [3 4 5]  
  27. train_index:[ 0  1  2  3  4  5  8  9 10 11] , test_index: [6 7]  
  28. train_index:[ 0  1  2  3  4  5  6  7 10 11] , test_index: [8 9]  
  29. train_index:[0 1 2 3 4 5 6 7 8 9] , test_index: [10 11]  
②设置shuffle=True时,运行两次,发现两次运行的结果不同

[python]  view plain  copy
  1. In [3]: from sklearn.model_selection import KFold  
  2.    ...: import numpy as np  
  3.    ...: X = np.arange(24).reshape(12,2)  
  4.    ...: y = np.random.choice([1,2],12,p=[0.4,0.6])  
  5.    ...: kf = KFold(n_splits=5,shuffle=True)  
  6.    ...: for train_index , test_index in kf.split(X):  
  7.    ...:     print('train_index:%s , test_index: %s ' %(train_index,test_index))  
  8.    ...:  
  9.    ...:  
  10. train_index:[ 0  1  2  4  5  6  7  8 10] , test_index: [ 3  9 11]  
  11. train_index:[ 0  1  2  3  4  5  9 10 11] , test_index: [6 7 8]  
  12. train_index:[ 2  3  4  5  6  7  8  9 10 11] , test_index: [0 1]  
  13. train_index:[ 0  1  3  4  5  6  7  8  9 11] , test_index: [ 2 10]  
  14. train_index:[ 0  1  2  3  6  7  8  9 10 11] , test_index: [4 5]  
  15.   
  16. In [4]: from sklearn.model_selection import KFold  
  17.    ...: import numpy as np  
  18.    ...: X = np.arange(24).reshape(12,2)  
  19.    ...: y = np.random.choice([1,2],12,p=[0.4,0.6])  
  20.    ...: kf = KFold(n_splits=5,shuffle=True)  
  21.    ...: for train_index , test_index in kf.split(X):  
  22.    ...:     print('train_index:%s , test_index: %s ' %(train_index,test_index))  
  23.    ...:  
  24.    ...:  
  25. train_index:[ 0  1  2  3  4  5  7  8 11] , test_index: [ 6  9 10]  
  26. train_index:[ 2  3  4  5  6  8  9 10 11] , test_index: [0 1 7]  
  27. train_index:[ 0  1  3  5  6  7  8  9 10 11] , test_index: [2 4]  
  28. train_index:[ 0  1  2  3  4  6  7  9 10 11] , test_index: [5 8]  
  29. train_index:[ 0  1  2  4  5  6  7  8  9 10] , test_index: [ 3 11]  
③设置shuffle=True和random_state=整数,发现每次运行的结果都相同

[python]  view plain  copy
  1. In [5]: from sklearn.model_selection import KFold  
  2.    ...: import numpy as np  
  3.    ...: X = np.arange(24).reshape(12,2)  
  4.    ...: y = np.random.choice([1,2],12,p=[0.4,0.6])  
  5.    ...: kf = KFold(n_splits=5,shuffle=True,random_state=0)  
  6.    ...: for train_index , test_index in kf.split(X):  
  7.    ...:     print('train_index:%s , test_index: %s ' %(train_index,test_index))  
  8.    ...:  
  9.    ...:  
  10. train_index:[ 0  1  2  3  5  7  8  9 10] , test_index: [ 4  6 11]  
  11. train_index:[ 0  1  3  4  5  6  7  9 11] , test_index: [ 2  8 10]  
  12. train_index:[ 0  2  3  4  5  6  8  9 10 11] , test_index: [1 7]  
  13. train_index:[ 0  1  2  4  5  6  7  8 10 11] , test_index: [3 9]  
  14. train_index:[ 1  2  3  4  6  7  8  9 10 11] , test_index: [0 5]  
  15.   
  16. In [6]: from sklearn.model_selection import KFold  
  17.    ...: import numpy as np  
  18.    ...: X = np.arange(24).reshape(12,2)  
  19.    ...: y = np.random.choice([1,2],12,p=[0.4,0.6])  
  20.    ...: kf = KFold(n_splits=5,shuffle=True,random_state=0)  
  21.    ...: for train_index , test_index in kf.split(X):  
  22.    ...:     print('train_index:%s , test_index: %s ' %(train_index,test_index))  
  23.    ...:  
  24.    ...:  
  25. train_index:[ 0  1  2  3  5  7  8  9 10] , test_index: [ 4  6 11]  
  26. train_index:[ 0  1  3  4  5  6  7  9 11] , test_index: [ 2  8 10]  
  27. train_index:[ 0  2  3  4  5  6  8  9 10 11] , test_index: [1 7]  
  28. train_index:[ 0  1  2  4  5  6  7  8 10 11] , test_index: [3 9]  
  29. train_index:[ 1  2  3  4  6  7  8  9 10 11] , test_index: [0 5]  

4.用enumerate,可以在输出每份的同时,输出每份的索引

from sklearn.model_selection import StratifiedKFold

X = np.ones(10)
y = [0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
kfold = StratifiedKFold(n_splits=NFOLDS, shuffle=True, random_state=218)
kf = kfold.split(X, y)
for i, (train_fold, validate) in enumerate(kf):
    print(i,train_fold,validate)

0 [0 1 3 4 5 7 9] [2 6 8]
1 [0 1 2 3 4 6 7 8 9] [5]
2 [0 2 3 5 6 7 8 9] [1 4]
3 [1 2 3 4 5 6 8 9] [0 7]
4 [0 1 2 4 5 6 7 8] [3 9]


5.n_splits属性值获取方式

[python]  view plain  copy
  1. In [8]: kf.split(X)  
  2. Out[8]: <generator object _BaseKFold.split at 0x00000000047FF990>  
  3.   
  4. In [9]: kf.get_n_splits()  
  5. Out[9]: 5  
  6.   
  7. In [10]: kf.n_splits  
  8. Out[10]: 5  

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值