欲与天公试比高:决策树算法及实现

恭贺我国神舟十二号载人飞船成功发射!数风流人物,还看今朝!!!

在这里插入图片描述

在这里插入图片描述

emsp;老规矩,决策树的数学推导部分再次不在赘述,因为决策树是一种非常直观的算法。其中需要注意的地方是特征选择,涉及到了信息增益。算法的描述如下图:
在这里插入图片描述
 但是在具体的实现过程中是有一内内难度的。关键是在于数据集的处理,我们仍用Mnist数据集的话,它有28*28=784个特征,但是选择一个特征后,我们如何对数据进行划分?比如选择了第一个特征,此特征下的数据取值为0~255,那我们如何划分?二分?还是三分?四分?或者第一个特征二分,第二个特征三分,或者其他的自由组合。所以这是一个挠头问题,极有可能发生过拟合。本次我们姑且选择二分,即大于128的化为一类,小于128的化为一类。

import numpy as np
from math import log
import time
def make_data_set(file):
    f=open(file)
    data_set=[]
    for line in f:
        value=line.split(',')
        tempt_list=list(map(int,value))
        data_set.append(tempt_list)
    f.close()
    return data_set
def tran_data(data_set):
    for line in data_set:
        for i in range(len(line)):
            if i==0:
                continue
                #因为第一个值是数据的标签
            else:
                if line[i]>=128:
                    line[i]=1
                    #二分
                else:
                    line[i]=0
def pre_cut(train_data_set,test_data_set):
    train=np.array(train_data_set)
    test=np.array(test_data_set)
    for i in range(train.shape[1]):
        if i==0:
            continue
        temp=set(train[:,i])
        if len(temp)==1:
            #对特征进行初步处理,因为如果测试集中某一特征值全为同一值,那么这个特征值对数据的划分是
            #无任何价值的,直接删去该列特征值即可
            train=np.delete(train,i,1)
            test=np.delete(test,i,1)
        return train.tolist(),test.tolist()
def cal_shann(train_data_set):
    pro=[0,0,0,0,0,0,0,0,0,0]
    entropy=0.0
    if len(train_data_set)==0:
        return 0
    for line in train_data_set:
        pro[line[0]]+=1
    for i in range(10):
        prob=pro[i]/len(train_data_set)
        if prob!=0:
            entropy-=prob*log(prob,2)
    return entropy
def split_data_set(train_data_set,axis,value):
    #axix是第几个特征,value是该特征对应的特征值
    split_data=[]
    for line in train_data_set:
        if line[axis]==value:
            reduced=line[:axis]
            reduced.extend(line[axis+1:])
            split_data.append(reduced)
    return split_data
    #注意:其可能返回空的列表,因为可能数据集中某一项特征值均为1或0,导致if后的语句无法执行
    
def chose_best_feature(data_set):
    base_entropy=cal_shann(data_set)
    best_info=0
    for i in range(len(data_set[0])):
        if i==0:
            continue
            #因为第一个值是数据的标签
        new_entropy=0
        for value in range(2):
            sub_data_set=split_data_set(data_set,i,value)
            if len(data_set)==0:
                continue
            prob=len(sub_data_set)/len(data_set)
            new_entropy+=cal_shann(sub_data_set)
            info=base_entropy-new_entropy
            if info > best_info:
                best_info=info
                best_feature=i
    return best_feature,best_info
    #返回最优的特征,和对应的信息增益
def get_label(label_list):
    number=[0,0,0,0,0,0,0,0,0,0]
    for  i in label_list:
        number[i]+=1
    return number.index(max(number))
    #投票表决,哪一类别的最多,则该集合数据就属于那一类
def creat_tree(data_set):
    if len(data_set)==0:
        print(233)
        return 
    ep = 0.1
    #设置信息增益的阈值
    class_list=[line[0] for line in data_set]
    if class_list.count(class_list[0])==len(class_list):
        return class_list[0]
        #说明该列表中均是同一类别的数据,则返回该类别
    if len(data_set[0])==1:
        return get_label(class_list)
        #特征选完了,只剩标签了,则投票表决
    best_feature,best_info=chose_best_feature(data_set)
    if best_info < ep:
        return get_label(class_list)
    my_tree={best_feature:{}}
    #建立字典树
    if len(split_data_set(data_set, best_feature, 0))!=0:      
        my_tree[best_feature][0] = creat_tree(split_data_set(data_set, best_feature, 0))
    if len(split_data_set(data_set, best_feature, 1))!=0:
        my_tree[best_feature][1] = creat_tree(split_data_set(data_set, best_feature, 1))
    return my_tree
def predict(test_data_per_row,tree):
    #注意是切出测试数据集中的一行来预测
    while True:
        (key,value),=tree.items()
        if type(tree[key]).__name__ == 'dict':
            feature_data=test_data_per_row[key]
            #得到该特征所对应的特征值
            del test_data_per_row[key]
            if feature_data not in value:
                return 11
                #如果该特征值不在测试集生成的树中,那么直接判为0~9之外的数,相当于直接判为错误
            else:
                tree = value[feature_data]
            #选择该特征值对应的子树
            if type(tree).__name__ == 'int':
                return tree
                #返回分类值
        else:
            return value
            #返回分类值

def cal_accuracy(test_data_set,tree):
    cnt=0
    for line in test_data_set:
        if line[0]==predict(line,tree):
            cnt+=1
    return cnt/len(test_data_set)
    
start=time.time()
train_data_set=make_data_set('D:\\bpnetwork\\mnist_train.csv')
test_data_set=make_data_set('D:\\bpnetwork\\mnist_test.csv')  
tran_data(train_data_set)
tran_data(test_data_set)
train_data_set,test_data_set=pre_cut(train_data_set,test_data_set)
tree=creat_tree(train_data_set)
accuracy=cal_accuracy(test_data_set,tree)
print('the accuracy :',accuracy)
print('the time',time.time()-start)

在这里插入图片描述
 debug好几天,感觉怪怪的,该算法还有极大的改进空间,留着后面慢慢优化。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值