得到XGBoost训练出的模型中的决策树的清晰PDF图像

得到XGBoost训练出的模型中的决策树的清晰PDF图像

直接调用

xgb.plot_tree(bst, num_trees=0)

生成的图像是png格式,过于模糊,想要得到矢量的PDF,以及适合插入到markdown(Typora)中的SVG格式,有一种简单的方式是修改XBGBoost源码,需要修改的文件是C:\Users\Username\AppData\Local\Programs\Python\Python38\Lib\site-packages\xgboost\plotting.py,将最后一个地方的函数添加下面两行

g.render('XGBoost_tree'+str(num_trees), format='pdf', cleanup=True)  # 2020-12-19
g.render('XGBoost_tree'+str(num_trees), format='svg', cleanup=True)  # 2020-12-19

改成下面这样

def plot_tree(booster, fmap='', num_trees=0, rankdir=None, ax=None, **kwargs):
    """Plot specified tree.

    Parameters
    ----------
    booster : Booster, XGBModel
        Booster or XGBModel instance
    fmap: str (optional)
       The name of feature map file
    num_trees : int, default 0
        Specify the ordinal number of target tree
    rankdir : str, default "TB"
        Passed to graphiz via graph_attr
    ax : matplotlib Axes, default None
        Target axes instance. If None, new figure and axes will be created.
    kwargs :
        Other keywords passed to to_graphviz

    Returns
    -------
    ax : matplotlib Axes

    """
    try:
        from matplotlib import pyplot as plt
        from matplotlib import image
    except ImportError as e:
        raise ImportError('You must install matplotlib to plot tree') from e

    if ax is None:
        _, ax = plt.subplots(1, 1)

    g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir,
                    **kwargs)
    g.render('XGBoost_tree'+str(num_trees),
         format='pdf', cleanup=True)  # 2020-12-19
    g.render('XGBoost_tree'+str(num_trees),
         format='svg', cleanup=True)  # 2020-12-19
    s = BytesIO()
    s.write(g.pipe(format='png'))
    s.seek(0)
    img = image.imread(s)

    ax.imshow(img)
    ax.axis('off')
    return ax

这样就会在调用

xgb.plot_tree(bst, num_trees=0)

的目录下生成相应的PDF文件和SVG文件。

以上です。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值