决策树
决策树是一树状结构,它的每一个树结点可以是叶节点,对应着某一类,也可以对应着一个划分,将该节点对应的样本集划分成若干个子集,每个子集对应一个节点。对一个分类问题,从已知类标记的训练元组学习并构造出决策树是一个从上至下,分而治之的过程。
ID3算法
ID3算法是一种基于信息熵的决策树分类算法,它选择当前样本集中具有最大信息增益值的属性作为测试属性;样本集的划分则依据测试属性的取值进行,测试属性有多少不同取值就将样本集划分为多少子样本集,同时,决策树上相应于该样本集的节点长出新的叶子节点。ID3算法根据信息理论,采用划分后样本集的不确定性作为衡量划分好坏的标准,用信息增益值度量:信息增益值越大,不确定性越小。因此,ID3算法在每个非叶节点选择信息增益最大的属性作为测试属性。该属性使得对结果划分中的样本分类所需的信息最小,并反映划分的最小随机性。
ID3算法具体流程
ID3算法的具体详细实现步骤如下。
- 对当前样本集合,计算所有属性的信息增益;
- 选择信息增益最大的属性作为测试属性,把测试属性取值相同的样本划为同一个子样本集;
- 若子样本集的类别属性只含有单个属性, 则分支为叶子节点, 判断其属性值并标上相应的符号, 然后返回调用处;否则对子样本集递归调用本算法。
代码实现(ID3)
准备
Graphviz安装
默认安装
下载地址: https://graphviz.org/download/
检查环境变量
dot -version
graphviz库安装
pip install graphviz
其他库安装
pip install pydotplus # 读取树结构数据
pip install six
pip install sklearn
pip install pandas
数据集:有关酒吧饭馆事例列表.xls
代码
import pydotplus # 读取树结构数据
from six import StringIO
from sklearn.tree import export_graphviz
from sklearn.tree import DecisionTreeClassifier as DTC
import pandas as pd
# 数据初始化
inputfile = 'data/有关酒吧饭馆事例列表.xls'
data = pd.read_excel(inputfile, index_col=u'例子')
# 数据是类别标签,将它转换成为数据
# 城镇:Y:1 ; N:0
# 有无大学: Y:1 ; N:0
# 居住区类型: M:3 ; L:2 ; S:1 ; N:0
# 有无工业区: Y:1 ; N:0
# 交通条件: A:2 ; P:1 ; G:0
# 学校数量: M:2 ; L:1 ; S:0
# 类别: +:1 ; -:0
data[data == u'Y'] = 1
data[data == u'N'] = 0
data[data == u'+'] = 1
data[data == u'-'] = 0
data[data[["居住区类型"]] == u'M'] = 3
data[data[["居住区类型"]] == u'L'] = 2
data[data[["居住区类型"]] == u'S'] = 1
data[data[["交通条件"]] == u'A'] = 2
data[data[["交通条件"]] == u'P'] = 1
data[data[["交通条件"]] == u'G'] = 0
data[data[["学校数量"]] == u'M'] = 2
data[data[["学校数量"]] == u'L'] = 1
data[data[["学校数量"]] == u'S'] = 0
print(data)
# x = data.iloc[:, :6].astype(int)
# y = data.iloc[:, 6].astype(int)
x = data[["城镇", "有无大学", "居住区类型", "有无工业区", "交通条件", "学校数量"]].astype(int)
y = data[["类别"]].astype(int)
dtc = DTC(criterion='entropy') # 建立决策树模型,基于信息熵
dtc.fit(x, y) # 训练模型
# 导入相关函数,可视化决策树
# 输出结果为dot文件,需安装Grapviz才能将它转换为pdf或png格式
x = pd.DataFrame(x)
# 存储成dot格式
# x = pd.DataFrame(x)
# with open("tree.dot", 'w') as f:
# f = export_graphviz(dtc, feature_names=x.columns, out_file=f)
# 存储为pdf格式,已解决中文显示问题
# 初始化dot文件
dot_data = StringIO()
export_graphviz(dtc, out_file=dot_data,
feature_names=x.columns,
filled=True, rounded=True, special_characters=True)
# 手动修改树结构的字体类型
dot_data_val = dot_data.getvalue()
dot_data_val = dot_data_val.replace('helvetica', 'SimSun')
graph = pydotplus.graph_from_dot_data(dot_data_val)
# 保存图像到pdf文件
graph.write_pdf('test.pdf')