目录
5. KFold, StratifiedKFold,StratifiedShuffleSplit, GroupKFold区别以及Stratified Group KFold 实现
在机器学习,一般不能直接拿整个数据集取训练,而采用cross-validation方法来训练。增强随机性减小噪声等,来减少过拟合,从而有限的数据中获取学习到更全面的信息,增强模型的泛化能力。在sklearn中,经常使用的有:KFold, StratifiedKFold,StratifiedShuffleSplit, GroupKFold。逐一解释使用区别,使用一个简单的df。(一般情况下, n_splits=5/10
)
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold, StratifiedKFold,\
StratifiedShuffleSplit, GroupKFold, GroupShuffleSplit
df2 = pd.DataFrame([[6.5, 1, 2],
[8, 1, 0],
[61, 2, 1],
[54, 0, 1],
[78, 0, 1],
[119, 2, 2],
[111, 1, 2],
[23, 0, 0],
[31, 2, 0]], columns=['h', 'w', 'class'])
df2
h w class
0 6.5 1 2
1 8.0 1 0
2 61.0 2 1
3 54.0 0 1
4 78.0 0 1
5 119.0 2 2
6 111.0 1 2
7 23.0 0 0
8 31.0 2 0
1. KFold 使用
X = df2.drop(['class'], axis=1)
y = df2['class']
floder = KFold(n_splits=3, random_state=2020, shuffle=True)
for train_idx, test_idx in floder.split(X,y):
print("KFold Spliting:")
print('Train index: %s | test index: %s' % (train_idx, test_idx))
# print(X.iloc[train_idx], y.iloc[train_idx], '\n', X.iloc[test_idx], y.iloc[test_idx])
===================================================================
KFold Spliting:
Train index: [0 1 3 5 6 8] | test index: [2 4 7]
KFold Spliting:
Train index: [0 2 3 4 7 8] | test index: [1 5 6]
KFold Spliting:
Train index: [1 2 4 5 6 7] | test index: [0 3 8]
注意划分后得到的是针对数据的索引。我们现在只关注其test index,可以发现每次划分得到的索引不是按照class
对应的类别均匀划分的,如第一次[2,4,7]
对应类别是1,1,0
. 其实 train index也一样,2,0,1,2,2,0
.这在很多时候是不满足要求的,因为我们很多时候希望每次划分得到的train dataset/valid dataset
其中对应的target类别是均匀的。
有意思的是,你将 n_splits=8或9
试试,可以看到不同划分数目,得到test index数目是不一样的。如 n_splits=8
时, 第1 folds中test index size为 n_samples // n_splits + 1= 2
,其余为1。
The first
n_samples % n_splits
folds have sizen_samples // n_splits + 1
, other folds have sizen_samples // n_splits
, wheren_samples
is the number of samples. —— kfold
现在我们知道,KFold不能按照target类别来均匀划分,如果数据集必须按target类别来划分呢?那就要用到 StratifiedKFold
。
2. StratifiedKFold使用
sfolder = StratifiedKFold(n_splits=3, random_state=2020, shuffle=True)
for train_idx, test_idx in sfolder.split(X,y):
print(