《统计学习方法》第五章`
import math
import operator
def createDataSet():
labels = ['年龄', "有工作", "有自己的房子", "信贷情况"]
dataSet = [["青年", "否", "否", "一般", "否"],
["青年", "否", "否", "好", "否"],
["青年", "是", "否", "好", "是"],
["青年", "是", "是", "一般", "是"],
["青年", "否", "否", "一般", "否"],
["中年", "否", "否", "一般", "否"],
["中年", "否", "否", "好", "否"],
["中年", "是", "是", "好", "是"],
["中年", "否", "是", "非常好", "是"],
["中年", "否", "是", "非常好", "是"],
["老年", "否", "是", "非常好", "是"],
["老年", "否", "是", "好", "是"],
["老年", "是", "否", "好", "是"],
["老年", "是", "否", "非常好", "是"],
["老年", "否", "否", "一般", "否"], ]
return dataSet, labels
def entropy(dataSet):
n = len(dataSet)
labelCounts = {}
for record in dataSet:
cur = record[-1]
if cur not in labelCounts.keys():
labelCounts[cur] = 0
labelCounts[cur] += 1
Entropy = 0
for key in labelCounts:
temp = float(labelCounts[key]) / n
Entropy -= temp * math.log(temp, 2)
return Entropy
def majorityCnt(classList):
# 选出最多的项目
classCount = {}
for decision in classList:
if decision not in classCount.keys():
classCount[decision] = 0
classCount[decision] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def splitDataSet(dataSet, axis, value):
# 某个特征下的子集
res = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis + 1:])
res.append(reducedFeatVec)
return res
def chooseBestFeat(dataSet):
numFeatures = len(dataSet[0]) - 1
baseEntropy = entropy(dataSet)
bestInfoGain = 0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * entropy(subDataSet)
infoGain = baseEntropy - newEntropy
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
# 全是一样的情况
if (len(dataSet[0])) == 1:
return majorityCnt(classList)
# 特征太少,特殊处理
bestFeat = chooseBestFeat(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel: {}}
del (labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
if __name__ == "__main__":
dataSet, labels = createDataSet()
print(createTree(dataSet,labels))
得到决策树结果:
{‘有自己的房子’: {‘否’: {‘有工作’: {‘否’: ‘否’, ‘是’: ‘是’}}, ‘是’: ‘是’}}
将结果可视化:
决策树代码参考自:https://blog.csdn.net/csqazwsxedc/article/details/65697652
可视化代码参考自:https://blog.csdn.net/u012421852/article/details/79801466