1. 概览
KFold和StratifiedKFold的作用都是用于配合交叉验证的需求,将数据分割成训练集和测试集。
2. 区别
KFold随机分割数据,不会考虑数据的分布情况。
StratifiedKFold会根据原始数据的分布情况,分割出同分布的数据。
3. 实验
3.1 代码
from sklearn.model_selection import KFold
from sklearn.model_selection import StratifiedKFold
import numpy as np
X = np.array([
[11,12,13,14],
[21,22,23,24],
[31,32,33,34],
[41,42,43,44],
[51,52,53,54],
[61,62,63,64],
[71,72,73,74],
[81,82,83,84]
])
y = np.array([1,1,0,0,0,1,1,0])
KFold = KFold(n_splits=2, shuffle=True)
sKFold = StratifiedKFold(n_splits=2, shuffle=True)
print("===KFold split===")
for train, test in KFold.split(X, y):
print('train index:\n',train)
print('train X value:\n',X[train])
print('train y value:\n',y[train])
print('test index:\n',test)
print('test X value:\n',X[test])
print('test y value:\n',y[test])
print()
print("===sKFold split===")
for train, test in sKFold.split(X, y):
print('train index:\n',train)
print('train X value:\n',X[train])
print('train y value:\n',y[train])
print('test index:\n',test)
print('test X value:\n',X[test])
print('test y value:\n',y[test])
print()
3.2 输出
===KFold split===
train index:
[0 4 5 6]
train X value:
[[11 12 13 14]
[51 52 53 54]
[61 62 63 64]
[71 72 73 74]]
train y value:
[1 0 1 1]
test index:
[1 2 3 7]
test X value:
[[21 22 23 24]
[31 32 33 34]
[41 42 43 44]
[81 82 83 84]]
test y value:
[1 0 0 0]
train index:
[1 2 3 7]
train X value:
[[21 22 23 24]
[31 32 33 34]
[41 42 43 44]
[81 82 83 84]]
train y value:
[1 0 0 0]
test index:
[0 4 5 6]
test X value:
[[11 12 13 14]
[51 52 53 54]
[61 62 63 64]
[71 72 73 74]]
test y value:
[1 0 1 1]
===sKFold split===
train index:
[2 4 5 6]
train X value:
[[31 32 33 34]
[51 52 53 54]
[61 62 63 64]
[71 72 73 74]]
train y value:
[0 0 1 1]
test index:
[0 1 3 7]
test X value:
[[11 12 13 14]
[21 22 23 24]
[41 42 43 44]
[81 82 83 84]]
test y value:
[1 1 0 0]
train index:
[0 1 3 7]
train X value:
[[11 12 13 14]
[21 22 23 24]
[41 42 43 44]
[81 82 83 84]]
train y value:
[1 1 0 0]
test index:
[2 4 5 6]
test X value:
[[31 32 33 34]
[51 52 53 54]
[61 62 63 64]
[71 72 73 74]]
test y value:
[0 0 1 1]
4.分析结果
原始数据y的值为[1,1,0,0,0,1,1,0],标签为0和1样本一样多。
观察上述分割y的值:
- KFold分割得到的是[1 0 1 1]、[1 0 0 0] ,0和1的数量是不同的(与原始数据分布不同)。
- StratifiedKFold分割得到的是[0 0 1 1]、[1 1 0 0],0和1的数量是相同的(与原始数据分布相同)。
注意:
- 多次运行KFold,也可能随机得到与原始数据同分布的分割结果,只是随机得到的。
- 多次运行StratifiedKFold,每次都可以得到与原始数据同分布的结果,是稳定的。