数据集:
ID3算法:
ID3算法是以信息熵和信息增益为衡量标准,从而实现对数据的归纳分类的一种算法。
首先,ID3算法需要解决的问题是如何选择特征作为划分数据集的标准。在ID3算法中,选择信息增益最大的属性作为当前的特征对数据集分类。
其次,ID3算法需要解决的问题是如何判断划分的结束。分为两种情况,第一种为划分出来的类属于同一个类,第二种为已经没有属性可供再分了。此时就结束了。
通过递归的方式,得到ID3决策树模型,它是局部最优的且仅输出单个目标,并近似的喜好“最短的”决策树。
信息熵与信息增益:
信息熵(information entropy):
信息熵是度量样本集合纯度最常用的一种指标。假定当前的样本集合D中第k类样本所占比例为,则D的信息熵定义为:
由信息熵定义可知熵越大,训练集D中的样本的类别越不纯。
信息增益(information gain):
假定离散属性a有V个可能的取值,若使用a来对样本集D进行划分,于是可计算出用属性a对样本集D进行划分所获得的“信息增益”,公式如下:
由信息增益的公式可知,信息增益越大,意味着使用属性a来进行划分所获得的“纯度提升”越大。
算法流程:
创建决策树的根节点root
若所有的样本均为正例,返回单个根节点root并标记为“+”
若所有的样本均为反例,返回单个根节点root并标记为“-”
若属性值为空,返回单个根节点root并标记为大部分样本的标签
其他情况下:
选取当前例子下最大信息增益的属性A作为分类属性,令A指向根节点
For each A中的属性值Vi
增加一个关于Vi的分支
For each符合A=Vi的样本ExampleVi
若ExampleVi为空
投票选出满足其父节点训练集的大部分标签
否则
递归产生决策树
代码:
# -*- coding: utf-8 -*-
"""
Created on Thu Apr 5 13:36:19 2018
@author: 安颖
"""
import numpy as np
from math import log
import draw_tree
#定义属性值
Outlook = ["Sunny","Overcast","Rain"]
Temperature = ["Hot","Mild","Cool"]
Humidity = ["High","Normal"]
Wind = ["Strong","Weak"]
Attri=[]
Attri.append(Outlook)
Attri.append(Temperature)
Attri.append(Humidity)
Attri.append(Wind)
#数据集
my_data = []
with open('data.txt', 'r') as data_txt:
data = data_txt.readlines()
for line in data:
temp = line.split(',')
my_data.append([temp[0],temp[1],temp[2],temp[3],int(temp[4])])
my_data = np.array(my_data)
data_txt.close()
#记录不同属性值的label比值
def count_label(data):
#初始化属性组
Anum = [[[0,0],[0,0],[0,0]],[[0,0],[0,0],[0,0]],[[0,0],[0,0]],[[0,0],[0,0]]]
#进行计数
for i in range(len(data)):
#print(str(i)+":"+str(data[i]))
for j in range(len(Attri)):
for k in range(len(Attri[j])):
if data[i][j]== Attri[j][k]:
if int(data[i][4]) == 1:
Anum[j][k][0] += 1
else:
Anum[j][k][1] += 1
break
return Anum
#计算信息增益
def cal_gain(data,label,targe_attr):
#取类标签集合
classList=[example[4] for example in data]
#计算信息熵
ent_s = 0.0
if len(data)!=0 and classList.count('0') != 0:
prob = (float)(classList.count('0')/len(data))
ent_s -= prob*log(prob,2)
if len(data)!=0 and classList.count('1') != 0:
prob = (float)(classList.count('1')/len(data))
ent_s -= prob*log(prob,2)
#计算增益
gain = 0.0
for i in range(len(Attri[targe_attr])):
ent = 0.0
sum = label[targe_attr][i][0] + label[targe_attr][i][1]
if sum == 0:
continue
for j in range(2):
prob = (float)(label[targe_attr][i][j]/sum)
#p=0时认为熵为0
if prob == 0.0:
continue
else:
ent -= prob*log(prob,2)
gain -= ent*(sum/len(data))
gain += ent_s
return gain
#选择信息增益最大的属性
def choose_attr(data,label,attr):
max = [0]*len(label)
#print("data:"+str(len(data)))
for i in range(len(label)):
if attr[i] != 'selectedAtt':
max[i] = cal_gain(data,label,i)
else:
max[i] = -1
temp = np.argsort(max)
#全部选完
if attr[temp[3]] == 'selectedAtt':
return -1
else:
return temp[3]
#判断属性值是否为空
def pdnull(attr):
flag = 0
for i in attr:
if i != 'selectedAtt' :
flag = 1
break
return flag
#投票选出大多数标签
def vote(classList):
if classList.count('1') > classList.count('0'):
flag = '1'
else:
flag = '0'
return flag
#ID3构建树算法
def id3_tree(examples,targe_attr,attr):
#取类标签集合
classList=[example[4] for example in examples]
#若全为正例,返回正标签
if classList.count('1')==len(classList):
myTree = '1'
return myTree
#若全为反例,返回负标签
if classList.count('0')==len(classList):
myTree = '0'
return myTree
#若属性值为空,返回单个根节点root并标记为大部分样本的标签
if len(attr)==0:
myTree = vote(classList)
return myTree
#否则进行递归建树
else:
AttributeLabel = attr[targe_attr]
myTree={AttributeLabel:{}}
#已用过的属性进行标记
#attr[targe_attr] = 'selectedAtt'
#该属性值列表
featValues=[example[targe_attr] for example in examples]
for i in range(len(Attri[targe_attr])):
#若没有属性则投票获得大多数票数的标签值
attr_value = Attri[targe_attr][i]
#若没有当前属性值的数据则进行投票
if featValues.count(attr_value)==0 :
myTree[AttributeLabel][attr_value] = vote(classList)
else:
sub_examples=[]
for j in range(len(examples)):
if featValues[j]==Attri[targe_attr][i]:
sub_examples.append(examples[j])
#若子属性还有数据值,递归
#if len(sub_examples) > 0:
sub_attr = attr[:]
#已用过的属性进行标记
sub_attr[targe_attr] = 'selectedAtt'
sub_label = count_label(sub_examples)
sub_targe_attr = choose_attr(sub_examples,sub_label,sub_attr)
if sub_targe_attr != -1:
myTree[AttributeLabel][attr_value] = id3_tree(sub_examples,sub_targe_attr,sub_attr)
return myTree
if __name__ == '__main__':
#Att数组记录不同属性
Att = ["Outlook","Temperature","Humidity","Wind"]
#选择第一个合适的属性值
label = count_label(my_data)
targe_attr = choose_attr(my_data,label,Att)
tree = id3_tree(my_data,targe_attr,Att)
print(tree)
draw_tree.createPlot(tree)
#test2
data = ['Rain','Cool','Normal','Strong','0']
print("加入新样本"+str(data)+"后:")
new_data = my_data[:,:]
#加入新样本
new_data = np.row_stack((new_data,data))
#选择第一个合适的属性值
new_label = count_label(new_data)
targe_attr = choose_attr(new_data,new_label,Att)
new_tree = id3_tree(my_data,targe_attr,Att)
print(new_tree)
draw_tree.createPlot(new_tree)
画决策树:
# -*- coding: utf-8 -*-
"""
Created on Thu Apr 5 11:59:17 2018
@author: 安颖
"""
import matplotlib.pyplot as plt
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")
#计算树的叶子节点数量
def getNumLeafs(myTree):
numLeafs=0
firstSides = list(myTree.keys())
firstStr=firstSides[0]
secondDict=myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
numLeafs+=getNumLeafs(secondDict[key])
else: numLeafs+=1
return numLeafs
#计算树的最大深度
def getTreeDepth(myTree):
maxDepth=0
firstSides = list(myTree.keys())
firstStr=firstSides[0]
secondDict=myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth=1+getTreeDepth(secondDict[key])
else: thisDepth=1
if thisDepth>maxDepth:
maxDepth=thisDepth
return maxDepth
#画节点
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\
xytext=centerPt,textcoords='axes fraction',va="center", ha="center",\
bbox=nodeType,arrowprops=arrow_args)
#画箭头上的文字
def plotMidText(cntrPt,parentPt,txtString):
lens=len(txtString)
xMid=(parentPt[0]+cntrPt[0])/2.0-lens*0.002
yMid=(parentPt[1]+cntrPt[1])/2.0
createPlot.ax1.text(xMid,yMid,txtString)
def plotTree(myTree,parentPt,nodeTxt):
numLeafs=getNumLeafs(myTree)
depth=getTreeDepth(myTree)
firstSides = list(myTree.keys())
firstStr=firstSides[0]
cntrPt=(plotTree.x0ff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.y0ff)
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict=myTree[firstStr]
plotTree.y0ff=plotTree.y0ff-1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.x0ff=plotTree.x0ff+1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),cntrPt,leafNode)
plotMidText((plotTree.x0ff,plotTree.y0ff),cntrPt,str(key))
plotTree.y0ff=plotTree.y0ff+1.0/plotTree.totalD
def createPlot(inTree):
fig=plt.figure(1,facecolor='white')
fig.clf()
axprops=dict(xticks=[],yticks=[])
createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
plotTree.totalW=float(getNumLeafs(inTree))
plotTree.totalD=float(getTreeDepth(inTree))
plotTree.x0ff=-0.5/plotTree.totalW
plotTree.y0ff=1.0
plotTree(inTree,(0.5,1.0),'')
plt.show()