# -*- coding: utf-8 -*-
"""
Created on Wed Dec 28 09:33:11 2016
@author: ZQ
"""
import numpy as np
#计算信息熵
def Infor_Ent(data):
data_count = len(data)
labelcounts = {}
for featvec in data:
currentlabel = featvec[-1]
if currentlabel not in labelcounts.keys():
labelcounts[currentlabel] = 0
labelcounts[currentlabel] += 1
ent = 0.0
for key in labelcounts:
prob = float(labelcounts[key])/data_count
ent -= prob*np.log2(prob)
return ent
#计算标签的信息增益中的被减数
def Infor_Gain_Label(data,axis):
featset = set(data[:,axis])
data_count = len(data)
v_ent = 0.0
for f in featset:
f_data = []
f_count = 0
for featvec in data:
if featvec[axis] == f:
f_count += 1
f_data.append(featvec)
v_ent += f_count/data_count*Infor_Ent(f_data)
return v_ent
#计算连续值中信息增益中的被减数(由于数据中浮点数是str类型,需要特殊处理)
def Infor_Gain_Num(data,axis):
data_count = len(data)
feat_lsit = list(map(float,data[:,axis]))
T_list = []
maxGain_inf_T = 0
maxGain_inf = 1
v_ent = 0.0
for i in range(len(feat_lsit)-1):
T_list.append((feat_lsit[i]+feat_lsit[i+1])/2)
for t in T_list:
f_gl = []
f_gt = []
f_count_gl = 0
for featvec in data:
if float(featvec[axis]) < t:
f_count_gl += 1
f_gl.append(featvec)
else:
f_gt.append(featvec)
v_ent = f_count_gl/data_count*Infor_Ent(f_gl) + (data_count-f_count_gl)/data_count*Infor_Ent(f_gt)
if v_ent < maxGain_inf:
maxGain_inf_T = t
maxGain_inf = v_ent
return maxGain_inf_T,maxGain_inf
#寻找最佳分割属性
def bestFeattosplit(data):
Featlabel = data[0][:-1]
# 用于统计非连续值的长度
label_lengh = 0
num_lengh = 0
for l in Featlabel:
if l == '密度' or l == '含糖量':
num_lengh += 1
else:
label_lengh += 1
bestLabel = ''
bestLabel_i = -1
bestInfoGain = -1
best_T = -1
ent = Infor_Ent(data[1:])
#print(label_lengh,num_lengh)
for i in range(label_lengh):
Gain_infor = ent - Infor_Gain_Label(data[1:],i)
if Gain_infor > bestInfoGain:
bestInfoGain = Gain_infor
bestLabel = Featlabel[i]
bestLabel_i = i
for i in range(label_lengh,label_lengh+num_lengh):
T,v_ent = Infor_Gain_Num(data[1:],i)
Gain_infor = ent - v_ent
if Gain_infor > bestInfoGain:
bestInfoGain = Gain_infor
bestLabel = Featlabel[i]
best_T = T
bestLabel_i = i
return bestInfoGain,bestLabel,best_T,bestLabel_i
def creatTree(data):
classList = [f[-1] for f in data[1:]]
# 类别完全相同时就返回
if classList.count(classList[0]) == len(classList):
return classList[0]
InfoGain,best_Label,best_T,best_i = bestFeattosplit(data)
tree = {best_Label:{}}
print(best_Label)
if best_T != -1:
# 大于或者小于处理
subdata_gl = []
subdata_gt = []
subdata_gl.append(data[0])
subdata_gt.append(data[0])
for f in data[1:]:
if float(f[best_i]) > best_T:
subdata_gt.append(f)
else:
subdata_gl.append(f)
#temp_label = best_Label
#InfoGain,best_Label,best_T,best_i = bestFeattosplit(subdata_gl)
subdata_gl = np.delete(subdata_gl,best_i,axis = 1)
tree[best_Label]['<'+str(best_T)] = creatTree(subdata_gl)
#InfoGain,best_Label,best_T,best_i = bestFeattosplit(subdata_gt)
subdata_gt = np.delete(subdata_gt,best_i,axis = 1)
tree[best_Label]['>'+str(best_T)] = creatTree(subdata_gt)
else:
featValues = [f[best_i] for f in data[1:]]
uniqueVals = set(featValues)
# 移除已经使用的属性,对每个属性值进行分割
for value in uniqueVals:
subdata = []
subdata.append(data[0])
for f in data[1:]:
if f[best_i] == value:
subdata.append(f)
subdata = np.delete(subdata,best_i,axis = 1)
#temp_label = best_Label
#InfoGain,best_Label,best_T,best_i = bestFeattosplit(subdata)
tree[best_Label][value] = creatTree(subdata)
return tree
def loadData():
data = []
with open('watermelon3.0.txt') as f:
for line in f.readlines():
word = line.strip().split('\t')[1:]
data.append(word)
return np.array(data)
if __name__ == '__main__':
data = loadData()
#InfoGain,Label,T = bestFeattosplit(data)
tree = creatTree(data)
中间一些部分参考了《机器学习实战》中决策树这一章节的相关代码。
该代码个人觉得有些问题,希望大家多多指正
数据如下:
编号 色泽 根蒂 敲声 纹理 脐部 触感 密度 含糖量 好瓜
1 青绿 蜷缩 浊响 清晰 凹陷 硬滑 0.697 0.46 是
2 乌黑 蜷缩 沉闷 清晰 凹陷 硬滑 0.774 0.376 是
3 乌黑 蜷缩 浊响 清晰 凹陷 硬滑 0.634 0.264 是
4 青绿 蜷缩 沉闷 清晰 凹陷 硬滑 0.608 0.318 是
5 浅白 蜷缩 浊响 清晰 凹陷 硬滑 0.556 0.215 是
6 青绿 稍蜷 浊响 清晰 稍凹 软粘 0.403 0.237 是
7 乌黑 稍蜷 浊响 稍糊 稍凹 软粘 0.481 0.149 是
8 乌黑 稍蜷 浊响 清晰 稍凹 硬滑 0.437 0.211 是
9 乌黑 稍蜷 沉闷 稍糊 稍凹 硬滑 0.666 0.091 否
10 青绿 硬挺 清脆 清晰 平坦 软粘 0.243 0.267 否
11 浅白 硬挺 清脆 模糊 平坦 硬滑 0.245 0.057 否
12 浅白 蜷缩 浊响 模糊 平坦 软粘 0.343 0.099 否
13 青绿 稍蜷 浊响 稍糊 凹陷 硬滑 0.639 0.161 否
14 浅白 稍蜷 沉闷 稍糊 凹陷 硬滑 0.657 0.198 否
15 乌黑 稍蜷 浊响 清晰 稍凹 软粘 0.36 0.37 否
16 浅白 蜷缩 浊响 模糊 平坦 硬滑 0.593 0.042 否
17 青绿 蜷缩 沉闷 稍糊 稍凹 硬滑 0.719 0.103 否