树模型分裂节点可视化

环境

  • 不管用哪种可视化的模型,都需要配置graphviz图形驱动用来对结果进行展示,其他的安装包如下所示:
import pydotplus
from sklearn import tree
from IPython.display import Image

二叉树可视化代码

  • 具体的操作代码如下所示:

import os

from sklearn import tree
from sklearn.datasets import load_iris

iris = load_iris()
clf = tree.DecisionTreeClassifier()
clf = clf.fit(train_x, train_y)

import pydotplus
from IPython.display import Image
dot_data = tree.export_graphviz(clf, out_file=None, 
                         feature_names=df_all.columns.tolist()[:-1],  
                         class_names=['0','1','2'], 
                         filled=True, rounded=True, special_characters=False)  
graph = pydotplus.graph_from_dot_data(dot_data)
graph.set('label', 'Tree 30 from  Xgboost Tree')
graph.set('labelloc', 't')
# 遍历节点
# 获取图中所有节点
nodes = graph.get_node_list()
for index, node in enumerate(nodes):
    # 判断节点名称是否为'node',即右上方背景框
    if node.get_name().startswith('node'):
        # 设置节点样式和属性
        
        node.set('style', 'filled')
        # node.set('color', 'none')
Image(graph.create_png())
graph.write_png('decision_tree.png')

xgboost可视化代码

def model_plot_trees(clf, features, **n):
    """绘制树模型"""
    # 保存特征名称到fmap文件,用于图形绘制
    with open('xgb.fmap', 'w', encoding="utf-8") as fmap:
        for k, ft in enumerate(features):
            fmap.write(''.join(str([k, ft, 'p']) + '\n'))

    # 构建条件节点和叶子节点的形状、填充颜色
    c_node_params = {'shape': 'box',
                     'style': 'filled,rounded',
                     'fillcolor': '#78bceb'
                     }
    l_node_params = {'shape': 'box',
                     'style': 'filled',
                     'fillcolor': '#e48038'
                     }

    # 树模型绘制:有向图
    # 绘制和保存第num_trees+1棵树,num_trees为树的序号
    digraph = xgb.to_graphviz(clf, num_trees=30, condition_node_params=c_node_params,
                              leaf_node_params=l_node_params, fmap='xgb.fmap')
    # digraph.format = 'png'
    digraph.view('./oil_xgb_trees')

    # 分别绘制子图,不保存
    for i in range(n.get('n')):
        xgb.plot_tree(clf, num_trees=i, condition_node_params=c_node_params,
                      leaf_node_params=l_node_params, fmap='xgb.fmap')
        plt.show()
model_plot_trees(xgb_model, df_all.columns.tolist()[:-1])
  • 10
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

一只红花猪

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值