决策树的优缺点
优点:
1.计算复杂度不高(对比KNN),顾运算较快
2.结果容易可视化(即书中可视化部分的代码)
3.对缺失值不敏感,能处理不相关特征的数据
4.适合处理数值型和标称型数据(什么是数值型和标称型?:https://www.jianshu.com/p/500c2918723f)
缺点:
1.不支持在线学习。即在新样本导入的时候,需要重建决策树。
2.容易过拟合。但是决策森林可以有效减少过拟合.
决策树的思想
通过不断计算数据集的熵(即数据的无序程度,熵越高越无序)来划分数据集,达到出口后停止分隔,最终得到一棵决策树。
熵的计算公式:见书
决策树的构造流程(递归)
1.递归出口:
所有类标签完全相同,或用完了所有特征,也无法将数据集划分为仅包含唯一类别的的分组
2.步骤一:
遍历特征,找出最佳分隔数据集的特征
3.步骤二:
根据该特征的各种取值,建立子节点。(如当前最佳分隔特征为特征0, 其取值可能为1, 2, 3,则建立三个子节点)
步骤三:
直到递归出口。需要注意的是,如果遇到了“用完了所有特征,也无法将数据集划分为仅包含唯一类别的的分组”的情况,就取出现次数最多的类别。
决策树的构建(代码)
计算香农熵
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob*log(prob, 2)
return shannonEnt
该代码块比较容易理解,即计算数据集中每个分类结果出现的概率,再根据计算熵的公式进行计算
划分数据集
由上述 决策树的过程 得知,我们需要划分数据集,根据划分后的熵来确定最佳分隔。以下是划分数据集的代码
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
该函数有三个参数,数据集, 划分数据在数据集中的位置,划分数据的值
通俗理解就是,找到数据集中第axis个特征值为value的部分,并删除第axis个特征,得到一个新的数据集
例如数据集[[1,1,‘yes’], [1,1,‘yes’], [1,0,‘no’],[0,1,‘no’]]
我们运行splitDataSet(dataset, 1, 1)
得到的结果为[[1, ‘yes’], [1, yes], [0, no]]
选择最好的划分方式的函数
def chooseBestFeatureToSplit(dataSet):
numFeature = len(dataSet[0])-1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeature):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)#新数据集长短不一,需要分权求和
infoGain = baseEntropy - newEntropy#信息增益,即熵减
if infoGain > bestInfoGain:# 即新计算出来的熵小于原数据的熵,即无序度减少
bestInfoGain = infoGain
bestFeature = i
return bestFeature
遍历所有特征划分数据集,对每种特征划分后的数据集分权求和,找到熵减最大的特征,即最佳特征。
得到出现次数最多的分类的代码
当决策树遇到第二种递归出口时,就需要找到出现次数最多的分类,将其视为当前分类
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
构建决策树的代码
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel: {}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
递归构建决策树,用到了上述的各个函数 。
用构建好的决策树模型分类数据集
模型已经训练好.输入数据集即可进行分类。
需要注意的是,我们并不知道属性在数据集中的对应位置。所以要同时输入属性标签,将标签字符串转化为索引。
def classify(tree, featLabels, testVec):
firstStr = list(tree.keys())[0]
secondDict = tree[firstStr]
featIndex = featLabels.index(firstStr) #第一个key在标签中的下标
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
return classLabel
这也是个递归函数。当递归进行到叶子节点时,结束并返回结果。
存取决策树的函数
def storeTree(tree, fileName):
fw = open(fileName, 'wb')
pickle.dump(tree, fw)
fw.close()
def grabTree(fileName):
fr = open(fileName, 'rb')
return pickle.load(fr)
这里要注意,书上的代码为py2,py3进行了一些改动,存取时都需要加上一个参数
具体来说,就是保存决策树和读取决策树。需要注意的是py2中原书中代码确实无误,但是在py3中运行时需要注意,写入数据不能只用‘w’,必须用‘wb’,因为Python3给open函数添加了名为encoding的新参数,而这个新参数的默认值却是‘utf-8’。这样在文件句柄上进行read和write操作时,系统就要求开发者必须传入包含Unicode字符的实例,而不接受包含二进制数据的bytes实例。同样的, 读出数据时也需‘rb’模式。
实战:预测隐形眼镜类型
fr = open("D:\machine learning actual combat\lenses.txt")
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = createTree(lenses, lensesLabels)
print(lensesTree)
lenses.txt为本书提供的数据集,可到github下载
或参考此博客:https://blog.csdn.net/sinat_29957455/article/details/79123394
可视化部分由于作者的偷懒,并没有进行代码的实现
大二小白一枚,欢迎指正也欢迎讨论