python实现决策树

 决策树是一个预测模型;他代表的是对象属性与对象值之间的一种映射关系。树中每个节点表示某个对象,而每个分叉路径则代表某个可能的属性值,而每个叶节点则对应从根节点到该叶节点所经历的路径所表示的对象的值。

详细关于决策树的讨论,请自行google。

一、找到最优分割位置

1、针对样本数据,需要在其不同的维度(d)上根据特定数据(v)进行分割

#X:样本数据
#y:样本属性
#d:维度
#v:分割标准
def cut(X , y , d , v):
    ind_left  = (X[:,d] <= v)
    ind_right = (X[:,d] > v)
    return (X[ind_left] , X[ind_right] , y[ind_left] , y[ind_right])

2、将样本数据排序

sorted_index= np.argsort(X[:,d])

3、找出中间点

v = (X[sorted_index[i] , d ] + X[sorted_index[i+1] , d ]) / 2

4、按照中间点进行分割

X_left , X_right , y_left , y_right = cut(X , y , d , v)

5、计算基尼系数

    gini_cur = gini(y_left , y_right )

6、找到基尼系数最小的分割位置(维度,分割值)

if gini_cur < gini_best :
   best_g = gini_cur
   best_d = d
   best_v = v

二、创建决策树
1、找到原始数据的最优分割点(对于第一次,找的结果是根节点的分割情况)

d , v , g  = try_split(X , y)

2、将找的结果保存在结点Node中

node = Node(d,v,g)

3、根据最优点将数据分割

X_left , X_right , y_left , y_right = cut(X , y , d , v)

4、递归查找下一个结点

node.child_left  = create_tree(X , y)
node.child_right = create_tree(X , y)

最后对上述过程汇总:
1、实现计算基尼系数
这里写图片描述

from collections import Counter

#y:样本数据的标签
def gini(y):
    counter = Counter(y)
    result = 0
    for v in counter.values():
        result += (v / len(y))**2
    return (1 - result )

2、根据维度(d)和值(v)对数据进行分割

#X:样本数据
#y:样本数据的标签
#d:维度
#v:分割数据
def cut(X , y , d , v):
    ind_left  = (X[:,d] <= v)
    ind_right = (X[:,d] > v)
    return (X[ind_left] , X[ind_right] , y[ind_left] , y[ind_right])

3、查找最优分割点

import numpy as np
#X:样本数据
#y:样本数据的标签
def try_split(X , y):
    best_g = 1
    best_d = -1
    best_v = -1
    for d in range(X.shape[1]):
        #将数据排序
        sorted_index = np.argsort(X[:,d])
        for i in range(len(X) - 1):
            if (X[sorted_index[i],d] == X[sorted_index[i + 1],d]):
                continue
            #计算两点之间的平均值
            v = (X[sorted_index[i],d] + X[sorted_index[i + 1],d]) / 2
            #根据d  v将X  y分割
            X_left , X_right , y_left , y_right = cut(X , y , d , v)
            #计算基尼系数
            gini_cur = gini(y_left) + gini(y_right)
            #计算最优分割点
            if gini_cur < best_g:
                best_g = gini_cur
                best_v = v
                best_d = d
      return (best_d,best_v,best_g)

4、结点,保存分割信息

class Node():
    def __init__(self,d=None,v=None,g=None,l=None):
        self.dim   = d
        self.value = v
        self.gini  = g
        self.label = l

        self.child_left = None
        self.child_rignt = None
    def __repr__(self):
        return 'Node(d={},v={},g={})'.format(self.dim,self.value,self.gini)

5、创建决策树

def create_tree(X , y):
    #查找最优分割点
    d,v,g = try_split(X , y)

    #不用再分
    if (d==-1) or (g==0):
        return None

    #实例化结点
    node = Node(d,v,g)

    #按照最优点把数据分割
    X_left , X_right , y_left , y_right = cut(X , y , d , v)

    #递归子结点(左)
    node.child_left = create_tree(X_left , y_left)
    #左边分割完了,保存label
    if node.child_left is None:
        #label
        label = Counter(y_left).most_common(1)[0][0]
        node.label = Node(l = label)

    #递归子结点(右)
    node.child_right = create_tree(X_left , y_left)
    #右边分割完了,保存label
    if node.child_right is None:
        #label
        label = Counter(y_right).most_common(1)[0][0]
        node.label = Node(l = label)

   return node

6、绘制决策树
这里写图片描述

def show_tree(node):
    if node is None:
        return ''

    result += '{} [label="{}"]\n'.format(id(node),node)
    if node.child_left is not None:
        result += '{} [label="{}"]\n'.format(id(node.child_left),node.child_left)
        result += '{}->{}\n'.format(id(node),id(node.child_left))
        result += show_tree(node.child_left)

    if node.child_right is not None:
        result += '{} [label="{}"]\n'.format(id(node.child_right),node.child_right)
        result += '{}->{}\n'.format(id(node),id(node.child_right))
        result += show_tree(node.child_right)
    return result
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值