"""
集成学习算法演示 - 随机森林分类器
功能:
1. 完整实现从数据加载到模型评估的机器学习流程
2. 包含详细的参数说明和代码注释
3. 支持自定义数据集输入
4. 扩展模型评估指标和特征重要性可视化
依赖:
- scikit-learn >= 1.0.0
- numpy >= 1.17.0
"""
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn import datasets
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
class EnsembleLearningModel:
"""
集成学习模型类(以随机森林为例)
Attributes:
model (RandomForestClassifier): 随机森林分类器实例
X_train (ndarray): 训练特征数据
X_test (ndarray): 测试特征数据
y_train (ndarray): 训练标签数据
y_test (ndarray): 测试标签数据
feature_names (list): 特征名称列表
"""
def __init__(self, n_estimators=100, criterion='gini', max_depth=None, random_state=42):
"""
初始化模型参数
Args:
n_estimators (int): 基学习器(决策树)数量,默认100
criterion (str): 分裂准则,可选'gini'或'entropy',默认'gini'
max_depth (int): 树的最大深度,None表示不限制,默认None
random_state (int): 随机种子,保证可复现性
"""
self.model = RandomForestClassifier(
n_estimators=n_estimators,
criterion=criterion,
max_depth=max_depth,
random_state=random_state
)
self.feature_names = None
self.classes = None
def load_dataset(self, dataset='iris', custom_data=None):
"""
加载数据集(支持内置数据集和自定义数据集)
Args:
dataset (str): 内置数据集名称,可选'iris'(默认)
custom_data (tuple): 自定义数据集元组 (X, y, feature_names, target_names)
Returns:
tuple: (特征数据, 标签数据, 特征名称, 目标名称)
"""
if dataset == 'iris':
data = datasets.load_iris()
X = data.data
y = data.target
self.feature_names = data.feature_names
self.classes = data.target_names
elif custom_data:
X, y, self.feature_names, self.classes = custom_data
else:
raise ValueError("未指定有效数据集或自定义数据")
return X, y
def prepare_data(self, X, y, test_size=0.3, random_state=42, scale=True):
"""
数据预处理和划分
Args:
X (ndarray): 特征数据
y (ndarray): 标签数据
test_size (float): 测试集比例,默认0.3
random_state (int): 随机种子
scale (bool): 是否对特征数据标准化,默认True
Returns:
tuple: (训练特征, 测试特征, 训练标签, 测试标签)
"""
# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=test_size, random_state=random_state, stratify=y
)
# 特征标准化(提升决策树类算法性能)
if scale:
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
self.X_train, self.X_test = X_train, X_test
self.y_train, self.y_test = y_train, y_test
return X_train, X_test, y_train, y_test
def train_model(self):
"""训练模型并记录训练时间"""
start_time = np.datetime64('now')
self.model.fit(self.X_train, self.y_train)
end_time = np.datetime64('now')
print(f"模型训练完成,耗时:{(end_time - start_time).item().total_seconds():.2f}秒")
def predict(self, X):
"""
预测样本类别
Args:
X (ndarray): 待预测特征数据
Returns:
ndarray: 预测类别标签
"""
return self.model.predict(X)
def evaluate(self, y_true, y_pred):
"""
模型评估
Args:
y_true (ndarray): 真实标签
y_pred (ndarray): 预测标签
Returns:
dict: 评估结果字典
"""
accuracy = accuracy_score(y_true, y_pred)
report = classification_report(y_true, y_pred, target_names=self.classes)
confusion = confusion_matrix(y_true, y_pred)
print("模型评估报告:")
print(f"准确率:{accuracy:.4f}")
print("分类报告:\n", report)
print("混淆矩阵:\n", confusion)
return {
'accuracy': accuracy,
'classification_report': report,
'confusion_matrix': confusion
}
def plot_feature_importances(self, title="特征重要性分析", figsize=(8, 5)):
"""
可视化特征重要性
Args:
title (str): 图表标题
figsize (tuple): 图表尺寸
"""
importances = self.model.feature_importances_
indices = np.argsort(importances)[::-1] # 降序排列
plt.figure(figsize=figsize)
plt.title(title)
plt.bar(range(len(indices)), importances[indices], align='center')
plt.xticks(range(len(indices)), [self.feature_names[i] for i in indices], rotation=45)
plt.xlabel("特征")
plt.ylabel("重要性得分")
plt.tight_layout()
plt.show()
def main():
"""主函数:完整流程控制"""
# 初始化模型并设置参数
model = EnsembleLearningModel(
n_estimators=200, # 增加基学习器数量提升鲁棒性
criterion='entropy', # 使用信息增益准则
max_depth=10, # 限制树深度防止过拟合
random_state=42
)
# 加载内置数据集
X, y = model.load_dataset(dataset='iris')
# 数据预处理和划分
X_train, X_test, y_train, y_test = model.prepare_data(
X, y, test_size=0.2, scale=True # 调整测试集比例并启用特征标准化
)
# 训练模型
model.train_model()
# 预测
y_pred = model.predict(X_test)
# 评估模型
model.evaluate(y_true=y_test, y_pred=y_pred)
# 可视化特征重要性
model.plot_feature_importances()
if __name__ == "__main__":
main()
-
面向对象设计:
- 创建
EnsembleLearningModel
类封装完整流程 - 包含数据集加载、数据预处理、模型训练、预测、评估、可视化等完整功能模块
- 支持内置数据集和自定义数据集输入
- 创建
-
增强功能:
- 新增特征标准化选项(通过
scale
参数控制) - 包含分类报告和混淆矩阵评估
- 实现特征重要性可视化功能(需安装 matplotlib)
- 支持分层抽样划分数据集(通过
stratify=y
保证类别分布均衡)
- 新增特征标准化选项(通过
-
代码规范:
- 完整的文档字符串(docstring)说明类和方法功能
- 详细的参数说明和类型标注
- 符合 PEP8 规范的代码格式
- 增加异常处理准备(可扩展 try-except 结构)
-
可配置性:
- 所有关键参数(训练比例、模型超参数、可视化配置)均可通过类初始化或方法参数调整
- 支持灵活的数据集输入方式
-
评估体系:
- 除准确率外,新增分类报告(精确率 / 召回率 / F1 值)和混淆矩阵
- 提供可视化分析工具(特征重要性柱状图)
使用说明:
- 安装依赖:
bash
pip install scikit-learn numpy matplotlib
- 运行方式:
bash
python ensemble_learning.py
-
自定义扩展:
- 加载自定义数据集时,传入格式为 (X, y, feature_names, target_names) 的元组
- 通过修改模型初始化参数调整随机森林配置
- 在
plot_feature_importances
方法中可自定义图表样式
-
输出内容:
- 控制台输出详细评估指标
- 弹出窗口显示特征重要性可视化图表