ID3算法的完整程序:
from math import log
import operator
def createDataSet():
dataSet = [[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
features = ['no surfacing','flippers']
return dataSet,features
#计算熵
def calcShannonEnt(dataset):
numEntries = len(dataset)#数据的长度
labelCounts = {}#标签
for featVec in dataset:
currentLabel = featVec[-1]#当前的标签
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
if prob != 0:
shannonEnt -= prob*log(prob,2)
return shannonEnt
def splitDataSet(dataset,feat,values):
retDataSet = []
for featVec in dataset:
if featVec[feat] == values:
reducedFeatVec = featVec[:feat]
reducedFeatVec.extend(featVec[feat+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def classify(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritems(),key = operator.itemgetter(1),reverse = True)
return sortedClassCount[0][0]
#寻找用于分裂的最佳属性
def findBestSplit(dataset):
numFeatures = len(dataset[0])-1 #特性的数量
baseEntropy = calcShannonEnt(dataset)#基本熵值
bestInfoGain = 0.0
bestFeat = -1
for i in range(numFeatures):
featValues = [example[i] for example in dataset]
uniqueFeatValues = set(featValues)
newEntropy = 0.0
for val in uniqueFeatValues:
subDataSet = splitDataSet(dataset,i,val)
prob = len(subDataSet)/float(len(dataset))
newEntropy +=prob*calcShannonEnt(subDataSet)
if(baseEntropy - newEntropy)>bestInfoGain:#比较信息增益,取最大的
bestInfoGain = baseEntropy - newEntropy
bestFeat = i
return bestFeat
def treeGrowth(dataSet,features):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):#只有一个属性
return classList[0]
if len(dataSet[0])==1:
return classify(classList)
bestFeat = findBestSplit(dataSet)
bestFeatLabel = features[bestFeat]
myTree = {bestFeatLabel:{}}
featValues = [example[bestFeat] for example in dataSet]
uniqueFeatValues = set(featValues)
del(features[bestFeat])
for values in uniqueFeatValues:
subDataSet = splitDataSet(dataSet,bestFeat,values)
myTree[bestFeatLabel][values] = treeGrowth(subDataSet,features)
print myTree
features.insert(bestFeat,bestFeatLabel)
return myTree
在上面的程序中calcShannonEnt()是用来计算熵的,熵越大则越没有规则,熵越小则数据呈现出一定的规则;
H(t)=-∑p(i|t)log₂p(i|t)(熵的计算公式)
而findBestSplit()就是寻找信息增益最大的一种分裂:
Δ=I(parent) - Σ(N(v)/N)I(v)(信息增益的公式)
叶结点确定函数classify(),将叶结点指派到具有多数记录的类:
leaf.label = max(p(i|t))
(p(i|t)表示该结点上属于类i的训练记录所占的比例)
在形成决策树的过程中用到的迭代:
myTree[bestFeatLabel][values] = treeGrowth(subDataSet,features)
用print打印出迭代的过程是:
{'no surfacing': {0: 'no'}}
{'flippers': {0: 'no'}}
{'flippers': {0: 'no', 1: 'yes'}}
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
最终的实现结果是:
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
这个ID3算法可以归纳为以下几点:
- 使用所有没有使用的属性并计算与之相关的样本熵值
- 选取其中熵值最小的属性
- 生成包含该属性的节点
优点:实现比较简单,产生的规则如果用图表示出来的话,清晰易懂,分类效果好
缺点:只能处理属性值离散的情况,算法中是二元划分(连续的用C4.5),在选择最佳分离属性的时候容易选择那些属性值多的一些属性。