第四次作业
题目:按照 ID3 算法推出最终的决策树。可程序实现;也可以手工推出(要有手动推出过程)
关于ID3算法:老师上课有简单介绍过,这里推荐大家去看这个分类算法——决策树ID3算法,容易理解与上手。
解题过程:
在上面的样本中,属于yes的结果有6个,no有6个,于是可以算出训练集的熵为:
E
(
S
)
=
−
6
12
log
2
6
12
−
6
12
log
2
6
12
=
1.0
E(S)=-\frac{6}{12}\log_2\frac{6}{12}-\frac{6}{12}\log_2\frac{6}{12}=1.0
E(S)=−126log2126−126log2126=1.0
下面对各个属性计算对应的信息增益。
是否有其他选择 | yes/熵 | no/增益 |
---|
是 | 2 | 4 |
否 | 4 | 2 |
熵 | 0.9182 | 0.0818 |
饿否 | yes/熵 | no/增益 |
---|
是 | 5 | 2 |
否 | 1 | 4 |
熵 | 0.8042 | 0.1958 |
价格 | yes/熵 | no/增益 |
---|
低 | 3 | 4 |
中 | 2 | 0 |
贵 | 1 | 2 |
熵 | 0.8042 | 0.1958 |
餐馆类型 | yes/熵 | no/增益 |
---|
中式 | 2 | 2 |
意大利式 | 1 | 1 |
法式 | 1 | 1 |
快餐 | 2 | 2 |
熵 | 1.0 | 0.0 |
顾客人数 | yes/熵 | no/增益 |
---|
无人 | 0 | 2 |
有人 | 4 | 0 |
客满 | 2 | 4 |
熵 | 0.4591 | 0.5409 |
等待时间 | yes/熵 | no/增益 |
---|
0-10 | 4 | 2 |
10-30 | 1 | 1 |
30-60 | 1 | 1 |
>60 | 0 | 2 |
熵 | 0.7924 | 0.2076 |
从上面可以看出顾客人数的信息增益最大,所以选择顾客人数作为根节点的测试属性,餐馆类型的信息增益为0,不能做出任何分类信息,产生第一次决策树,然后对每个叶节点再次利用上面的过程,生成最终的决策树。
具体实现:
语言:Python 工具:PyCharm
源代码:(仅供参考,水平有限,有错请指出)
import math as m
def createDataSet():
dataSet = [[1, 1, 2, 2, 1, 0, 'yes'],
[1, 1, 0, 0, 2, 2, 'no'],
[0, 0, 0, 3, 1, 0, 'yes'],
[1, 1, 0, 0, 2, 1, 'yes'],
[1, 0, 2, 2, 2, 3, 'no'],
[0, 1, 1, 1, 1, 0, 'yes'],
[0, 0, 0, 3, 0, 0, 'no'],
[0, 1, 1, 0, 1, 0, 'yes'],
[0, 0, 0, 3, 2, 3, 'no'],
[1, 1, 2, 1, 2, 1, 'no'],
[1, 0, 0, 0, 0, 0, 'no'],
[0, 1, 0, 3, 2, 2, 'yes'],
]
return dataSet
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLable = featVec[-1]
if currentLable not in labelCounts.keys():
labelCounts[currentLable] = 0
labelCounts[currentLable] += 1
Ent = 0.0
for feat in labelCounts:
prob = float(labelCounts[feat]) / numEntries
Ent -= prob * m.log(prob, 2)
return Ent
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
baseEntropy = calcShannonEnt(dataSet)
numFeatures = len(dataSet[0]) - 1
bestInfoGain = 0.0
bestFeature = 0
for i in range(0, numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntorpy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntorpy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntorpy
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel: {}}
del (labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
subDataSet = splitDataSet(dataSet, bestFeat, value)
myTree[bestFeatLabel][value] = createTree(subDataSet, subLabels)
return myTree
if __name__ == '__main__':
dataSet = createDataSet()
labels = ['choice', 'hungry', 'price', 'types', 'people', 'waitmin']
labelsForCreateTree = labels[:]
Tree = createTree(dataSet, labelsForCreateTree)
print(Tree)