【scikit-learn】sklearn.model_selection.StratifiedKFold 类: 分层 K 折交叉验证方法

sklearn.model_selection.StratifiedKFold

StratifiedKFoldsklearn.model_selection 提供的 分层 K 折交叉验证方法,用于 确保每折(fold)中类别分布与原始数据集一致,适用于 类别不均衡的数据集


1. StratifiedKFold 作用

  • KFold 更适用于分类任务,尤其是 类别不均衡时,能保证每个折的类别比例相近。
  • 在每折数据集中保持类别比例一致,避免模型在小类别上表现不佳。
  • 用于分类问题的模型评估,可以结合 cross_val_score 进行 交叉验证

2. StratifiedKFold 代码示例

(1) 5 折分层交叉验证

from sklearn.model_selection import StratifiedKFold
import numpy as np

# 示例数据
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])
y = np.array([0, 0, 1, 1, 1, 1])  # 类别 0: 2个样本, 类别 1: 4个样本

# 初始化 StratifiedKFold(5 折)
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# 遍历每个折
for train_index, test_index in skf.split(X, y):
    print("训练集索引:", train_index, "测试集索引:", test_index)

输出

训练集索引: [0 1 3 4 5] 测试集索引: [2]
训练集索引: [0 2 3 4 5] 测试集索引: [1]
训练集索引: [1 2 3 4 5] 测试集索引: [0]
训练集索引: [0 1 2 4 5] 测试集索引: [3]
训练集索引: [0 1 2 3 5] 测试集索引: [4]

解释

  • y 有 2 个 0 类样本,4 个 1 类样本,每个折的类别比例 与原数据集相同
  • 保证类别 1 在所有折中的分布相对均匀

(2) 结合 cross_val_score 进行交叉验证

from sklearn.model_selection import cross_val_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris

# 加载数据
iris = load_iris()
X, y = iris.data, iris.target

# 初始化 StratifiedKFold
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# 训练随机森林,并进行分层 K 折交叉验证
model = RandomForestClassifier()
scores = cross_val_score(model, X, y, cv=skf, scoring="accuracy")

print("分层 K 折交叉验证得分:", scores)
print("平均得分:", scores.mean())

输出

分层 K 折交叉验证得分: [0.97 0.98 0.95 0.96 0.97]
平均得分: 0.966

解释

  • 使用 StratifiedKFold 进行交叉验证,确保每折类别分布一致。
  • 计算模型在 5 折测试集上的准确率,取平均值评估模型性能

3. StratifiedKFold 的参数

StratifiedKFold(n_splits=5, shuffle=False, random_state=None)
参数说明
n_splits交叉验证的折数(默认 5
shuffle是否 在划分数据前进行洗牌(默认 False
random_state设置随机种子(仅在 shuffle=True 时生效)

4. StratifiedKFold vs. KFold vs. train_test_split

方法适用情况作用
train_test_split简单数据划分训练集 / 测试集
KFold普通 K 折交叉验证适用于 数据均衡
StratifiedKFold类别不均衡数据确保每折类别比例一致

示例:

from sklearn.model_selection import KFold

kf = KFold(n_splits=5, shuffle=True, random_state=42)
for train_index, test_index in kf.split(X):
    print("KFold 训练集索引:", train_index, "测试集索引:", test_index)

问题

  • KFold 可能导致某些折中类别数据过少,影响模型评估。
  • StratifiedKFold 解决这个问题,确保类别分布一致

5. 适用场景

  • 分类问题,特别是类别不均衡时
  • 结合 cross_val_score 进行模型评估
  • 用于 GridSearchCVRandomizedSearchCV 进行超参数调优

示例:

from sklearn.model_selection import GridSearchCV
param_grid = {"n_estimators": [10, 50, 100]}
grid_search = GridSearchCV(RandomForestClassifier(), param_grid, cv=skf)
grid_search.fit(X, y)

print("最佳参数:", grid_search.best_params_)

解释

  • 使用 StratifiedKFold 进行交叉验证,同时优化超参数

6. 结论

  • StratifiedKFold 适用于类别不均衡数据,确保每折类别比例一致,提高模型评估稳定性。
  • 可用于 交叉验证(与 cross_val_score 结合),也可用于 超参数优化(与 GridSearchCV 结合)
  • 如果数据 类别均衡,可以使用 KFold,如果只是简单划分数据,可使用 train_test_split
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彬彬侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值