hello,上篇文章实现的是svm
本文实现决策树算法。主要是依照周志华那本机器学习的书籍,进行实现。
其中红色部分我未写。因为我在选择属性划分的时候,不可能导致数据集为空。所以这部分对我来说没有必要了
这里是用字典来存储树,开始还准备用list,发现字典要好些,但是我没有画图。如果需要就自己去画吧,我就不画了。
具体代码如下:
# coding=utf-8 import pprint import uniout import math from collections import Counter ''' @author :chenyuqing @mail :chen_yu_qin_g@163.com ''' from numpy import * def load_data(path): ''' :param path:传递路径,返回样例的数据和标签,格式采用矩阵,便于进行矩阵运算 :return: ''' data_set=[] label_set=[] file_object=open(path) for line in file_object.readlines(): lineArr = line.strip().split(',') label_set.append(lineArr[-1]) #最后一列默认为标记 data_set.append(lineArr) #这里提取三个部分:属性集、标签集、数据集 attri_set=array(data_set[0])[1:-1] data_set=array(data_set)[1:,1:-1] label_set=array(label_set)[1:] return data_set,label_set,attri_set def is_same_class(label_set): ''' :param label_set: :return:#判断节点中是否是同一类 ''' if len(unique(label_set))==1: return True return False def info_entropy(label_set): ''' :param data_set: :return:计算信息熵,信息熵越小,纯度越高,全部为一个类型,则信息熵为0 ''' info_entro=0 #设置初始信息熵为0 class_set=unique(label_set) #得到有几类 for item in class_set: tempsult=0 for sample_item in label_set: if(item==sample_item): tempsult=tempsult+1 ratio=float(tempsult)/len(label_set) info_entro =info_entro+ratio*math.log(ratio,2) info_entro=0-info_entro return info_entro def info_gain(data_set,label_set,attri_set): ''' :param data_set: :param label_set: :param attri_set: :return:计算信息增益, 返回信息增益最大的那个属性 ''' data_set=array(data_set) #必须将list转换成array root_info_entro=info_entropy(label_set=label_set) #计算根节点的信息熵 m,n=shape(data_set) result_gain=[] for i in range(n): #对给出的数据,每个特征计算信息增益 i_gain=0 #设定初始的信息增益为0 col_data=data_set[:,i] #提取第i列数据 col_attr=unique(col_data) #一个特征有几个属性 for item_attr in col_attr: label_set_attr=[] #初始化当前的标签结果为空 for j in range(len(col_data)): if(item_attr==col_data[j]): label_set_attr.append(label_set[j]) attr_entropy=info_entropy(label_set_attr) ratio=float(len(label_set_attr))/len(col_data) i_gain=i_gain+ratio*attr_entropy i_gain=root_info_entro-i_gain result_gain.append(i_gain) print result_gain return result_gain.index(max(result_gain)) def is_repeat(data_set): ''' :param data_set: :return:判断数组是否是重复元素 ''' data_set=array(data_set) #必须将list转换成array m,n=shape(data_set) for i in range(n): if len(unique(data_set[:,i]))<>1: return False return True def TreeGenerate(data_set,label_set,attri_set): ''' :param data_set: :param label_set: :param attri_set: :return:返回生成的决策树,采用list来表示决策树 ''' data_set=array(data_set) #必须将list转换成array if(is_same_class(label_set=label_set)): return list(unique(label_set)) #这里转换成list,主要是为了好看 if(len(attri_set)==0 or is_repeat(data_set)):#如果属性集合为空或者数据集合一致,将节点标记为叶子节点,标记为做多的类 return list(Counter(label_set).keys()[0]) attr_num=info_gain(data_set,label_set,attri_set)#得到那一列作为最优划分 col_set=data_set[:,attr_num] col_uni=unique(col_set) #最优划分属性的每个取值都要添加 print "本次选择的划分属性为:%s"%attri_set[attr_num] bestFeatLabel = attri_set[attr_num] myTree = {bestFeatLabel:{}} m,n =shape(data_set) attri_set_temp=attri_set[attri_set<>attri_set[attr_num]] #属性这一列剔除掉划分属性 for item in col_uni: data_set_temp=[] label_set_temp=[] for i in range(m): if(item==data_set[i,attr_num]): data_set_temp.append([data_set[i][k] for k in range(len(data_set[i])) if k!=attr_num ]) #将这个样本放入到下一个数据集中,样本也要把这列删除 label_set_temp.append(label_set[i]) myTree[bestFeatLabel][item]=TreeGenerate(data_set=data_set_temp,label_set=label_set_temp,attri_set=attri_set_temp) return myTree if __name__ == '__main__': print("------------my desion tree-----------") path=u"./西瓜数据集2.0.txt" data_set,label_set,attri_set=load_data(path=path) result=TreeGenerate(data_set=data_set,label_set=label_set,attri_set=attri_set) print("-------------the result tree is ------------") pprint.pprint(result)
西瓜数据集2.0如下:
编号,色泽,根蒂,敲声,纹理,脐部,触感,好瓜 1,青绿,蜷缩,浊响,清晰,凹陷,硬滑,是 2,乌黑,蜷缩,沉闷,清晰,凹陷,硬滑,是 3,乌黑,蜷缩,浊响,清晰,凹陷,硬滑,是 4,青绿,蜷缩,沉闷,清晰,凹陷,硬滑,是 5,浅白,蜷缩,浊响,清晰,凹陷,硬滑,是 6,青绿,稍蜷,浊响,清晰,稍凹,软粘,是 7,乌黑,稍蜷,浊响,稍糊,稍凹,软粘,是 8,乌黑,稍蜷,浊响,清晰,稍凹,硬滑,是 9,乌黑,稍蜷,沉闷,稍糊,稍凹,硬滑,否 10,青绿,硬挺,清脆,清晰,平坦,软粘,否 11,浅白,硬挺,清脆,模糊,平坦,硬滑,否 12,浅白,蜷缩,浊响,模糊,平坦,软粘,否 13,青绿,稍蜷,浊响,稍糊,凹陷,硬滑,否 14,浅白,稍蜷,沉闷,稍糊,凹陷,硬滑,否 15,乌黑,稍蜷,浊响,清晰,稍凹,软粘,否 16,浅白,蜷缩,浊响,模糊,平坦,硬滑,否 17,青绿,蜷缩,沉闷,稍糊,稍凹,硬滑,否
得到的结果如下:
-------------the result tree is ------------
{'纹理': {'模糊': ['否'],
'清晰': {'根蒂': {'硬挺': ['否'],
'稍蜷': {'色泽': {'乌黑': {'触感': {'硬滑': ['是'],
'软粘': ['否']}},
'青绿': ['是']}},
'蜷缩': ['是']}},
'稍糊': {'触感': {'硬滑': ['否'],
'软粘': ['是']}}}}
Process finished with exit code 0