关于sklearn决策树手动指定节点进行剪枝调整的实现

一、决策树剪枝


        决策树的剪枝方式有两种,预剪枝和后剪枝,后剪枝在python的sklearn方法中提供了CCP代价复杂度剪枝法(Cost Complexity Pruning)具体实现代码如下:
 

 # -*- coding: utf-8 -*-
from sklearn.datasets import load_iris
from sklearn import tree
import numpy as np
 
#--------数据准备-----------------------------------
iris = load_iris()                          # 加载数据
X = iris.data
y = iris.target
#-------模型训练---------------------------------
clf = tree.DecisionTreeClassifier(min_samples_split=10,random_state=0,ccp_alpha=0)        
clf = clf.fit(X, y)     
#-------计算ccp路径------------------------------
pruning_path = clf.cost_complexity_pruning_path(X, y)
 
#-------打印结果---------------------------------   
print("\n====CCP路径=================")
print("ccp_alphas:",pruning_path['ccp_alphas'])
print("impurities:",pruning_path['impurities'])    
 
#------设置alpha对树后剪枝-----------------------
clf = tree.DecisionTreeClassifier(min_samples_split=10,random_state=0,ccp_alpha=0.1)        
clf = clf.fit(X, y) 
#------自行计算树纯度以验证-----------------------
is_leaf =clf.tree_.children_left ==-1
tree_impurities = (clf.tree_.impurity[is_leaf]* clf.tree_.n_node_samples[is_leaf]/len(y)).sum()
#-------打印结果--------------------------- 
print("\n==设置alpha=0.1剪枝后的树纯度:=========\n",tree_impurities)

该方法在树构建完成后,对树进行剪枝简化,使以下损失函数最小化。但具体裁剪哪些分枝的是不知道,因此需要一个方法,可以对具体分枝进行裁剪。目前sklearn并没有提供该方法,从源码上进行修改。

二、决策树源码分析

        在github上下载scikit-learn-1.1.X的源码,决策树的实现在scikit-learn-1.1.X/sklearn/tree/_tree.pyx,采用cpython是实现代码分析如下:

#CCP 剪枝方式
def _build_pruned_tree_ccp(
    Tree tree, # OUT
    Tree orig_tree,
    DOUBLE_t ccp_alpha):
    """Build a pruned tree from the original tree using cost complexity
    pruning.

    The values and nodes from the original tree are copied into the pruned
    tree.

    Parameters
    ----------
    tree : Tree
        Location to place the pruned tree
    orig_tree : Tree
        Original tree
    ccp_alpha : positive double
        Complexity parameter. The subtree with the largest cost complexity
        that is smaller than ``ccp_alpha`` will be chosen. By default,
        no pruning is performed.
    """

    cdef:
        SIZE_t n_nodes = orig_tree.node_count
        unsigned char[:] leaves_in_subtree = np.zeros(
            shape=n_nodes, dtype=np.uint8)

    pruning_controller = _AlphaPruner(ccp_alpha=ccp_alpha)

    # CCP方法用于计算哪些节点需要剪枝
    _cost_complexity_prune(leaves_in_subtree, orig_tree, pruning_controller)

    # 实现剪枝的方法
    _build_pruned_tree(tree, orig_tree, leaves_in_subtree,
                       pruning_controller.capacity)

cdef class _CCPPruneController:
    """Base class used by build_pruned_tree_ccp and ccp_pruning_path
    to control pruning.
    """
    cdef bint stop_pruning(self, DOUBLE_t effective_alpha) nogil:
        """Return 1 to stop pruning and 0 to continue pruning"""
        return 0

    cdef void save_metrics(self, DOUBLE_t effective_alpha,
                           DOUBLE_t subtree_impurities) nogil:
        """Save metrics when pruning"""
        pass

    cdef void after_pruning(self, unsigned char[:] in_subtree) nogil:
        """Called after pruning"""
        pass


cdef class _AlphaPruner(_CCPPruneController):
    """Use alpha to control when to stop pruning."""
    cdef DOUBLE_t ccp_alpha
    cdef SIZE_t capacity

    def __cinit__(self, DOUBLE_t ccp_alpha):
        self.ccp_alpha = ccp_alpha
        self.capacity = 0

    cdef bint stop_pruning(self, DOUBLE_t effective_alpha) nogil:
        # The subtree on the previous iteration has the greatest ccp_alpha
        # less than or equal to self.ccp_alpha
        return self.ccp_alpha < effective_alpha

    cdef void after_pruning(self, unsigned char[:] in_subtree) nogil:
        """Updates the number of leaves in subtree"""
        for i in range(in_subtree.shape[0]):
            if in_subtree[i]:
                self.capacity += 1

cdef struct CostComplexityPruningRecord:
    SIZE_t node_idx
    SIZE_t parent

cdef _cost_complexity_prune(unsigned char[:] leaves_in_subtree, # OUT
                            Tree orig_tree,
                            _CCPPruneController controller):
    """Perform cost complexity pruning.

    This function takes an already grown tree, `orig_tree` and outputs a
    boolean mask `leaves_in_subtree` which are the leaves in the pruned tree.
    During the pruning process, the controller is passed the effective alpha and
    the subtree impurities. Furthermore, the controller signals when to stop
    pruning.

    Parameters
    ----------
    leaves_in_subtree : unsigned char[:]
        Output for leaves of subtree
    orig_tree : Tree
        Original tree
    ccp_controller : _CCPPruneController
        Cost complexity controller
    """

    cdef:
        SIZE_t i
        SIZE_t n_nodes = orig_tree.node_count
        # prior probability using weighted samples
        DOUBLE_t[:] weighted_n_node_samples = orig_tree.weighted_n_node_samples
        DOUBLE_t total_sum_weights = weighted_n_node_samples[0]
        DOUBLE_t[:] impurity = orig_tree.impurity
        # weighted impurity of each node
        DOUBLE_t[:] r_node = np.empty(shape=n_nodes, dtype=np.float64)

        SIZE_t[:] child_l = orig_tree.children_left
        SIZE_t[:] child_r = orig_tree.children_right
        SIZE_t[:] parent = np.zeros(shape=n_nodes, dtype=np.intp)

        stack[CostComplexityPruningRecord] ccp_stack
        CostComplexityPruningRecord stack_record
        int rc = 0
        SIZE_t node_idx
        stack[SIZE_t] node_indices_stack

        SIZE_t[:] n_leaves = np.zeros(shape=n_nodes, dtype=np.intp)
        DOUBLE_t[:] r_branch = np.zeros(shape=n_nodes, dtype=np.float64)
        DOUBLE_t current_r
        SIZE_t leaf_idx
        SIZE_t parent_idx

        # candidate nodes that can be pruned
        unsigned char[:] candidate_nodes = np.zeros(shape=n_nodes,
                                                    dtype=np.uint8)
        # nodes in subtree
        unsigned char[:] in_subtree = np.ones(shape=n_nodes, dtype=np.uint8)
        DOUBLE_t[:] g_node = np.zeros(shape=n_nodes, dtype=np.float64)
        SIZE_t pruned_branch_node_idx
        DOUBLE_t subtree_alpha
        DOUBLE_t effective_alpha
        SIZE_t child_l_idx
        SIZE_t child_r_idx
        SIZE_t n_pruned_leaves
        DOUBLE_t r_diff
        DOUBLE_t max_float64 = np.finfo(np.float64).max

    # find parent node ids and leaves
    with nogil:

        for i in range(r_node.shape[0]):
            r_node[i] = (
                weighted_n_node_samples[i] * impurity[i] / total_sum_weights)

        # Push the root node
        ccp_stack.push({"node_idx": 0, "parent": _TREE_UNDEFINED})

        while not ccp_stack.empty():
            stack_record = ccp_stack.top()
            ccp_stack.pop()

            node_idx = stack_record.node_idx
            parent[node_idx] = stack_record.parent

            if child_l[node_idx] == _TREE_LEAF:
                # ... and child_r[node_idx] == _TREE_LEAF:
                leaves_in_subtree[node_idx] = 1
            else:
                ccp_stack.push({"node_idx": child_l[node_idx], "parent": node_idx})
                ccp_stack.push({"node_idx": child_r[node_idx], "parent": node_idx})

        # computes number of leaves in all branches and the overall impurity of
        # the branch. The overall impurity is the sum of r_node in its leaves.
        for leaf_idx in range(leaves_in_subtree.shape[0]):
            if not leaves_in_subtree[leaf_idx]:
                continue
            r_branch[leaf_idx] = r_node[leaf_idx]

            # bubble up values to ancestor nodes
            current_r = r_node[leaf_idx]
            while leaf_idx != 0:
                parent_idx = parent[leaf_idx]
                r_branch[parent_idx] += current_r
                n_leaves[parent_idx] += 1
                leaf_idx = parent_idx

        for i in range(leaves_in_subtree.shape[0]):
            candidate_nodes[i] = not leaves_in_subtree[i]

        # save metrics before pruning
        controller.save_metrics(0.0, r_branch[0])

        # while root node is not a leaf
        while candidate_nodes[0]:

            # computes ccp_alpha for subtrees and finds the minimal alpha
            effective_alpha = max_float64
            for i in range(n_nodes):
                if not candidate_nodes[i]:
                    continue
                subtree_alpha = (r_node[i] - r_branch[i]) / (n_leaves[i] - 1)
                if subtree_alpha < effective_alpha:
                    effective_alpha = subtree_alpha
                    pruned_branch_node_idx = i

            if controller.stop_pruning(effective_alpha):
                break

            node_indices_stack.push(pruned_branch_node_idx)

            # descendants of branch are not in subtree
            while not node_indices_stack.empty():
                node_idx = node_indices_stack.top()
                node_indices_stack.pop()

                if not in_subtree[node_idx]:
                    continue # branch has already been marked for pruning
                candidate_nodes[node_idx] = 0
                leaves_in_subtree[node_idx] = 0
                in_subtree[node_idx] = 0

                if child_l[node_idx] != _TREE_LEAF:
                    # ... and child_r[node_idx] != _TREE_LEAF:
                    node_indices_stack.push(child_l[node_idx])
                    node_indices_stack.push(child_r[node_idx])
            leaves_in_subtree[pruned_branch_node_idx] = 1
            in_subtree[pruned_branch_node_idx] = 1

            # updates number of leaves
            n_pruned_leaves = n_leaves[pruned_branch_node_idx] - 1
            n_leaves[pruned_branch_node_idx] = 0

            # computes the increase in r_branch to bubble up
            r_diff = r_node[pruned_branch_node_idx] - r_branch[pruned_branch_node_idx]
            r_branch[pruned_branch_node_idx] = r_node[pruned_branch_node_idx]

            # bubble up values to ancestors
            node_idx = parent[pruned_branch_node_idx]
            while node_idx != _TREE_UNDEFINED:
                n_leaves[node_idx] -= n_pruned_leaves
                r_branch[node_idx] += r_diff
                node_idx = parent[node_idx]

            controller.save_metrics(effective_alpha, r_branch[0])

        controller.after_pruning(in_subtree)

# 构造对象在下面剪枝使用使用
cdef struct BuildPrunedRecord:
    SIZE_t start
    SIZE_t depth
    SIZE_t parent
    bint is_left

# 剪枝方法,传入需要兼职的节点信息leaves_in_subtree
cdef _build_pruned_tree(
    Tree tree, # 输出树
    Tree orig_tree, # 源树
    const unsigned char[:] leaves_in_subtree,
    SIZE_t capacity):
    """Build a pruned tree.

    Build a pruned tree from the original tree by transforming the nodes in
    ``leaves_in_subtree`` into leaves.

    Parameters
    ----------
    tree : Tree
        Location to place the pruned tree
    orig_tree : Tree
        Original tree
    leaves_in_subtree : unsigned char memoryview, shape=(node_count, )
        Boolean mask for leaves to include in subtree
    capacity : SIZE_t
        Number of nodes to initially allocate in pruned tree
    """
    tree._resize(capacity)

    cdef:
        SIZE_t orig_node_id
        SIZE_t new_node_id
        SIZE_t depth
        SIZE_t parent
        bint is_left
        bint is_leaf

        # value_stride for original tree and new tree are the same
        SIZE_t value_stride = orig_tree.value_stride
        SIZE_t max_depth_seen = -1
        int rc = 0
        Node* node
        double* orig_value_ptr
        double* new_value_ptr

        stack[BuildPrunedRecord] prune_stack
        BuildPrunedRecord stack_record

    with nogil:
        # push root node onto stack
        prune_stack.push({"start": 0, "depth": 0, "parent": _TREE_UNDEFINED, "is_left": 0})

        while not prune_stack.empty():
            stack_record = prune_stack.top()
            prune_stack.pop()

            orig_node_id = stack_record.start
            depth = stack_record.depth
            parent = stack_record.parent
            is_left = stack_record.is_left

            is_leaf = leaves_in_subtree[orig_node_id]
            node = &orig_tree.nodes[orig_node_id]

            new_node_id = tree._add_node(
                parent, is_left, is_leaf, node.feature, node.threshold,
                node.impurity, node.n_node_samples,
                node.weighted_n_node_samples)

            if new_node_id == SIZE_MAX:
                rc = -1
                break

            # copy value from original tree to new tree
            orig_value_ptr = orig_tree.value + value_stride * orig_node_id
            new_value_ptr = tree.value + value_stride * new_node_id
            memcpy(new_value_ptr, orig_value_ptr, sizeof(double) * value_stride)

            if not is_leaf:
                # Push right child on stack
                prune_stack.push({"start": node.right_child, "depth": depth + 1,
                                  "parent": new_node_id, "is_left": 0})
                # push left child on stack
                prune_stack.push({"start": node.left_child, "depth": depth + 1,
                                  "parent": new_node_id, "is_left": 1})

            if depth > max_depth_seen:
                max_depth_seen = depth

        if rc >= 0:
            tree.max_depth = max_depth_seen
    if rc == -1:
        raise MemoryError("pruning tree")

基于CCP的代码,去除CCP的计算,直接指定哪些节点需要剪枝,复用_build_pruned_tree方法,具体剪枝代码如下,实现prune_tree方法:

cdef struct PruningRecord:
    SIZE_t node_idx
    SIZE_t parent
    SIZE_t leaves

def prune_tree(
    Tree tree, # OUT
    Tree orig_tree,
    list leaves_redoces):

    cdef:
        SIZE_t n_nodes = orig_tree.node_count
        unsigned char[:] leaves_in_subtree = np.zeros(shape=n_nodes, dtype=np.uint8)
        stack[PruningRecord] ccp_stack
        PruningRecord stack_record
        SIZE_t[:] child_l = orig_tree.children_left
        SIZE_t[:] child_r = orig_tree.children_right
        SIZE_t capacity = 0
        SIZE_t leaves = 0

    for leaf_idx in leaves_redoces:
        leaves_in_subtree[leaf_idx] = 1
        
    ccp_stack.push({"node_idx": 0, "parent": _TREE_UNDEFINED, "leaves": 0})
    while not ccp_stack.empty():
        stack_record = ccp_stack.top()
        leaves = stack_record.leaves
        ccp_stack.pop()
        node_idx = stack_record.node_idx
        if leaves_in_subtree[node_idx]:
            leaves = leaves_in_subtree[node_idx]
        if child_l[node_idx] != _TREE_LEAF:
            leaves_in_subtree[node_idx] = leaves
            ccp_stack.push({"node_idx": child_l[node_idx], "parent": node_idx, "leaves": leaves})
            ccp_stack.push({"node_idx": child_r[node_idx], "parent": node_idx, "leaves": leaves})
        else:
            leaves_in_subtree[node_idx] = 1
    
    for i in range(leaves_in_subtree.shape[0]):
        if leaves_in_subtree[i]:
            capacity += 1
    _build_pruned_tree(tree, orig_tree, leaves_in_subtree, capacity)

由于是.pyx文件需要通过源码编译生成指定的文件,命令:

python setup.py build_ext --inplace

将目录下sklearn\tree\_tree.cp38-win_amd64.pyd文件复制替换当前python环境下的同名文件,重启测试:

#-*- coding: UTF-8 -*-
'''
Created on 2021年3月4日

@author: xch
'''
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
from sklearn import tree
from sklearn.tree._tree import prune_tree,Tree
import numpy as np
from sklearn.metrics import precision_score,recall_score,accuracy_score,f1_score
import pandas as pd
# 加载Iris数据集
iris = load_iris()

X = iris.data
y = iris.target
feature_names = iris.feature_names
print("特征:"+str(iris.feature_names))
# 创建并训练决策树分类器
tree_classifier = DecisionTreeClassifier(random_state=0)
tree_classifier.fit(X, y)

y_predict = tree_classifier.predict(X)
acc = accuracy_score(y, y_predict)
p_score = precision_score(y, y_predict, average='micro')
r_score = recall_score(y, y_predict, average='micro')
f_score = f1_score(y, y_predict, average='micro')
print("准确率="+str(acc)+" 精确度="+str(p_score)+" 召回度="+str(r_score)+"f_score="+str(f_score))
tree.plot_tree(tree_classifier,filled=True)
plt.show()
# 手动剪枝决策树
def xprune_tree(tree_model, recodes=[7]):
    feature_idx = feature_names.index(feature_name)
    feature_idxs = [feature_idx,feature_idx-len(feature_names)]
    pruned_tree = Tree(tree_model.n_features_in_, tree_model.classes_, tree_model.n_outputs_)
    prune_tree(pruned_tree,tree_model.tree_,recodes)
    tree_model.tree_ = pruned_tree


xprune_tree(tree_classifier)

# 输出
with open("iris.dot", 'w') as f:
    tree.export_graphviz(tree_classifier, out_file=f)
tree.plot_tree(tree_classifier,filled=True)
plt.show()

决策树如预期的那般被处理了,但是用其他工具没办法显示出来,不知道什么问题,不太懂cpython,需要大神帮忙看一下,目前剪枝后的系数也变化,可以用于预测,计算准确度也可以,pmml也可以导出,原有的决策树功能都保留下来。

  • 5
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值