思路
类似于医生看病询问问题,通过一个个问题的询问,层层深入,来确定和定位所属的患病类型(模仿人类专家),这也就意味着无法像k-近邻算法一样进行连续型数值计算,只能进行标称型(离散型)或者可以转化为标称型的数值计算。但对比于k邻近算法,决策树可以给出数据的内在含义,数据形式易于理解,理解数据中蕴含的知识信息。
缺点:可能产生过度匹配问题(?)
代码实现
计算香农熵 以求得最佳的分类树策略 信息增益
H = -∑ p(Xi)log2(p(Xi))
from math import log
import operator
# 计算香农熵
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
# 如果标签不存在
if currentLabel not in labelCounts.keys():
# 先初始化
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob*log(prob, 2)
return shannonEnt
# 按照指定特征(axis)划分数据集
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
# 算出最大信息增益的分类方式
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0])-1 # 获取除标签外可用的特征值列数
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1 # 最好的分类对应的列
for i in range(numFeatures):
featList = [example[i] for example in dataSet] # 获得每一行的第i列,避免在源数据上操作
uniqueVals = set(featList) # 置为唯一
newEntropy = 0.0
for value in uniqueVals: # 某一列的含有的value遍历
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob*calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if(infoGain > bestInfoGain): # 信息增益大,就替换
bestInfoGain = infoGain
bestFeature = i
return bestFeature
多数表决
当标签耗尽时依旧无法分类时,选择该数据集中标签数最多的一个标签作为该数据的分类标签:
# 多数表决
# 当特征值都使用完后依然存在不相同的标签,采用多数表决,即哪类标签最多,算哪类
def majorityCnt(dataSet):
classCount = {}
for vote in dataSet: # 此时数据集中只剩下标签了,所有特征值都分类耗尽
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(
classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
构造决策树
利用上述构造的函数,递归的构造决策树
每次选择最好的分类策略
# 创建树
# labels为每一列对应的特征值名称
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet] # 获取标签
if classList.count(classList[0]) == len(classList): # 标签/类别相同
return classList[0] # 返回标签即可,不需要继续判断了
if len(dataSet[0]) == 1: # 特征耗尽但没有分完
return majorityCnt(classList) # 按照标签最多的返回
# 否则说明还有特征值可以进行分类
bestFeat = chooseBestFeatureToSplit(dataSet) # 获得最好的分类特征列
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel: {}}
# del(labels[bestFeat])
# 获得该列所有的特征值,比如:0,1,然后循环递归构造树
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues) # 使其唯一 集合
for value in uniqueVals:
subLabels = labels[:bestFeat]
subLabels.extend(labels[bestFeat+1:])
myTree[bestFeatLabel][value] = createTree(
splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
执行分类
def classify(inputTree, featLabels, testVec):
# firstStr = inputTree.keys()[0] TypeError: 'dict_keys' object is not subscriptable(可下标)
# 在python3中应该如下
firstStr = list(inputTree.keys())[0]
# print(firstStr + '1')
secondDict = inputTree[firstStr]
# print(featLabels)
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if(key == testVec[featIndex]):
# AttributeError: type object 'dict' has no attribute '_name_'
# if type(secondDict[key])._name_ == 'dict': 这是p2中的用法
# 直接这样写
# Python3中类型对象“ str ”没有“_name_”属性,所以我们需要将属性去掉
if type(secondDict[key]) == dict:
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
return classLabel
序列化树
将树序列化,避免每次分类都需要重新构造树
def storeTree(inputTree, filename):
import pickle
# 使用 fw = open(filename, 'w') 会报错 TypeError: write() argument must be str, not bytes
# 要使用二进制存入读出
fw = open(filename, 'wb')
pickle.dump(inputTree, fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename, 'rb')
return pickle.load(fr)
使用
def createDataSet():
dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [
1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels
myData, labels = createDataSet()
myTree = createTree(myData, labels)
# print(labels) # 此时的labels在构造决策树时被删掉一部分了 83行
print(myData)
print(120)
print(myTree)
print(classify(myTree, labels, [1, 0]))
storeTree(myTree, 'classFileStorage.txt')
print(grabTree('classFileStorage.txt'))
输出
no为最终结果
体会
该算法就是模仿人类进行一步步判断决策来确定哪一种类型,进行分类,但需要事先构造好的决策树,所以训练数据(经验)对于算法的准确率有很大的因素