numpy实现简易版决策树算法

根据决策树分而治之的思想,使用gini准则,封装一个决策树分类算法,同时能实现调节两个超参数:树深和叶子节点最小样本数。

代码部分

import numpy as np
from collections import Counter

'''
Encapsulate the decision tree method
author:Evan
'''

class DecisionTreeClassifier:
    def __init__(self,max_depth=2,min_samples_leaf=1):
        self.tree_ = None
        self.max_depth = max_depth
        self.min_samples_leaf = min_samples_leaf
        
    def fit(self, X, y):
        self.tree_ = self.creat_tree(X,y)
        return self

    def creat_tree(self, X, y,current_depth=1):
        if current_depth>self.max_depth:
            return None
        
        d, v, g = try_split(X, y,self.min_samples_leaf)

        if d == -1 or g == 0:
            return None
        node = Node(d, v, g)

        X_left, X_right, y_left, y_right = cut(X, y, v, d)
        node.children_left = self.creat_tree(X_left, y_left,current_depth+1)
        if node.children_left is None:
            label = Counter(y_left).most_common(1)[0][0]
            node.children_left = Node(l=label)

        node.children_right = self.creat_tree(X_right, y_right,current_depth+1)
        if node.children_right is None:
            label = Counter(y_right).most_common(1)[0][0]
            node.children_right = Node(l=label)
        return node


    def predict(self, X):
        assert self.tree_ is not None, 'Call the fit() method first'

        return np.array([self._predict(x, self.tree_) for x in X])

    def _predict(self, x, node):
        if node.label is not None:
            return node.label
        if x[node.dim] <= node.value:
            return self._predict(x, node.children_left)
        else:
            return self._predict(x, node.children_right)

def cut(X,y,v,d):
    '''将数据一分为二'''
    ind_left = (X[:,d]<=v)
    ind_right = (X[:,d]>v)
    return X[ind_left], X[ind_right], y[ind_left], y[ind_right]

def try_split(X,y,min_samples_leaf):
    '''划分数据集,返回最好的划分点'''
    best_g = 1
    best_d = -1
    best_v = -1
    for d in range(X.shape[1]):
        sorted_index = np.argsort(X[:,d])
        for i in range(len(X)-1):
            if X[sorted_index[i],d] == X[sorted_index[i+1],d]:
                continue
            v = (X[sorted_index[i],d]+X[sorted_index[i+1],d])/2
#             print("d={},v={}".format(d,v))
            X_left,X_right,y_left,y_right = cut(X,y,v,d)
            gini_all =gini(y_left)+gini(y_right)
#             print("d={},v={},g={}".format(d,v,gini_all))
            if gini_all<best_g and len(y_left)>= min_samples_leaf and len(y_right) >= min_samples_leaf:
        
                best_g = gini_all
                best_d = d
                best_v = v
    return best_d,best_v,best_g

# define node class
class Node():
    def __init__(self, d=None, v=None, g=None, l=None):
        self.dim = d
        self.value = v
        self.gini = g
        self.label = l

        self.children_left = None
        self.children_right = None

    def __repr__(self):
        return "Node(d={},v={},g={},l={})".format(self.dim, self.value, self.gini, self.label)

    
# compute gini
def gini(y):
    counter = Counter(y)
    result = 0
    for v in counter.values():
        result += (v/len(y))**2
    return 1 - result
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值