参考教程:https://github.com/apachecn/AiLearning/ 讲解的很详细。
下面的篇幅包含我本人在学习期间学习的代码(参考了部分上述教程)
Decision Tree
这也是一个比较简单的机器学习算法,用于分类或回归(此处考虑的是分类)。其原理总结起来就是:根据训练数据构建一个用于分类的多叉树,对于测试数据,直接按照构造的多叉树来走到叶子节点(叶子节点表示类别信息)。
所以现在分为下面几点介绍:
- 结构
决策树是一个类似于流程图的树结构:可以从图中看到,有三个要素:内部节点、分支,叶子节点,分别表示某个 f e a t u r e feature feature、 v a l u e value value、 l a b e l label label。
- 构建算法
每次选取当前最好的特征作为根结点来划分数据集,分支数量为该特征的可取值数量,对于每个分支同样递归处理。
那递归出口是什么呢?
- 情况一:如果当前分支下的所有样本的 l a b e l label label都相同了,那该分支接下来是值为 l a b e l label label的叶节点。
- 情况二:虽然当前分支下所有样本的
l
a
b
e
l
label
label没有一致,但已经没有特征可以继续用于划分了,那该分支接下来是值为
m
_
l
a
b
e
l
m\_label
m_label的叶节点(该分支下所有样本属于
m
_
l
a
b
e
l
m\_label
m_label是最多的)。
好像有点拗口……所以写了一段伪代码解释一下:
create(data):
#情况一
if label(data) is one value:#suppose label(data) denotes the vector that is composed by the labels of samples in data.
return label(data)[0]
#情况二
if data.column==1:
return most(label(data))
#除了情况二和情况三,就递归创建决策树
bestFeature=chooseBestFeature(data)
resultTree={bestFeature:{}}
for value in data[bestFeature]:
subData=splitdata(data,bestFeature,value)
resultTree[bestFeature][value]=create(subData)
所以综合来看,构建决策树只有一个最模糊也是最重要的一个要素:如何选择最好的特征?该算法认为能包含最多信息的特征就是最好的特征,在划分数据集前后信息发生的变化称为信息增益,获得信息增益最高的特征就是最好的选择。
- 信息增益与香农熵
香农熵计算公式:
H = − ∑ i = 1 n p ( x i ) ∗ l o g ( p ( x i ) ) H= -\sum_{i=1}^np(x_i)*log(p(x_i)) H=−i=1∑np(xi)∗log(p(xi))
其中 p ( x i ) p(x_i) p(xi)表示属于分类 i i i的概率。熵越大,随机变量的不确定性就越大。
信息增益则是选取特定特征后,该特征的取值的条件下,信息熵的变化(我个人的理解)。在《统计学习方法》中,也有该对信息增益的定义:
具体实现可以看下面的代码哦~
import numpy as np
import math
import operator
import decisionTreePlot as dtPlot
def ShannonEnt(data):
#计算香农熵的函数
total_num=len(data)
#统计每个label的个数
label_to_num={}
for i in range(total_num):
label=data[i][-1]
label_to_num[label]=label_to_num.get(label,0)+1
result=0
for key in label_to_num:
pkey=label_to_num[key]/total_num #p(xi)
log2pkey=-math.log(pkey,2)#log2 p(xi)
result+=log2pkey*pkey
return result
def splitDataSet(data,index,value):
#划分数据集的函数 取出data中index列的值为value的数据组成一个subdataset
#由于在决策树的算法中,对某个属性进行划分数据集后,该属性不会在后续划分数据集中起到作用
#所以这里就直接将subdataset中该列去除掉
subdataset=[]
for i in data:
if i[index]==value:
vector=i[:index]
vector.extend(i[index+1:])
subdataset.append(vector)
return subdataset
def chooseBestFeature(data):
#选择划分数据集的最佳特征
feature_num=len(data[0])-1#因为最后一列是label,不是特征
maxShannonInc=0
old=ShannonEnt(data)
beatFeature=-1
num_list=len(data)
for i in range(feature_num):
#对每个特征进行划分数据集,然后计算划分前后数据集的香农熵变化程度
feature_list=[ff[i] for ff in data]
feature_list=set(feature_list)
new=0
#计算按照某个特征划分后的熵
#https://blog.csdn.net/It_BeeCoder/article/details/79554388
for value in feature_list:
subDataSet=splitDataSet(data,i,value)
pvalue=len(subDataSet)/num_list
new+=pvalue*ShannonEnt(subDataSet)
if old-new>maxShannonInc:
maxShannonInc=old-new
beatFeature=i
return beatFeature
def createTree(data,label):
#用递归的方法构建决策树
label_list=[ i[-1] for i in data]
flag=0
#如果所有样本的label都相同
for i in label_list:
if i!=label_list[0]:
flag=1
if flag==0:
return label_list[0]
#如果只剩下一个feature了,然而数据集仍然不止一类
if(len(data[0])==1):
#取出现次数最多的Label作为该分支下的Label
label_num={}
for i in label_list:
label_num[i]=label_num.get(i,0)+1
sortlabel=sorted(label_num.items(),key=operator.itemgetter(1),reverse=True)
#对label_num以第二维(num)进行降序
return sortlabel[0][0]
bestFeature=chooseBestFeature(data)
valuelist=[i[bestFeature] for i in data]
retTree={label[bestFeature]:{}}
for value in valuelist:
subdata=splitDataSet(data,bestFeature,value)
retTree[label[bestFeature]][value]=createTree(subdata,label)
return retTree
def classify(tree,label,testdata):
#获取tree的根结点
root = list(tree.keys())[0]
label_index=label.index(root)
cur_value=testdata[label_index]
branches=tree[root]
nexttree=branches[cur_value]
#print(type(nexttree))
if isinstance(nexttree, dict):
label_=classify(nexttree,label,testdata)
else:
label_=nexttree
return label_
if __name__=='__main__':
data=open('data/机器学习/3.DecisionTree/lenses.txt')
#数据集介绍:包含24个样本,每个样本有4个属性和一个label,label在最后一列
data=[i.strip().split('\t') for i in data.readlines()]
#print(splitDataSet(data,2,'no'))
chooseBestFeature(data)
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree=createTree(data,lensesLabels)
print(classify(lensesTree, lensesLabels, data[0]))
print(classify(lensesTree,lensesLabels,data[1]))
dtPlot.createPlot(lensesTree)