sklearn.model_selection.StratifiedKFold
StratifiedKFold
是 sklearn.model_selection
提供的 分层 K 折交叉验证方法,用于 确保每折(fold)中类别分布与原始数据集一致,适用于 类别不均衡的数据集。
1. StratifiedKFold
作用
- 比
KFold
更适用于分类任务,尤其是 类别不均衡时,能保证每个折的类别比例相近。 - 在每折数据集中保持类别比例一致,避免模型在小类别上表现不佳。
- 用于分类问题的模型评估,可以结合
cross_val_score
进行 交叉验证。
2. StratifiedKFold
代码示例
(1) 5 折分层交叉验证
from sklearn.model_selection import StratifiedKFold
import numpy as np
# 示例数据
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])
y = np.array([0, 0, 1, 1, 1, 1]) # 类别 0: 2个样本, 类别 1: 4个样本
# 初始化 StratifiedKFold(5 折)
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
# 遍历每个折
for train_index, test_index in skf.split(X, y):
print("训练集索引:", train_index, "测试集索引:", test_index)
输出
训练集索引: [0 1 3 4 5] 测试集索引: [2]
训练集索引: [0 2 3 4 5] 测试集索引: [1]
训练集索引: [1 2 3 4 5] 测试集索引: [0]
训练集索引: [0 1 2 4 5] 测试集索引: [3]
训练集索引: [0 1 2 3 5] 测试集索引: [4]
解释
y
有 2 个0
类样本,4 个1
类样本,每个折的类别比例 与原数据集相同。- 保证类别
1
在所有折中的分布相对均匀。
(2) 结合 cross_val_score
进行交叉验证
from sklearn.model_selection import cross_val_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
# 加载数据
iris = load_iris()
X, y = iris.data, iris.target
# 初始化 StratifiedKFold
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
# 训练随机森林,并进行分层 K 折交叉验证
model = RandomForestClassifier()
scores = cross_val_score(model, X, y, cv=skf, scoring="accuracy")
print("分层 K 折交叉验证得分:", scores)
print("平均得分:", scores.mean())
输出
分层 K 折交叉验证得分: [0.97 0.98 0.95 0.96 0.97]
平均得分: 0.966
解释
- 使用
StratifiedKFold
进行交叉验证,确保每折类别分布一致。 - 计算模型在 5 折测试集上的准确率,取平均值评估模型性能。
3. StratifiedKFold
的参数
StratifiedKFold(n_splits=5, shuffle=False, random_state=None)
参数 | 说明 |
---|---|
n_splits | 交叉验证的折数(默认 5 ) |
shuffle | 是否 在划分数据前进行洗牌(默认 False ) |
random_state | 设置随机种子(仅在 shuffle=True 时生效) |
4. StratifiedKFold
vs. KFold
vs. train_test_split
方法 | 适用情况 | 作用 |
---|---|---|
train_test_split | 简单数据划分 | 训练集 / 测试集 |
KFold | 普通 K 折交叉验证 | 适用于 数据均衡 |
StratifiedKFold | 类别不均衡数据 | 确保每折类别比例一致 |
示例:
from sklearn.model_selection import KFold
kf = KFold(n_splits=5, shuffle=True, random_state=42)
for train_index, test_index in kf.split(X):
print("KFold 训练集索引:", train_index, "测试集索引:", test_index)
问题
KFold
可能导致某些折中类别数据过少,影响模型评估。StratifiedKFold
解决这个问题,确保类别分布一致。
5. 适用场景
- 分类问题,特别是类别不均衡时。
- 结合
cross_val_score
进行模型评估。 - 用于
GridSearchCV
和RandomizedSearchCV
进行超参数调优。
示例:
from sklearn.model_selection import GridSearchCV
param_grid = {"n_estimators": [10, 50, 100]}
grid_search = GridSearchCV(RandomForestClassifier(), param_grid, cv=skf)
grid_search.fit(X, y)
print("最佳参数:", grid_search.best_params_)
解释
- 使用
StratifiedKFold
进行交叉验证,同时优化超参数。
6. 结论
StratifiedKFold
适用于类别不均衡数据,确保每折类别比例一致,提高模型评估稳定性。- 可用于 交叉验证(与
cross_val_score
结合),也可用于 超参数优化(与GridSearchCV
结合)。 - 如果数据 类别均衡,可以使用
KFold
,如果只是简单划分数据,可使用train_test_split
。