决策树其实基本原理不难,主要是N叉树的实现。
1、节点的选择标准:代码使用最大香农熵增益,选取最优特征值
2、根据节点值来生成子树:特征值的取值
3、叶子节点未最后分类的结果
使用决策树进行测试集分类时,遍历到叶子节点就是最后分类的结果
import numpy as np
import math
import tree_plotter
def createDataSet():
'''
创建数据集
:return:
'''
# dataSet = [[1, 1, 'yes'],
# [1, 1, 'yes'],
# [1, 0, 'no'],
# [0, 1, 'no'],
# [0, 1, 'no']]
# labels = ['no surfacing', 'flippers']
# return dataSet, labels
data_set = [[0, 0, 0, 0, 'no'],
[0, 0, 0, 1, 'no'],
[0, 1, 0, 1, 'yes'],
[0, 1, 1, 0, 'yes'],
[0, 0, 0, 0, 'no'],
[1, 0, 0, 0, 'no'],
[1, 0, 0, 1, 'no'],
[1, 1, 1, 1, 'yes'],
[1, 0, 1, 2, 'yes'],
[1, 0, 1, 2, 'yes'],
[2, 0, 1, 2, 'yes'],
[2, 0, 1, 1, 'yes'],
[2, 1, 0, 1, 'yes'],
[2, 1, 0, 2, 'yes'],
[2, 0, 0, 0, 'no']]
# 分类属性
labels = ['年龄', '有工作', '有自己的房子', '信贷情况']
# 返回数据集和分类属性
return data_set, labels
def computeShannonEntroy(dataSet):
'''
计算数据集的香农熵
:param dataSet: 数据集
:return: 香农熵
'''
dataNums = len(dataSet)# 数据集的数量
labelCount = {}# 统计最后分类结果
for dataItem in dataSet:
classLabel = dataItem[-1] #样本集的分类结果
if classLabel not in labelCount.keys():
labelCount[classLabel] = 0
labelCount[classLabel] += 1
shannonEntroy = 0
# 计算数据集的香农熵
for classLabel in labelCount.keys():
prob = float(labelCount[classLabel]) / dataNums
shannonEntroy -= prob * math.log(prob,2)
return shannonEntroy
def splitDataset(dataSet,index,value):
'''
根据某一特征和特征值分裂数据集
:param dataSet:需要分裂的数据集
:param index:某一特征
:param value:特征值
:return: 分裂的数据集
'''
resDataset = []
for dataItem in dataSet:
if dataItem[index] == value:
reduceItem = dataItem[:index]
# 注意extend()和append()两个的区别
reduceItem.extend(dataItem[index+1:])
resDataset.append(reduceItem)
return resDataset
def chooseBestFeatureIndex(dataSet,features):
'''
根据最大信息增益,选择最优特征
:param dataSet:数据集
:param features:特征集
:return:最优特征
'''
featureNums = len(features)
dataSetNums = len(dataSet)
baseShannonEntroy = computeShannonEntroy(dataSet)
bestShannonGain = 0
bestFeatureIndex = -1
for index in range(featureNums):
indexValueSet = [item[index] for item in dataSet]
indexValueUniqueSet = set(indexValueSet)
newShannonEntroy = 0
for value in indexValueUniqueSet:
newDataSet = splitDataset(dataSet,index,value)
prob = len(newDataSet) / dataSetNums
newShannonEntroy += prob * computeShannonEntroy(newDataSet)
newShannonGain = baseShannonEntroy - newShannonEntroy
if newShannonGain > bestShannonGain:
bestShannonGain = newShannonGain
bestFeatureIndex = index
return bestFeatureIndex
def chooseMajorClass(classList):
'''
当最后只剩一个特征时,还未得出分类,则需要判断当前分类集中最多的分类
:param classList:分类集
:return:当前分类集中最多的分类
'''
classCount = {}
for classItem in classList:
if classItem not in classCount.keys():
classCount[classItem] = 0
classCount[classItem] += 1
# return max(classCount,key=lambda x:x[0])
return max(classCount,key=classCount.get) # 返回分类集中最多的分类
def buildDTree(dataSet,features):
'''
构建决策树
:param dataSet: 数据集
:param features: 特征集
:return: 决策树
'''
classList = [item[-1] for item in dataSet]
# 当最后只有唯一的分类结果时
if classList.count(classList[0]) == len(classList):
return classList[0]
# 当特征集只剩最后一个时
if(len(features) == 0):
return chooseMajorClass(classList)
bestFeatureIndex = chooseBestFeatureIndex(dataSet,features)
bestFeature = features[bestFeatureIndex]
DTree = {bestFeature:{}}
del features[bestFeatureIndex]
bestFeatureValue = [item[bestFeatureIndex] for item in dataSet]
bestFeatureUniqueValue = set(bestFeatureValue)
for value in bestFeatureUniqueValue:
newDataSet = splitDataset(dataSet,bestFeatureIndex,value)
DTree[bestFeature][value] = buildDTree(newDataSet,features)
return DTree
def classify(DTree,features,testDataSet):
'''
使用决策树DTree对测试数据集进行分类
:param DTree:决策树
:param features: 特征集
:param testDataSet: 测试数据集
:return: 最后分类结果
'''
# 当DTree类型不是字典类型时,意味着已经到叶子节点,此时返回DTree即可
if type(DTree).__name__ != 'dict':
return DTree
firstFeature = list(DTree.keys())[0]
firstFeatureIndex = features.index(firstFeature)
# 根据测试集数据进行决策树DTree的遍历
newDTree = DTree[firstFeature][testDataSet[firstFeatureIndex]]
classResult = classify(newDTree,features,testDataSet)
return classResult
if __name__ == '__main__':
dataSet,features = createDataSet()
featuresCopy = features.copy()
DTree = buildDTree(dataSet,featuresCopy)
print(DTree)
# tree_plotter.create_plot(DTree)
testDataSet = [0, 1, 0, 0]
print(features)
print(classify(DTree,features,testDataSet))