决策树
构造决策树
1. 找到决定性特征
2. 如果某个分支下的数据属于同一类,则无需继续分类;如果分支下的数据属于不同类,则重复划分数据集。
如何划分数据集
信息增益:划分数据集前后的信息变化。
信息增益最高的特征就是最好的选择。
香农熵越高,混合的数据也越多。
示例一:区分鱼类和非鱼类
(1) 创建数据集
import numpy as np
def createDataset():
dataSet = [[1,1,'yes'],
[1,1,'yes'],
[1,0,'no'],
[0,1,'no'],
[0,1,'no']]
labels = ["no surface","flippers"]
return dataSet,labels
myDat,labels = createDataset()
(2)计算给定数据集的香农熵
import math
def calcShannonEnt(dataSet):
datSize = len(dataSet)
labelCount = {}
for sample in dataSet:
currentLabel = sample[-1]
if currentLabel not in labelCount.keys():
labelCount[currentLabel] = 0
labelCount[currentLabel] += 1
shannonEnd = 0
for key in labelCount:
prob = labelCount[key]/datSize
shannonEnd -= prob * math.log(prob,2)
return shannonEnd
calcShannonEnt(myDat)
(3) 划分数据集
def splitDataSet(dataSet,feature,value):
'''
feature - 划分数据集的特征,为数据集对应该特征的index
value - feature的值
'''
retDataSet = []
for sample in dataSet:
if sample[feature] == value:
reducesample = sample[:feature]
reducesample.extend(sample[feature+1:])
retDataSet.append(reducesample)
return retDataSet
splitDataSet(myDat,0,0)
(4) 选择最好的数据集划分方式
选择划分的best feature:信息增益最高
def chooseBestFeatureToSplit(dataSet):
numFeature = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0
bestFeature = -1
for i in range(numFeature):
featureList = [example[i] for example in dataSet]
uniqueList = set(featureList)
newEntropy = 0.0
for value in uniqueList:
subDataSet = splitDataSet(dataSet,i,value)
prob = len(subDataSet)/len(dataSet)
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
chooseBestFeatureToSplit(myDat)
(5) 创建树
import operator
def majorityCnt(classList):
#返回出现次数最多的分类名称
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse = True)
return sortedClassCount[0][0]
def createTree(dataSet,labels):
#类别相同停止划分;否则返回出现次数最多的类别
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
return myTree
createTree(myDat,labels)
实例二:分类器的测试与存储
测试:使用决策树对新数据进行分类。
def classify(inputTree,featLabels,testVec):
firstFea = list(myTree.keys())[0]
firstIndex = featLabels.index(firstFea)
secondDict = inputTree[firstFea]
for key in secondDict.keys():
if testVec[firstIndex] == key:
if type(secondDict[key]) == type({}):
classLabel = classify(secondDict[key],featLabels,testVec)
else:
classLabel = secondDict[key]
return classLabel
存储:pickle模块
def storeTree(inputTree,filename):
import pickle
fw = open(filename,'w')
pickle.dump(inputTree,fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)