鸢尾花数据集使用随机森林模型建立并进行超参数寻优及界面封装GUI

3 篇文章 0 订阅
1 篇文章 0 订阅

鸢尾花数据集使用随机森林模型建立并进行超参数寻优及界面封装GUI

以鸢尾花数据集为例,进行超参数寻优,寻优包括估计器个数,树深度、最小分割叶子节点数,最大特证数目等,可以在输入框中直接输入,点击按钮即可进行超参数寻优,并显示进度条。

import sys
import numpy as np
from PyQt5.QtWidgets import (
    QApplication, QMainWindow, QVBoxLayout, QWidget,
    QPushButton, QLineEdit, QLabel, QProgressBar
)
from PyQt5.QtCore import Qt, QThread, pyqtSignal
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris

class WorkerThread(QThread):
    progress = pyqtSignal(int)
    result = pyqtSignal(object)

    def __init__(self, n_estimators_range):
        super().__init__()
        self.n_estimators_range = n_estimators_range

    def run(self):
        # 加载示例数据集
        data = load_iris()
        X, y = data.data, data.target

        # 定义参数网格
        param_grid = {'n_estimators': self.n_estimators_range}

        # 创建随机森林分类器
        rf = RandomForestClassifier()

        # 创建GridSearchCV对象
        grid_search = GridSearchCV(rf, param_grid, cv=5, n_jobs=-1)

        # 进行超参数寻优
        n_iterations = len(self.n_estimators_range) * 5  # 总的迭代次数
        for i, params in enumerate(grid_search.param_grid['n_estimators']):
            # 拷贝grid_search对象,并仅包含一个参数进行训练
            sub_grid_search = GridSearchCV(rf, {'n_estimators': [params]}, cv=5, n_jobs=-1)
            sub_grid_search.fit(X, y)
            current_progress = ((i + 1) * 5 / n_iterations) * 100
            self.progress.emit(int(current_progress))

        # 重新拟合整个参数网格
        grid_search.fit(X, y)

        # 获取最优参数
        best_params = grid_search.best_params_
        self.result.emit(best_params)
        self.progress.emit(100)  # 完成时进度设为100%

class MainWindow(QMainWindow):
    def __init__(self):
        super().__init__()

        # 创建主窗口部件
        self.setWindowTitle("Random Forest Hyperparameter Tuning")
        self.setGeometry(100, 100, 400, 200)

        # 创建布局
        layout = QVBoxLayout()

        # 创建输入框和标签
        self.label = QLabel("Enter range of n_estimators (comma separated):")
        self.input = QLineEdit("10,50,100,150,200")
        layout.addWidget(self.label)
        layout.addWidget(self.input)

        # 创建进度条
        self.progress_bar = QProgressBar()
        layout.addWidget(self.progress_bar)

        # 创建按钮
        self.start_button = QPushButton("Start")
        self.start_button.clicked.connect(self.start_tuning)
        layout.addWidget(self.start_button)

        # 创建结果显示标签
        self.result_label = QLabel("Best Parameters: ")
        layout.addWidget(self.result_label)

        # 设置中央部件
        container = QWidget()
        container.setLayout(layout)
        self.setCentralWidget(container)

    def start_tuning(self):
        # 获取n_estimators范围
        n_estimators_str = self.input.text()
        n_estimators_range = [int(x) for x in n_estimators_str.split(',')]

        # 创建并启动线程
        self.worker = WorkerThread(n_estimators_range)
        self.worker.progress.connect(self.update_progress)
        self.worker.result.connect(self.show_result)
        self.worker.start()

    def update_progress(self, value):
        self.progress_bar.setValue(value)

    def show_result(self, best_params):
        self.result_label.setText(f"Best Parameters: {best_params}")
        self.update_progress(100)

if __name__ == "__main__":
    app = QApplication(sys.argv)
    main_win = MainWindow()
    main_win.show()
    sys.exit(app.exec_())
  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小孟的CDN

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值