决策树 Python实现
决策树无非分为三步:
- 特征选择
- 决策树的生成
- 决策树的剪枝
决策树由一组组的if-then组成,根据计算信息增益来选择最优的特征作为根节点
数据集特征如下:
年龄:0代表青年,1代表中年,2代表老年;
有工作:0代表否,1代表是;
有自己的房子:0代表否,1代表是;
信贷情况:0代表一般,1代表好,2代表非常好;
类别(是否给贷款):no代表否,yes代表是。
from math import log
import operator
"""
Function Description:创建测试数据集
Parameter:None
Return:
dataSet:数据集
labels:标签
Modify:
2019-07-19
"""
def createDataSet():
dataSet = [[0, 0, 0, 0, 'no'],
[0, 0, 0, 1, 'no'],
[0, 1, 0, 1, 'yes'],
[0, 1, 1, 0, 'yes'],
[0, 0, 0, 0, 'no'],
[1, 0, 0, 0, 'no'],
[1, 0, 0, 1, 'no'],
[1, 1, 1, 1, 'yes'],
[1, 0, 1, 2, 'yes'],
[1, 0, 1, 2, 'yes'],
[2, 0, 1, 2, 'yes'],
[2, 0, 1, 1, 'yes'],
[2, 1, 0, 1, 'yes'],
[2, 1, 0, 2, 'yes'],
[2, 0, 0, 0, 'no']]
labels = ['年龄', '有工作', '有房', '信贷情况']
return dataSet, labels
"""
Function Description:计算给定数据集的经验熵
Parameters:
dataSet
Returns:
shannonEnt;经验熵
Modify:
2019-07-19
"""
def calcShannonEnt(dataSet):
# 返回数据集行数
numEntries = len(dataSet)
# 记录每个标签出现的次数的字典 如{'no':6,'yes':9}
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1] # 提取标签信息
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
# 计算经验熵
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries # 选择该标签的概率
shannonEnt -= prob * log(float(prob), 2)
return shannonEnt
"""
Function Description:按给定特征划分数据集
Parameters;
dataSet
axis:划分数据集的特征的索引值
value:划分数据集的特征的值
Returns:
划分后的数据集
Modify:
2019-07-19
"""
def splitDateSet(dataSet, axis, value):
# 返回的数据集
reDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis + 1:])
reDataSet.append(reducedFeatVec)
return reDataSet
"""
Function Description:选择最好的特征分类
Parameters:
dataSet
Returns:
bestFeature:信息增益最大的特征的索引值
Modify:
2019-07-19
"""
def chooseBestFeatureToSplit(dateSet):
# 特征数量
numFeatures = len(dateSet[0]) - 1
# 原始香农熵
baseEntropy = calcShannonEnt(dataSet)
# 最大信息增益
bestInfoGain = 0.0
# 最优分类特征的索引值
bestFeature = -1
# 遍历所有特征
for i in range(numFeatures):
# 获取第i个特征的所有值
featList = [example[0] for example in dataSet]
# 用集合删除重复值
uniqueVals = set(featList)
# 经验条件熵
newEntropy = 0.0
for value in uniqueVals:
# 划分后的子集
subDataSet = splitDateSet(dateSet, i, value)
# 计算经验条件熵
prob = len(subDataSet) / float(len(dateSet))
newEntropy += prob * calcShannonEnt(subDataSet)
# 信息增益
infoGain = baseEntropy - newEntropy
if infoGain > bestInfoGain:
# 寻找最大信息增益
bestInfoGain = infoGain
bestFeature = i
return bestFeature
"""
Function Description:统计分类标签中出现最多的元素
Parameters:
classList:分类标签列表
Returns:
sortedClassCount[0][0]:出现次数最多的元素
Modify:
2019-07-19
"""
def majorityCnt(classList):
classCount = {}
for vote in classList:
classCount[vote] = 0
classCount[vote] += 1
# 根据降序排列
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
"""
Function Description:决策树的生成
Parameters:
dataSet:
labels:
featLabels:最优特征标签
Modify:
2019-07-19
"""
def createTree(dataSet, labels, featLabels):
# 分类标签
classList = [example[-1] for example in dataSet]
# 如果类别相同,则停止划分
if classList.count(classList[0]) == len(classList):
return classList[0]
# 如果所有特征都遍历完,返回出现次数最多的分类
if len(dataSet[0]) == 1:
return majorityCnt(classList)
# 最优特征
bestFeat = chooseBestFeatureToSplit(dataSet)
# 最优特征的标签
bestFeatLabel = labels[bestFeat]
featLabels.append(bestFeatLabel)
# 根据最优标签生成树
myTree = {bestFeatLabel: {}}
# 删除已经使用过的标签
del (labels[bestFeat])
# 得到训练集中所有特征的值,并删除重复值
featValues = [example[bestFeat] for example in dataSet]
uniqueVls = set(featValues)
# 遍历特征,创建决策树
for value in uniqueVls:
myTree[bestFeatLabel][value] = createTree(splitDateSet(dataSet, bestFeat, value), labels, featLabels)
return myTree
"""
Function Description:使用决策树进行分类
Parameters:
inputTree:已生成的决策树
featLabels:最优特征标签
testVec:测试数据列表
Returns:
classLabel:分类结果
Modify:
2019-07-20
"""
def classify(inputTree, featLabels, testVec):
# 获取决策树节点
firstStr = next(iter(inputTree))
# 下一个字典
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
return classLabel
if __name__ == '__main__':
dataSet, labels = createDataSet()
featLabels = []
myTree = createTree(dataSet, labels, featLabels)
testVec = [0, 1]
result = classify(myTree, featLabels, testVec)
if result == 'yes':
print('放贷')
else:
print("不放贷")