1.1决策树的定义
决策树(Decision Tree)是从一组无次序、无规则,但有类别标号的样本集中推导出的、树形表示的分类规则。
一般的,一棵决策树包含一个根结点、若干个内部结点(中间结点)和若干个叶子结点。
树的叶子结点表示类别标号,即分类属性的取值,也可以说是决策结果;树的内部结点为条件属性,每个结点包含的样本集合根据属性测试结果被划分到子结点中;根结点包含样本全集。从树根到叶子结点的一条路径称为一条决策规则,它可以对未知数据进行分类或预测。每条有向边都用其出点的属性值标记。
通常,一个属性有多少种取值,就从该结点引出多少条有向边,每一条边代表属性的一种取值。
树深度是树根到树叶的最大层数,通常作为决策树模型复杂度的一种度量。
1.2决策树的优缺点
决策树的优点:
决策树算法中学习简单的决策规则建立决策树模型的过程非常容易理解;
决策树模型可以可视化,非常直观;
应用范围广,可用于分类和回归,而且非常容易做多类别的分类;
能够处理数值型和连续的样本特征。
决策树的缺点:
很容易在训练数据中生成复杂的树结构,造成过拟合(overfitting)。剪枝可以缓解过拟合的负作用,常用方法是限制树的高度、叶子节点中的最少样本数量。
1.3决策树实例
#数据加载
def loadData( ):
data = pd.read_csv("titanic.csv")
data["Age"] = data["Age"].fillna(data["Age"].mean()) #针对Age字段,采用均值进行填充
data = data.dropna()
data.isna().sum()
dataset = data.values.tolist()
#四个属性
labels=['Pclass','Age','Sex','Survived']
return dataset,labels
#计算给定数据的香农熵
def calShannonEnt(dataset):
numEntries = len(dataset) #获得数据集函数
labelCounts={} #用于保存每个标签出现的次数
for data in dataset:
#提取标签信息
classlabel = data[-1]
if(classlabel not in labelCounts.keys()): #如果标签未放入统计次数的字典,则添加进去
labelCounts[classlabel]=0
labelCounts[classlabel]+=1 #标签计数
shannonEnt=0.0 #熵初始化
for key in labelCounts:
p = float(labelCounts[key])/numEntries #选择该标签的概率
shannonEnt-= p*np.log2(p)
return shannonEnt #返回经验熵
#根据某一特征划分数据集
def splitDataset(dataset,axis,value):
# dataset 待划分的数据集 axis 划分数据集的特征 value 返回数据属性值为value
retDataSet = [] #创建新的list对象
for featVec in dataset: #遍历元素
if featVec[axis]==value: #符合条件的抽取出来
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
# 统计出现次数最多的元素(类标签)
def majorityCnt(classList):
classCount={} #统计classList中每个类标签出现的次数
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) #根据字典的值降序排列
return sortedClassCount[0][0] #返回出现次数最多的类标签
def createTree(dataset,labels):#数据集和标签列表
classList =[example[-1] for example in dataset]#数据所属类得值
if classList.count(classList[0])==len(classList):#条件1:classList只剩下一种值
return classList[0]
if len(dataset[0])==1:#条件2:数据dataset中属性已使用完毕,但没有分配完毕
return majorityCnt(classList)#取数量多的作为分类
bestFeat = chooseBestFeatureToSplit(dataset)#选择最好的分类点,即香农熵值最小的
labels2 = labels.copy()#复制一分labels值,防止原数据被修改。
bestFeatLabel = labels2[bestFeat]
myTree = {bestFeatLabel:{}}#选取获取的最好的属性作为
del(labels2[bestFeat])
featValues = [example[bestFeat] for example in dataset]#获取该属性下的几类值
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels2[:]#剩余属性列表
myTree[bestFeatLabel][value] = createTree(splitDataset(dataset,bestFeat,value),subLabels)
return myTree
if __name__ == '__main__':
dataSet, labels = loadData( )
print("数据集信息熵:"+str(calShannonEnt(dataSet)))
mytree = createTree(dataSet,labels)
print(mytree)
绘制:
import numpy as np
import pandas as pd
from sklearn import tree
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.feature_extraction import DictVectorizer
from sklearn.tree import export_graphviz
import matplotlib.pyplot as plt
# 加载数据
data = pd.read_csv('titanic.csv')
# print(data.head())
# print(data.shape) # (891, 12)
# 获取数据
x = data[['Pclass','Age','Sex']]
y = data['Survived']
# print(x.head())
# print(y.head())
# 缺失值处理
x['Age'].fillna(x['Age'].mean(),inplace=True)
# 特征处理
x['Sex'] = np.array([0 if i == 'male' else 1 for i in x['Sex']]).T
# 打印查看是否替换成功
# print(x.head())
# 划分数据集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2,random_state=22)
# 创建模型
model = DecisionTreeClassifier(criterion='entropy',max_depth=3)
# 3 准确率: 0.7653631284916201
# 4 准确率: 0.7653631284916201
# 5 准确率: 0.7430167597765364
# 6 准确率: 0.7486033519553073
# 8 准确率: 0.7541899441340782
# 12 准确率: 0.770949720670391
model.fit(x_train,y_train)
# 评估
score = model.score(x_test,y_test)
pred = model.predict(x_test)
print('准确率:',score)
print('预测值:',pred)
# 可视化方法一:
# export_graphviz(model, out_file="tree.dot", feature_names=['pclass','age','sex'])
# 更简单的方法:
plt.figure(figsize=(20,20))
feature_name = ['pclass','age','sex']
class_name = ['survived','death']
tree.plot_tree(model,feature_names=feature_name,class_names=class_name)
plt.savefig('tree.png')
效果截图: