一.决策树模型
决策树是一种通过对特征属性属性分类对样本进行分类的树形结构,包括边和三类节点
- 根节点:决策树的起源,进行分类的第一个特征属性,只有出边没有入边;
- 内部节点:正在进行分类的特征属性,有一条入边,至少有一条出边;
- 叶节点:分类结束的特征属性,有入边,没有出边;
上面就是一个简单的二叉决策树
二.决策树学习
决策树学习本质就是对样本总结一个分类规则,既能对已知样本进行合理的拟合分类,也能对未知样本进行正确的分类预测,这其中就有两个关键:
- 选择哪个特征属性进行分类?
- 到什么时候停止分类?
三.决策树算法
- 特征属性选择
在ID3中,特征属性的选择是由目标函数决定的,目标函数代表的是特征属性的混乱程度(也就是特征属性越混乱越不好分类,该特征属性的分类顺序越靠后),这个目标函数就是信息增益,信息增益是由熵计算出来的:
熵:Entropy(t)=−∑kp(ck|t)logp(ck|t)
信息增益:
Δ=H(c)−∑i=1nN(ai)NH(c|ai)
=H(c)−∑inp(ai)H(c|ai)
=H(c)−H(c|A)
其中H(c)表示父节点的熵值,H(c|A)表示该父节点下特征属性的熵的加权和
- 决策树生成
1.如果节点满足停止分裂的条件,则将其设为叶节点;
2.如果节点不满足停止分裂的条件,则选择信息增益最大的属性分裂;
3.重复1-2动作直到分类完成
C4.5与ID3的区别就是属性的选择条件为信息增益比
四.算法实现
# -*- coding=utf-8 -*-
"""
author:xuwf
created:2017-2-10
purpose:实现决策树算法C4.5
"""
from numpy import *
'''装载文本'''
def load():
data=loadtxt('d:/myProject/data/destree.txt',delimiter=',',dtype=str)
return data
'''熵计算'''
def calcShan(dataset):
rows=dataset.shape[0]
shannon=0.0 #熵值
dataCount=dataInfo(dataset)
for i in dataCount:
temp=dataCount[i]/float(rows) #计算概率
shannon-=temp*(log2(dataCount[i])-log2(rows)) #计算熵
return shannon
'''统计数据'''
def dataInfo(dataset):
info={}
for i in dataset:
info[i]=info.get(i,0)+1 #计算每个数据出现的次数(dict.get方法获取)
return info
'''最大增益比计算'''
def calcGain(dataset):
rows,cols=dataset.shape
shan=calcShan(dataset[:,cols-1]) #计算总熵
shanPro=zeros(cols-1) #每个属性的熵增益比
for i in range(cols-1):
dataCount=dataInfo(dataset[:,i])
for j in dataCount.keys():
temp=dataCount[j]/float(rows)
shanPro[i]+=temp*calcShan(dataset[dataset[:,i]==j,cols-1])
shanPro[i]=(shan-shanPro[i])/shan
return argsort(-shanPro)[0]
'''计算分支熵'''
def calcShanBranch(dataset,i):
rows,cols=dataset.shape
shanBranch={}
dataCount=dataInfo(dataset[:,i])
for j in dataCount.keys():
shanBranch[j]=calcShan(dataset[dataset[:,i]==j,cols-1])
return shanBranch
'''建树干'''
def buildTree(dataset,label,tree):
sortRatio=calcGain(dataset) #找出熵增益最大的那个下标,从这个下标开始分类
shanBranch=calcShanBranch(dataset,sortRatio) #计算分支的熵,判断是否为同一类
tree={label[sortRatio]:{}}
for i in shanBranch.keys():
dataNew=dataset[dataset[:,sortRatio]==i,:]
dataNew=dataNew[:,dataset[0,:]!=dataset[0,sortRatio]]
labelNew=label[label!=label[sortRatio]]
if shanBranch[i]==0:
#如果分支的熵为0,说明该分支下的类别为同一类,该分支分类完成
tree[label[sortRatio]][i]=dataset[dataset[:,sortRatio]==i,-1][0]
else:
if label.shape[0]==2:
#如果到了最后的一组就返回数据
return tree
else:
#没有到最后一组就继续循环
tree[label[sortRatio]][i]=buildTree(dataNew,labelNew,tree)
return tree
'''主函数'''
def main():
data=load()
label=data[0,:]
data=data[1:,:]
tree={}
tree=buildTree(data,label,tree)
return tree
if __name__=='__main__':
main()
五.画图生成决策树
# -*- coding=utf-8 -*-
"""
author:wfxu
create:2017-02-22
purpose:实现plt画图
"""
import matplotlib.pyplot as plt
import destree
def cyc(tree,x,y,coord,ax):
for i in tree:
x1=x #起点坐标
y1=y
ax.text(x1,y1,i,fontsize=20,va='center',ha='center',bbox=dict(facecolor='red', alpha=0.5)) #起点
x2=x-1 #终点坐标
y2=y-1
for j in tree[i]:
if (x2,y2) in coord:
x2+=1
ax.text((x1+x2)/2.0,(y1+y2)/2.0,j,fontsize=10,va='center',ha='center') #中间点
ax.arrow(x1,y1,x2-x1,y2-y1,width=0.05,length_includes_head=True,overhang=0.3,head_length=0.1)
#print (j,(x1+x2)/2.0,(y1+y2)/2.0),(tree[i][j],x2,y2)
x2+=1
if tree[i][j] in ('yes','no'):
#判断有没有属性点重合的,如果有就向右移动一位
coord.append((x2-1,y2))
ax.text(x2-1,y2,tree[i][j],fontsize=20,va='center',ha='center',bbox=dict(facecolor='blue', alpha=0.5)) #终点
else:
cyc(tree[i][j],x2-1,y2,coord,ax)
def main():
fig=plt.figure()
ax=fig.add_subplot(111)
ax.plot((5,1.5),(3.5,0.5),'b+')
tree=destree.main()
x=3
y=3
coord=[]
cyc(tree,x,y,coord,ax)
fig.show()
if __name__=='__main__':
main()
决策树效果
用plt画的图:
六.总结
1.这次用的数据比较单一,所以最后的结果都是“纯净的”,如果样本空间很大,那么最后的分裂停止条件就应该是一个阈值(比如80%以上为yes就视为分裂完成),这样的结果就是预测结果会有偏差
2.如果预测的结果又偏差,就需要进行剪枝,这里并没有剪枝,希望下次可以拿一个大的样本进行测试
3.这里画图是在坐标轴上画的,看起来很压抑,希望下次有方法可以在白画布上画出决策树
如果大家有什么更好的方法或意见请分享至下面的留言!