tree.py
from math import log
from matplotlib.font_manager import FontProperties
import matplotlib.pyplot as plt
import pickle
"""对数据集进行标注:
年龄:0-青年,1-中年,2-老年
有工作:0-否,1-是
有自己的房子:0-否,1-是
信贷情况:0-一般,1-好,2-非常好
类别(是否给贷款):no-否,yes-是
"""
"""
函数说明:创建数据集
Parameters:
无
returns:
dataset - 数据集
labels - 标签
"""
def creatdataset():
dataset = [[0, 0, 0, 0, 'no'],
[0, 0, 0, 1, 'no'],
[0, 1, 0, 1, 'yes'],
[0, 1, 1, 0, 'yes'],
[0, 0, 0, 0, 'no'],
[1, 0, 0, 0, 'no'],
[1, 0, 0, 1, 'no'],
[1, 1, 1, 1, 'yes'],
[1, 0, 1, 2, 'yes'],
[1, 0, 1, 2, 'yes'],
[2, 0, 1, 2, 'yes'],
[2, 0, 1, 1, 'yes'],
[2, 1, 0, 1, 'yes'],
[2, 1, 0, 2, 'yes'],
[2, 0, 0, 0, 'no']]
labels = ['年龄', '有工作', '有自己的房子', '信贷情况'] # 特征标签
return dataset,labels
"""
函数说明:计算给定数据集的经验熵(香农熵)H(D)
Parameters:
dataset - 数据集
Returns:
shannonEnt - 经验熵
"""
def calcShannonEnt(dataset):
labelcounts = {}
length = len(dataset)
for vector in dataset:
label = vector[-1]
if label not in labelcounts.keys():
labelcounts[label] = 0
labelcounts[label] += 1
shannonEnt = 0.0
for v in labelcounts.values():
prob = v/length
shannonEnt -= prob*log(prob,2)
return shannonEnt
"""
函数说明:按照给定特征划分数据集(获取各个特征的子集)
Parameters:
dataSet - 待划分的数据集
axis - 划分数据集的特征
value - 需要返回的特征的值
Returns:
无
"""
def splitDataSet(dataset, axis, value):
retDataset = []
for vector in dataset:
if vector[axis] == value:
reducedFeatVec = vector[:axis]
reducedFeatVec.extend(vector[axis+1:])
retDataset.append(reducedFeatVec)
return retDataset
"""
函数说明:选择最优特征
Parameters:
dataSet - 数据集
Returns:
bestFeature - 最优特征
"""
def chooseBestFeatureToSplit(dataset):
H_D = calcShannonEnt(dataset)
num_feature = len(dataset[0])-1
bestFeature = -1
best_gain_info = 0.0
for i in range(num_feature):
feature = [example[i] for example in dataset]
unique_feature = set(feature)
info = 0.0
for value in unique_feature:
retDataset = splitDataSet(dataset,i,value)
shannonEnt = calcShannonEnt(retDataset)
prob = len(retDataset)/float(len(dataset))
info += shannonEnt*prob
infoGain=H_D - info
if infoGain>best_gain_info:
best_gain_info = infoGain
bestFeature = i
return bestFeature
"""
函数说明:统计classList中出现次数最多的元素(类标签)
Parameters:
classList - 类标签
Returns:
sortedClassCount[0][0] - 出现次数最多的元素(类标签)
"""
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(),key=lambda item:item[1],reverse=True)
return sortedClassCount[0][0]
"""
函数说明:创建决策树
Parameters:
dataSet - 训练数据集
labels - 分类属性标签
featlabels - 存储选择的最优特征
Returns:
mytree - 决策树
"""
# 递归终止条件:1.数据集中所有样本均属于同一类;
# 2.所有特征均遍历完,此处选择数据集中样本数最多是作为特征
def createTree(dataset,labels,featlabels):
classList = [example[-1] for example in dataset]
## 递归终止
# 属于同类
if classList.count(classList[0]) == len(classList):
return classList[0]
# 特征遍历完毕
if (len(dataset[0])==1) or (len(labels)==0):
return majorityCnt(classList)
# 选择最优特征
bestfeat = chooseBestFeatureToSplit(dataset)
bestfeature = labels[bestfeat]
mytree = {bestfeature:{}}
featlabels.append(bestfeature)
del (labels[bestfeat])
bestfeaturelist = [vector[bestfeat] for vector in dataset]
unique_bestfeature = set(bestfeaturelist)
for value in unique_bestfeature:
sublabels = labels[:]
mytree[bestfeature][value] = createTree(splitDataSet(dataset,bestfeat,value),sublabels,featlabels)
return mytree
"""
函数说明:获取决策树叶子结点的数目
Parameters:
myTree - 决策树
Retuns:
numLeafs - 决策树的叶子结点的数目
"""
def getNumLeafs(myTree):
numleafs = 0
firstStr = next(iter(myTree))
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
numleafs += getNumLeafs(secondDict[key])
else:
numleafs += 1
return numleafs
"""
函数说明:获取决策树的层数
Parameters:
myTree - 决策树
Returns:
maxDepth - 决策树的层数
"""
def getTreeDepth(mytree):
maxDepth = 0
firstStr = next(iter(mytree))
secondDict = mytree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1+getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
"""
函数说明:存储决策树
Parameters:
inputTree - 决策树
filename - 存储文件
Returns:
无
"""
def storeTree(inputTree,filename):
with open(filename,'wb') as fw:
pickle.dump(inputTree,fw)
dataset,labels = creatdataset()
featlabels = []
mytree = createTree(dataset,labels,featlabels)
filename = 'classifierStorage.txt'
storeTree(mytree,filename)
main.py
from tree import *
"""
函数说明:使用决策树进行分类
Parameters:
inputTree - 已经生成的决策树
featlabels - 存储选择的最优特征标签
testVec - 测试数据列表,顺序对应最优特征标签
Returns:
classLabel - 分类结果
"""
def classify(inputTree,featlabels,testVec):
firststr = next(iter(inputTree))
secondDict = inputTree[firststr]
featIndex = featlabels.index(firststr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classlabel = classify(secondDict[key],featlabels,testVec)
else:
classlabel = secondDict[key]
return classlabel
"""
函数说明:加载决策树
Parameters:
filename - 文件
Retuens:
pickle.load(fr) - 决策树字典
"""
def getTree(filename):
fr = open(filename,'rb')
return pickle.load(fr)
if __name__ == '__main__':
mytree = getTree(filename)
testVec = [0,1]
classlable = classify(mytree,featlabels,testVec)
if classlable == 'no':
print('不放贷')
elif classlable == 'yes':
print('放贷')