原理:
- 决策树其实就是一个if-else集合,通过一系列的分支将原始数据划分成最后的若干类。
- 每个节点都是一个数据集合,对这个数据集合选择一个最优的划分特征,根据该特征来对集合进行划分,生成若干子集合。
- 选择最优特征用到了信息学中香农熵的概念,其中,信息增益=原始熵-划分之后的熵。
- 递归地构建决策树。
注意点:
- 本文所用的决策树要求数据集为Python中的列表变量。
- 每一条实例数据的最后一列元素为当前实例的类别标签。
- 最后生成的决策树为字典变量。
- 代码基于《机器学习实战》一书,本文只实现了决策树的构建,没有考虑剪枝,实际会存在过拟合的问题。
代码:
from math import *
import operator
def createDataSet():
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels
def calShannonEnt(dataSet) :
numEntries = len(dataSet)
labelsCount = {}
for featVec in dataSet:
labelsName = featVec[-1]
if (labelsName not in labelsCount.keys()):
labelsCount[labelsName] = 0
labelsCount[labelsName] += 1
shannonEnt = 0.0
for key in labelsCount:
prob = labelsCount[key] * 1.0 / numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
sumFeatures = len(dataSet[0]) - 1
baseEntropy = calShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range(sumFeatures):
featList = [x[i] for x in dataSet]
uniqueValue = set(featList)
newEntropy = 0.0
for value in uniqueValue:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) * 1.0 / len(dataSet)
newEntropy += prob * calShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if infoGain > bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
return bestFeature
def mayorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet, labels):
classList = [x[-1] for x in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1:
return mayorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel : {}}
del(labels[bestFeat])
featValues = [x[bestFeat] for x in dataSet]
uniqueFeatValues = set(featValues)
for value in uniqueFeatValues:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
myData, labels = createDataSet()
featLabels = labels
print myData
tree = createTree(myData, labels)
print tree
def classify(tree, labels, testVec):
firstStr = tree.keys()[0]
secondDict = tree[firstStr]
featIndex = labels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], labels, testVec)
else: classLabel = secondDict[key]
return classLabel
testVec = [1, 1]
print testVec
print classify(tree, featLabels, testVec)
def storeTree(tree, fileName):
import pickle
fw = open(fileName, 'w')
pickle.dump(tree, fw)
fw.close()
def grabTree(fileName):
import pickle
fr = open(fileName)
return pickle.load(fr)
storeTree(tree, 'hahaha.txt')
print grabTree('hahaha.txt')