KFold划分数据集的原理:根据n_split直接进行划分
StratifiedKFold划分数据集的原理:划分后的训练集和验证集中类别分布尽量和原数据集一样
1、KFold函数
KFold函数共有三个参数:
n_splits:默认为3,表示将数据划分为多少份,即k折交叉验证中的k;
shuffle:默认为False,表示是否需要打乱顺序,这个参数在很多的函数中都会涉及,如果设置为True,则会先打乱顺序再做划分,如果为False,会直接按照顺序做划分;
random_state:默认为None,表示随机数的种子,只有当shuffle设置为True的时候才会生效。
from sklearn.model_selection import KFold
from sklearn.model_selection import StratifiedKFold
import numpy as np
X = np.array([[10, 1], [20, 2], [30, 3], [40, 4], [50,5], [60,6], [70,7],[80,8],[90,9],[100,10],
[110, 1], [120, 2], [130, 3], [140, 4], [150,5], [160,6], [170,7],[180,8],[190,9],[200,10]])
# 五个类别:1:1:1:1:1
Y1 = np.array([1,1,2,3,3,2,4,4,5,5,1,1,2,3,3,2,4,4,5,5])
# 两个类别:2:3
Y2 = np.array([1,1,1,1,2,2,2,2,2,2,1,1,1,1,2,2,2,2,2,2])
kfolds = KFold(n_splits=5, shuffle=False)
# 注:返回的是索引
for (trn_idx, val_idx) in kfolds.split(X):#不需要标签
print((trn_idx, val_idx))
print((len(trn_idx), len(val_idx)))
out:
(array([ 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]), array([0, 1, 2, 3]))
(16, 4)
(array([ 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]), array([4, 5, 6, 7]))
(16, 4)
(array([ 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 16, 17, 18, 19]), array([ 8, 9, 10, 11]))
(16, 4)
(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19]), array([12, 13, 14, 15]))
(16, 4)
(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), array([16, 17, 18, 19]))
(16, 4)
2、StratifiedKFold
StratifiedKFold函数的参数与KFold相同。
# StratifiedKFold: 抽样后的训练集和验证集的样本分类比例和原有的数据集尽量是一样的
# 对(X, Y1)进行抽样
# Y1中有5个类别,比例为1:1:1:1:1
# 所以,每个KFold的样本数必须为 1*x+1*x+1*x+1*x+1*x=5x个样本
stratifiedKFolds = StratifiedKFold(n_splits=2, shuffle=False)
for (trn_idx, val_idx) in stratifiedKFolds.split(X, Y2):
print((trn_idx, val_idx))
print((len(trn_idx), len(val_idx)))
print('################################################')
# 对(X, Y2)进行抽样
# Y1中有2个类别,比例为2:3
# 所以,每个KFold的样本数必须为 2x+3x=5x个样本
stratifiedKFolds = StratifiedKFold(n_splits=4, shuffle=False)
for (trn_idx, val_idx) in stratifiedKFolds.split(X, Y2):
print((trn_idx, val_idx))
print((len(trn_idx), len(val_idx)))
Out:
(array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19]), array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
(10, 10)
(array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19]))
(10, 10)
################################################
(array([ 2, 3, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]), array([0, 1, 4, 5, 6]))
(15, 5)
(array([ 0, 1, 4, 5, 6, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]), array([2, 3, 7, 8, 9]))
(15, 5)
(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, 17, 18, 19]), array([10, 11, 14, 15, 16]))
(15, 5)
(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, 16]), array([12, 13, 17, 18, 19]))
(15, 5)