集成算法学习案例

"""
集成学习算法演示 - 随机森林分类器

功能:
1. 完整实现从数据加载到模型评估的机器学习流程
2. 包含详细的参数说明和代码注释
3. 支持自定义数据集输入
4. 扩展模型评估指标和特征重要性可视化

依赖:
- scikit-learn >= 1.0.0
- numpy >= 1.17.0
"""

import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn import datasets
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

class EnsembleLearningModel:
    """
    集成学习模型类(以随机森林为例)
    
    Attributes:
        model (RandomForestClassifier): 随机森林分类器实例
        X_train (ndarray): 训练特征数据
        X_test (ndarray): 测试特征数据
        y_train (ndarray): 训练标签数据
        y_test (ndarray): 测试标签数据
        feature_names (list): 特征名称列表
    """
    
    def __init__(self, n_estimators=100, criterion='gini', max_depth=None, random_state=42):
        """
        初始化模型参数
        
        Args:
            n_estimators (int): 基学习器(决策树)数量,默认100
            criterion (str): 分裂准则,可选'gini'或'entropy',默认'gini'
            max_depth (int): 树的最大深度,None表示不限制,默认None
            random_state (int): 随机种子,保证可复现性
        """
        self.model = RandomForestClassifier(
            n_estimators=n_estimators,
            criterion=criterion,
            max_depth=max_depth,
            random_state=random_state
        )
        self.feature_names = None
        self.classes = None

    def load_dataset(self, dataset='iris', custom_data=None):
        """
        加载数据集(支持内置数据集和自定义数据集)
        
        Args:
            dataset (str): 内置数据集名称,可选'iris'(默认)
            custom_data (tuple): 自定义数据集元组 (X, y, feature_names, target_names)
        
        Returns:
            tuple: (特征数据, 标签数据, 特征名称, 目标名称)
        """
        if dataset == 'iris':
            data = datasets.load_iris()
            X = data.data
            y = data.target
            self.feature_names = data.feature_names
            self.classes = data.target_names
        elif custom_data:
            X, y, self.feature_names, self.classes = custom_data
        else:
            raise ValueError("未指定有效数据集或自定义数据")
        return X, y

    def prepare_data(self, X, y, test_size=0.3, random_state=42, scale=True):
        """
        数据预处理和划分
        
        Args:
            X (ndarray): 特征数据
            y (ndarray): 标签数据
            test_size (float): 测试集比例,默认0.3
            random_state (int): 随机种子
            scale (bool): 是否对特征数据标准化,默认True
        
        Returns:
            tuple: (训练特征, 测试特征, 训练标签, 测试标签)
        """
        # 划分数据集
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=test_size, random_state=random_state, stratify=y
        )
        
        # 特征标准化(提升决策树类算法性能)
        if scale:
            scaler = StandardScaler()
            X_train = scaler.fit_transform(X_train)
            X_test = scaler.transform(X_test)
        
        self.X_train, self.X_test = X_train, X_test
        self.y_train, self.y_test = y_train, y_test
        return X_train, X_test, y_train, y_test

    def train_model(self):
        """训练模型并记录训练时间"""
        start_time = np.datetime64('now')
        self.model.fit(self.X_train, self.y_train)
        end_time = np.datetime64('now')
        print(f"模型训练完成,耗时:{(end_time - start_time).item().total_seconds():.2f}秒")

    def predict(self, X):
        """
        预测样本类别
        
        Args:
            X (ndarray): 待预测特征数据
        
        Returns:
            ndarray: 预测类别标签
        """
        return self.model.predict(X)

    def evaluate(self, y_true, y_pred):
        """
        模型评估
        
        Args:
            y_true (ndarray): 真实标签
            y_pred (ndarray): 预测标签
        
        Returns:
            dict: 评估结果字典
        """
        accuracy = accuracy_score(y_true, y_pred)
        report = classification_report(y_true, y_pred, target_names=self.classes)
        confusion = confusion_matrix(y_true, y_pred)
        
        print("模型评估报告:")
        print(f"准确率:{accuracy:.4f}")
        print("分类报告:\n", report)
        print("混淆矩阵:\n", confusion)
        
        return {
            'accuracy': accuracy,
            'classification_report': report,
            'confusion_matrix': confusion
        }

    def plot_feature_importances(self, title="特征重要性分析", figsize=(8, 5)):
        """
        可视化特征重要性
        
        Args:
            title (str): 图表标题
            figsize (tuple): 图表尺寸
        """
        importances = self.model.feature_importances_
        indices = np.argsort(importances)[::-1]  # 降序排列
        
        plt.figure(figsize=figsize)
        plt.title(title)
        plt.bar(range(len(indices)), importances[indices], align='center')
        plt.xticks(range(len(indices)), [self.feature_names[i] for i in indices], rotation=45)
        plt.xlabel("特征")
        plt.ylabel("重要性得分")
        plt.tight_layout()
        plt.show()

def main():
    """主函数:完整流程控制"""
    # 初始化模型并设置参数
    model = EnsembleLearningModel(
        n_estimators=200,       # 增加基学习器数量提升鲁棒性
        criterion='entropy',    # 使用信息增益准则
        max_depth=10,          # 限制树深度防止过拟合
        random_state=42
    )
    
    # 加载内置数据集
    X, y = model.load_dataset(dataset='iris')
    
    # 数据预处理和划分
    X_train, X_test, y_train, y_test = model.prepare_data(
        X, y, test_size=0.2, scale=True  # 调整测试集比例并启用特征标准化
    )
    
    # 训练模型
    model.train_model()
    
    # 预测
    y_pred = model.predict(X_test)
    
    # 评估模型
    model.evaluate(y_true=y_test, y_pred=y_pred)
    
    # 可视化特征重要性
    model.plot_feature_importances()

if __name__ == "__main__":
    main()

    

  1. 面向对象设计

    • 创建EnsembleLearningModel类封装完整流程
    • 包含数据集加载、数据预处理、模型训练、预测、评估、可视化等完整功能模块
    • 支持内置数据集和自定义数据集输入
  2. 增强功能

    • 新增特征标准化选项(通过scale参数控制)
    • 包含分类报告和混淆矩阵评估
    • 实现特征重要性可视化功能(需安装 matplotlib)
    • 支持分层抽样划分数据集(通过stratify=y保证类别分布均衡)
  3. 代码规范

    • 完整的文档字符串(docstring)说明类和方法功能
    • 详细的参数说明和类型标注
    • 符合 PEP8 规范的代码格式
    • 增加异常处理准备(可扩展 try-except 结构)
  4. 可配置性

    • 所有关键参数(训练比例、模型超参数、可视化配置)均可通过类初始化或方法参数调整
    • 支持灵活的数据集输入方式
  5. 评估体系

    • 除准确率外,新增分类报告(精确率 / 召回率 / F1 值)和混淆矩阵
    • 提供可视化分析工具(特征重要性柱状图)

使用说明:

  1. 安装依赖:

bash

pip install scikit-learn numpy matplotlib

  1. 运行方式:

bash

python ensemble_learning.py

  1. 自定义扩展:

    • 加载自定义数据集时,传入格式为 (X, y, feature_names, target_names) 的元组
    • 通过修改模型初始化参数调整随机森林配置
    • plot_feature_importances方法中可自定义图表样式
  2. 输出内容:

    • 控制台输出详细评估指标
    • 弹出窗口显示特征重要性可视化图表
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值