sklearn决策树可视化

过去,关于sklearn决策树可视化的教程大部分都是基于Graphviz(一个图形可视化软件)的。

Graphviz的安装比较麻烦,并不是通过pip install就能搞定的,因为要安装底层的依赖库。

现在,自版本0.21以后,scikit-learn也自带可视化工具了,它就是sklearn.tree.plot_tree()

假设决策树模型(clf)已经训练好了,画图的代码如下:

def tree1(clf):
    fig = plt.figure()
    tree.plot_tree(clf)
    fig.savefig(os.path.join(fig_dir, "tree1.png"))

没有设置图像的相关参数,画出的树结构看不清树节点的信息。
在这里插入图片描述

设置字体大小,把文字调大一点:

def tree2(clf):
    fig = plt.figure()
    tree.plot_tree(clf, fontsize=8)
    fig.savefig(os.path.join(fig_dir, "tree2.png"))

文字是放大了,树节点也随着增大了,但是画面很拥挤。
在这里插入图片描述
那把画布调大一点:

def tree3(clf):
    fig = plt.figure(figsize=(35, 10))
    tree.plot_tree(clf, fontsize=8)
    fig.savefig(os.path.join(fig_dir, "tree3.png"))

大功告成!
在这里插入图片描述

下面的代码包含数据读取、模型训练和画图,有注释,就不展开了。

关注【小猫AI】公众号,回复tree可以获取训练模型的数据哦。

# -*- coding: utf-8 -*-
"""
Description : sklearn决策树可视化(scikit-learn==0.24.2)。
Authors     : wapping
CreateDate  : 2022/2/7
"""
import os
import pandas as pd
from sklearn import tree
from matplotlib import pyplot as plt


def read_data(fp):
    """加载训练数据。"""
    data = pd.read_csv(fp, header=None)
    x = data[[0, 1]]    # 第0,1列为特征
    y = data[[2]]       # 第2列为标签
    return x, y


def tree1(clf):
    # 没有设置图像的相关参数,画出的树结构看不清树节点的信息
    fig = plt.figure()
    tree.plot_tree(clf)
    fig.savefig(os.path.join(fig_dir, "tree1.png"))


def tree2(clf):
    # 设置字体大小,树节点放大了,但是很拥挤
    fig = plt.figure()
    tree.plot_tree(clf, fontsize=8)
    fig.savefig(os.path.join(fig_dir, "tree2.png"))


def tree3(clf):
    # 同时设置字体大小和图像的大小,树结构正常显示
    fig = plt.figure(figsize=(35, 10))
    tree.plot_tree(clf, fontsize=8)
    fig.savefig(os.path.join(fig_dir, "tree3.png"))


if __name__ == '__main__':
    fig_dir = "data/plot_tree"      # 保存图片的目录
    data_path = "data/plot_tree_data.csv"   # 训练树模型的数据
    os.makedirs(fig_dir, exist_ok=True)

    # 读取训练数据
    x, y = read_data(data_path)

    # 训练决策树分类器
    clf = tree.DecisionTreeClassifier(min_samples_leaf=100, random_state=666)
    clf = clf.fit(x, y)

    # 画树结构并保存图片
    tree1(clf)
    tree2(clf)
    tree3(clf)
  • 12
    点赞
  • 42
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
sklearn中贝叶斯分类模型的可视化可以通过使用matplotlib库来实现。 以高斯朴素贝叶斯分类器(GaussianNB)为例,可以使用以下代码来可视化分类结果: ```python import numpy as np import matplotlib.pyplot as plt from sklearn.naive_bayes import GaussianNB from sklearn.datasets import make_classification from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score # 生成模拟数据集 X, y = make_classification(n_samples=1000, n_features=2, n_informative=2, n_redundant=0, random_state=42) # 划分训练集和测试集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) # 创建高斯朴素贝叶斯分类器 gnb = GaussianNB() # 拟合训练集 gnb.fit(X_train, y_train) # 预测测试集 y_pred = gnb.predict(X_test) # 计算准确率 accuracy = accuracy_score(y_test, y_pred) print("Accuracy:", accuracy) # 可视化分类结果 x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1), np.arange(y_min, y_max, 0.1)) Z = gnb.predict(np.c_[xx.ravel(), yy.ravel()]) Z = Z.reshape(xx.shape) plt.contourf(xx, yy, Z, alpha=0.4) plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.8) plt.title("GaussianNB Classification") plt.xlabel("Feature 1") plt.ylabel("Feature 2") plt.show() ``` 在上述代码中,首先使用make_classification函数生成一个二维的模拟数据集,然后将其划分为训练集和测试集。接着创建高斯朴素贝叶斯分类器,并拟合训练集。使用预测函数predict对测试集进行预测,并计算准确率。最后,使用meshgrid和contourf函数可视化分类结果,使用scatter函数绘制数据点。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值