Python分类决策树实现

本文用Python实现了分类决策树,主要实现了ID3、C4.5算法及剪枝。

本文的算法及代码参考了《机器学习实战》及周志华老师的《机器学习》,以及博客:http://blog.csdn.net/wzmsltw


决策树主文件 tree.py

# coding: utf-8

from math import log
import json
from plot import createPlot



class DecisionTree():
    
    def __init__(self,criterion = "entropy"):
        self.tree = None
        self.criterion = criterion
        
           
        
    def _is_continuous_value(self,a):
        # 判断一个值是否是连续型变量
        
        if type(a).__name__.lower().find('float')>-1 or \
           type(a).__name__.lower().find('int')>-1:
            return True
        else:
            return False
        
        
    def _calc_entropy(self,dataset):
        # 计算数据集的香农熵
        
        classes = dataset.ix[:,-1]
        total = len(classes)
        cls_count = {}
        for cls in classes:
            if cls not in cls_count.keys():
                cls_count[cls] = 0
            cls_count[cls] += 1
        entropy = 1.0
        for key in cls_count:
            prob = float(cls_count[key]) / total
            entropy -= prob * log(prob, 2)
        return entropy
    
    def _calc_gini(self,dataset):
        # 计算数据集的Gini指数
        
        classes = dataset.ix[:,-1]
        total = len(classes)
        cls_count = {}
        for cls in classes:
            if cls not in cls_count.keys():
                cls_count[cls] = 0
            cls_count[cls] += 1
        gini = 1.0
        for key in cls_count:
            prob = float(cls_count[key]) / total
            gini -= prob ** 2
        return gini
    
        
    def _split_data_category(self, dataset, feature, value):   
        # 对分类变量进行拆分     
        # 将feature列的值为value的记录抽取出来,同时删除feature列
        
        ret = dataset[dataset[feature] == value].drop([feature],axis=1)       
        return ret
            

    def _split_data_continuous(self, dataset, feature, value):
        # 对连续变量进行二分类拆分,同时返回两个数据集
        # 但是“不”删除feature列
        
        ret_less = []
        ret_greater = [] 
        ret_greater = dataset[dataset[feature] > value]
        ret_less = dataset[dataset[feature] <= value]
        
        return ret_less,ret_greater
    
   
    
    def _choose_best_feature_entropy(self,dataset):

        base_entropy = self._calc_entropy(dataset)
        best_info_gain_ratio = 0.0
        best_feature = ''
        best_split_dict={}  
        for col in dataset.columns[:-1]: 
            # 对每一个特征col,计算信息增益,选择信息增益最大的特征为划分点
            
            # i列数据,变成行向量
            feature_val_list = dataset[col]
            
            
            # 对连续型特征进行处理  
            if self._is_continuous_value(feature_val_list.iloc[0]):
                # 产生n-1个候选划分点
                sorted_feature_list = sorted(feature_val_list.unique())
                
                split_list=[]  
                for j in range(len(sorted_feature_list)-1):
                    split_list.append((sorted_feature_list[j]+sorted_feature_list[j+1])/2.0)
                                
                best_split_entropy = 100000000  
                best_split = -1
                best_split_info = 0.0
                # 求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点
                for j in range(len(split_list)):
                    value = split_list[j]  
                    entropy = 0.0  
                    sub_data0,sub_data1 = self._split_data_continuous(dataset,col,value)  
                    prob0 = float(sub_data0.shape[0])/float(dataset.shape[0])  
                    entropy += prob0 * self._calc_entropy(sub_data0)  
                    prob1 = float(sub_data1.shape[0])/float(dataset.shape[0])  
                    entropy += prob1 * self._calc_entropy(sub_data1)    
                    # 计算目前分裂点的分裂属性
                    split_info = - prob0 * log(prob0,2) - prob1 * log(prob1,2)                   
                    if entropy < best_split_entropy:
                        best_split_entropy = entropy  
                        best_split = j  
                        best_split_info = split_info

                #用字典记录当前特征的最佳划分点  
                if len(split_list)>0:
                    best_split_dict[col] = split_list[best_split] 
                else:
                    best_split_dict[col] = None
                     
                info_gain = base_entropy - best_split_entropy 
                # 计算最佳划分点的信息增益率
                
                info_gain_ratio = info_gain / best_split_info if best_split_info>0 else 0.0
                 
            #对离散型特征进行处理  
            else:
                unique_vals =set(feature_val_list)
                entropy = 0.0
                split_info = 0.0
                for val in unique_vals:
                    sub_data = self._split_data_category(dataset, col, val) 
                    prob = sub_data.shape[0] / float(dataset.shape[0])
                    entropy +=  prob * self._calc_entropy(sub_data)
                    split_info -= prob * log(prob, 2)
                info_gain = base_entropy - entropy
                info_gain_ratio = info_gain / split_info if split_info > 0 else 0.0
                           
            if info_gain_ratio > best_info_gain_ratio:
                best_info_gain_ratio = info_gain_ratio
                best_feature = col
               
               
        
        if self._is_continuous_value(dataset[best_feature].iloc[0]):
            best_split_value = best_split_dict[best_feature] 
        else:
            best_split_value = None 
                                       
        return best_feature,best_split_value
    
    def _choose_best_feature_gini(self,dataset):

        best_gini = 1.0
        best_feature = ''
        best_split_dict={}  
        for col in dataset.columns[:-1]: 
            # 对每一个特征col,计算信息增益,选择信息增益最大的特征为划分点
            
            # i列数据,变成行向量
            feature_val_list = dataset[col]
            
            
            # 对连续型特征进行处理  
            if self._is_continuous_value(feature_val_list.iloc[0]):
                # 产生n-1个候选划分点
                
                sorted_feature_list = sorted(feature_val_list.unique())
                
                split_list=[]  
                for j in range(len(sorted_feature_list)-1):
                    split_list.append((sorted_feature_list[j]+sorted_feature_list[j+1])/2.0)
                                
                best_split_gini = 1.0  
                best_split = -1
                # 求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点
                for j in range(len(split_list)):
                    value = split_list[j]  
                    _gini = 0.0  
                    sub_data0,sub_data1 = self._split_data_continuous(dataset,col,value)  
                    prob0 = float(sub_data0.shape[0])/float(dataset.shape[0])  
                    _gini += prob0 * self._calc_gini(sub_data0)  
                    prob1 = float(sub_data1.shape[0])/float(dataset.shape[0])  
                    _gini += prob1 * self._calc_gini(sub_data1)    
                    # 计算目前分裂点的分裂属性                    
                    if _gini < best_split_gini:
                        best_split_gini = _gini  
                        best_split = j  
                gini = best_split_gini

                #用字典记录当前特征的最佳划分点  
                if len(split_list)>0:
                    best_split_dict[col] = split_list[best_split] 
                else:
                    best_split_dict[col] = None   
                 
            #对离散型特征进行处理  
            else:
                unique_vals =set(feature_val_list)
                gini = 0.0
                for val in unique_vals:
                    sub_data = self._split_data_category(dataset, col, val) 
                    prob = sub_data.shape[0] / float(dataset.shape[0])
                    gini +=  prob * self._calc_gini(sub_data)
                           
            if gini < best_gini:
                best_gini = gini
                best_feature = col
                    
                
               
        
        if self._is_continuous_value(dataset[best_feature].iloc[0]):
            best_split_value = best_split_dict[best_feature] 
        else:
            best_split_value = None 
                                       
        return best_feature,best_split_value
    
    def _choose_best_feature(self,dataset):
        # 根据不同的判断标准选择最佳分类特征
        # 默认为信息增益率
        
        if self.criterion == 'gini':
            return self._choose_best_feature_gini(dataset)
        elif self.criterion == 'entropy':
            return self._choose_best_feature_entropy(dataset)
        else:
            return self._choose_best_feature_entropy(dataset)
    
    def _class_vote(self,class_list):
        # 对分类进行投票,选择例数最多的分类返回
        
        class_count = {}
        for vote in class_list:
            if vote not in class_count.keys():
                class_count[vote] = 0
            class_count[vote] += 1
        
        return max(class_count)
    
    
    def _testing(self,tree):  
        # 使用测试集来测试当前决策树节点的正确率                        
        return self._validate(tree,self.test_data)[1]
    
    def _testing_node(self, class_voted):
        # 测试该节点在未分裂时在测试集上的正确率
        
        correct = 0.0  
        nb_samples = self.test_data.shape[0]
        for i in range(nb_samples):  
            if class_voted == self.test_data.iloc[i,-1]:  
                correct += 1       
        print(float(correct)/nb_samples)
        return float(correct)/nb_samples
    
    def _testing_node_feature(self, dataset, feature, split_value = None):
        # 测试节点划分后,在验证数据集上的正确率
        
        train_data_values = dataset.ix[:,feature]
        
        correct = 0.0
        
        if self._is_continuous_value(train_data_values.iloc[0]):
            sub_datas = self._split_data_continuous(dataset, feature, split_value)
            test_sub_datas = self._split_data_continuous(self.test_data, feature, split_value)
            
            for i in range(2):
                sub_class_list = sub_datas[i].ix[:,-1]
                class_voted = self._class_vote(sub_class_list)
                sub_class_list_test = test_sub_datas[i].ix[:,-1]
                for test_value in sub_class_list_test.values:
                    if test_value == class_voted:
                        correct += 1.0
        else:
            all_unique_vals = set(train_data_values)
            
            for value in all_unique_vals:           
                sub_data = self._split_data_category(dataset, feature, value)
                sub_class_list = sub_data.ix[:,-1]
                class_voted = self._class_vote(sub_class_list)
                
                test_sub_data = self._split_data_category(self.test_data, feature, value)
                test_sub_class_list = test_sub_data.ix[:,-1]                
                
                for test_value in test_sub_class_list.values:
                    if test_value == class_voted:
                        correct += 1.0
        return correct / float(self.test_data.shape[0])
        
    
    
    def create_tree(self,dataset):
       
        class_list = dataset.ix[:,-1]
        class_voted = self._class_vote(class_list)          
        
        # 节点为同一类,则不再划分
        if len(set(class_list)) == 1:            
            return class_voted
        
        # 没有可以用来划分的属性了,投票决定分类
        if dataset.shape[1] == 1:
            return class_voted
      
        best_feature,best_split_value= self._choose_best_feature(dataset)
        
        
        # 预剪枝
        if self.prune == 'prev' and \
        self._testing_node(class_voted) >= \
        self._testing_node_feature(dataset, best_feature, best_split_value):
            return class_voted
        else:
            my_tree = {"nodes":{best_feature:{}},
                   'node_data':{
                                'entropy':self._calc_entropy(dataset),
                                'gini':self._calc_gini(dataset),
                                'samples': len(dataset),
                                'is_continuous':best_split_value is not None,
                                'split_value':best_split_value
                                }}     
            
                
        
        if best_split_value is not None: 
            sub_data_less,sub_data_greater = self._split_data_continuous(dataset, best_feature, best_split_value)
            my_tree['nodes'][best_feature][">"] = self.create_tree(sub_data_greater)
            my_tree['nodes'][best_feature]["<="] = self.create_tree(sub_data_less)
        
        else:
            feature_values = dataset.ix[:,best_feature]
            unique_vals = set(feature_values)
            for val in unique_vals:
                sub_data = self._split_data_category(dataset, best_feature, val)
                my_tree['nodes'][best_feature][val] = self.create_tree(sub_data) 
            my_tree['nodes'][best_feature]['__other__'] = class_voted
            
        # 后剪枝
        if self.prune == 'post':
            if self._testing(my_tree) <= self._testing_node(class_voted):
                return class_voted                
                    
        return my_tree      
    
    def fit(self,train_data,test_data = None, prune = None):
        # 拟合决策树
        # @param train_data:训练集,
        # @param test_data: 用于剪枝处理的测试集。如果为None,则不剪枝
        # @param prune: 是否剪枝
        # return : 生成的树        
        
        assert type(train_data).__name__ == 'DataFrame', '训练集必须为pandas的DataFrame对象'
        assert (test_data is None) or type(test_data).__name__ == 'DataFrame',\
            '测试集必须为空或为pandas的DataFrame对象'
        
        # 用于剪枝处理的测试集
        self.test_data = test_data
        self.prune = prune 
        
        if self.test_data is None:
            self.prune = None
        
        
        # 生成tree并返回
        self.tree = self.create_tree(train_data)
        
        return self.tree
    
    def plot(self):
        createPlot(self.tree)
    
    def print_tree(self):
        return  json.dumps(self.tree, ensure_ascii=False, indent=4)
    
    
    def _predict_one(self,tree,vec):
        
        if  type(tree).__name__ != 'dict':
            return tree
              
        nodes = tree.get('nodes')
        node_data = tree.get('node_data')
        for feat in nodes:    
            if node_data['is_continuous']:
                split_value = node_data['split_value']
                if vec[feat] <= split_value:
                    sub_tree = nodes[feat]['<=']
                    return self._predict_one(sub_tree, vec)
                else:
                    sub_tree = nodes[feat]['>']
                    return self._predict_one(sub_tree, vec)
            else:
                if vec[feat] in nodes[feat]:
                    return self._predict_one(nodes[feat][vec[feat]], vec)
                else:
                    return nodes[feat]['__other__']   
                
    
    def predict(self,dataset):
        
        assert self.tree is not None, '请先生成决策树'
        
        type_name = type(dataset).__name__
        assert type_name in ['DataFrame','Series'], '仅支持pandas的DataFrame或Series类型'    
        
        
        # 用一个list来保存预测结果
        predict_result = []
        
        if type_name == 'Series':
            predict_result.append(self._predict_one(self.tree, dataset)) 
            return predict_result
        
        if type_name == 'DataFrame':
            nb_samples = dataset.shape[0]
            for idx in range(nb_samples):
                vec = dataset.iloc[idx,:]
                predict_result.append(self._predict_one(self.tree, vec))
            
            return predict_result
        
    
    def _validate(self,tree,dataset):
                
        # 样本个数
        nb_samples = dataset.shape[0]
        
        # 没有样本,返回tuple
        if nb_samples ==0 :
            return [],0        
        
        # 用一个list来保存预测结果
        predict_result = []
        # 用一个list保存真实结果
        real_result = list(dataset.ix[:,-1].values)
        
        # 开始逐个预测
        for idx in range(nb_samples):
            vec = dataset.iloc[idx,:]
            predict_result.append(self._predict_one(tree, vec))           
        
        # 计算正确率
        correct_count = 0
        for i in range(nb_samples):
            if predict_result[i] == real_result[i]:
                correct_count += 1
        
        correct_ratio = float(correct_count) / nb_samples
            
        return predict_result, correct_ratio
        
        
    def validate(self,dataset):
                
        assert self.tree is not None, '请先生成决策树'
        
        type_name = type(dataset).__name__
        assert type_name in ['DataFrame'], '仅支持pandas的DataFrame类型'    
           
        return self._validate(self.tree, dataset)
                
        
              
        
                
        
    
    



绘制决策树的程序 plot.py

# coding: utf-8

import matplotlib.pyplot as plt  
from pylab import mpl
  
mpl.rcParams['font.sans-serif'] = ['SimHei'] #指定默认字体  
  
mpl.rcParams['axes.unicode_minus'] = False #解决保存图像是负号'-'显示为方块的问题 


decisionNode=dict(boxstyle="sawtooth",fc="0.8")  
leafNode=dict(boxstyle="round4",fc="0.8")  
arrow_args=dict(arrowstyle="<-")  
  
  
#计算树的叶子节点数量  
def getNumLeafs(myTree):  
    numLeafs=0  
    firstStr=list(myTree['nodes'])[0]  
    secondDict=myTree['nodes'][firstStr]
    for key in secondDict.keys():  
        if type(secondDict[key]).__name__=='dict':  
            numLeafs+=getNumLeafs(secondDict[key])  
        else: numLeafs+=1  
    return numLeafs  
  
#计算树的最大深度  
def getTreeDepth(myTree):  
    maxDepth=0  
    firstStr=list(myTree['nodes'])[0]  
    secondDict=myTree['nodes'][firstStr]

    for key in secondDict.keys():  
        if type(secondDict[key]).__name__=='dict':  
            thisDepth=1+getTreeDepth(secondDict[key])  
        else: thisDepth=1  
        if thisDepth>maxDepth:  
            maxDepth=thisDepth  
    return maxDepth  
  
#画节点  
def plotNode(nodeTxt,centerPt,parentPt,nodeType):  
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',  
    xytext=centerPt,textcoords='axes fraction',va="center", ha="center",  
    bbox=nodeType,arrowprops=arrow_args)  
  
#画箭头上的文字  
def plotMidText(cntrPt,parentPt,txtString):  
    lens=len(txtString)  
    xMid=(parentPt[0]+cntrPt[0])/2.0-lens*0.002  
    yMid=(parentPt[1]+cntrPt[1])/2.0  
    createPlot.ax1.text(xMid,yMid,txtString)  
      
def plotTree(myTree,parentPt,nodeTxt):  
    numLeafs=getNumLeafs(myTree)  
    depth=getTreeDepth(myTree)  
    firstStr=list(myTree['nodes'])[0]  
    cntrPt=(plotTree.x0ff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.y0ff)  
    plotMidText(cntrPt,parentPt,nodeTxt)  
    
    node_str = list(myTree['nodes'])[0]
    node_str += "\nGini:%.2f"% (myTree['node_data']['gini'])
    node_str += "\nEntropy:%.2f"% (myTree['node_data']['entropy'])
    node_str += "\nSamples:%d" % (myTree['node_data']['samples'])
    node_str += "\nSplitValue:%.2f" % (myTree['node_data']['split_value'])\
              if myTree['node_data']['is_continuous'] else ''
    
    plotNode(node_str,cntrPt,parentPt,decisionNode)  
    secondDict=myTree['nodes'][firstStr]  
    plotTree.y0ff=plotTree.y0ff-1.0/plotTree.totalD  
    for key in secondDict.keys():  
        if type(secondDict[key]).__name__=='dict':  
            plotTree(secondDict[key],cntrPt,str(key))  
        else:  
            plotTree.x0ff=plotTree.x0ff+1.0/plotTree.totalW  
            plotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),cntrPt,leafNode)  
            plotMidText((plotTree.x0ff,plotTree.y0ff),cntrPt,str(key))  
    plotTree.y0ff=plotTree.y0ff+1.0/plotTree.totalD  
  
def createPlot(inTree):  
    fig=plt.figure(1,facecolor='white')  
    fig.clf()  
    axprops=dict(xticks=[],yticks=[])  
    createPlot.ax1=plt.subplot(111,frameon=False,**axprops)  
    plotTree.totalW=float(getNumLeafs(inTree))  
    plotTree.totalD=float(getTreeDepth(inTree))  
    plotTree.x0ff=-0.5/plotTree.totalW  
    plotTree.y0ff=1.0  
    plotTree(inTree,(0.5,1.0),'')  
    plt.show()  



使用测试文件 

# coding: utf-8

import numpy as np
from sklearn.datasets import load_iris
import pandas as pd

from tree import DecisionTree



dataset = pd.read_csv('dataset_xg3.csv')

nb_samples = dataset.shape[0]

train_sampler = np.random.permutation(10)
test_sampler  = list(set(np.array(range(nb_samples)))-set(train_sampler))


train_sampler = np.array(range(11))
test_sampler  = np.array([11,12,13,14,15,16])

train_data = dataset.take(train_sampler)
test_data = dataset.take(test_sampler)



tree = DecisionTree(criterion="gini")

print(tree.fit(train_data,test_data,prune='prev'))

tree.plot()





  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值