本博客构造ID3决策树的方法步骤如下:
1、确定决策树的输出形式 data_item_list:[{'node1': None, 'weight': None, 'node2': None},、、、],其中data_item={'node1': None, 'weight': None, 'node2': None},node1为父节点,node2为子节点,weight为分支权重或属性分类
2、确定两个data_item之间的连接关系数据结构形式为data_node_list:[data_node,、、、],其中data_node={'node1': None, 'weight': None, 'node2': None, 'node1_name': None, 'node2_name': None},node1为父节点,node2为子节点,weight为分支权重或属性分类,node1_name和node2_name都是为了构建dot文件所建立。
3、将data_node_list转换成为dot文件格式并保存
4、在终端输入dot “dot文件名” -Tpng -o “保存的图片名”.png,如:dot MyPicture.dot -Tpng -o mytree_png.png
代码如下:
built_tree.py
import pandas as pd
import numpy as np
# 定义计算数据集信息增益的函数输入为lable_list和character_list
def getEnt(data_pd):
# 计算熵
ENT_D = 0.0
for item in data_pd[data_pd.keys()[-1]].value_counts():
ENT_D += (-np.log2(item / len(data_pd[data_pd.keys()[-1]])) * item / len(data_pd[data_pd.keys()[-1]]))
return ENT_D
# 定义一个生成子数据集的函数,传入的参数为数组和列标签
def get_son_data(data_pd, lable=None):
son_data_list = []
son_data_dict_list = []
# 得到数据集属性分类
part = data_pd[lable].unique()
# 根据分类划分子集
for item in part:
son_data_dict = {'node': None, 'weight': None, 'data': None}
data_M = data_pd[data_pd[lable].map(lambda x: x == item)]
data_last = data_M.drop(labels=lable, axis=1)
son_data_list.append(data_last)
son_data_dict['node'] = lable
son_data_dict['weight'] = item
son_data_dict['data'] = data_last
son_data_dict_list.append(son_data_dict)
return son_data_list, son_data_dict_list
def get_parent_node(data_array):
# 计算数据集的熵
max_end = 0
max_end_lable = None
ENT_D = getEnt(data_array)
# 对原数据集按照列索引
for column in data_array.keys()[:-1]:
# 按照列索引划分子集
son_data_list, _ = get_son_data(data_pd=data_array, lable=column)
# 计算子集的加权熵
entropy_sum = 0.0
for data in son_data_list:
entropy_sum += getEnt(data_pd=data) * len(data) / len(data_array)
# 计算信息增益
add_entropy = ENT_D - entropy_sum
# 选取最大信息增益的节点
if add_entropy > max_end:
max_end = entropy_sum
max_end_lable = column
# 返回最大信息增益的节点
_, son_data_dict_list = get_son_data(data_pd=data_array, lable=max_end_lable)
return son_data_dict_list
class Tree_Method:
def __init__(self, data_array):
self.data = data_array
self.node_link_list = []
self.parent_roof = get_parent_node(self.data)
self.biult_tree(self.parent_roof)
def biult_tree(self, son_data_dict_list):
# 判断子数据集是否都属于同一类
for item in son_data_dict_list:
node_link_add = {'node1': None, 'weight': None, 'node2': None}
t = item['data'][item['data'].keys()[-1]]
k = set(t)
if len(k) == 1:
node_link_add['node1'] = item['node']
node_link_add['weight'] = item['weight']
node_link_add['node2'] = t.unique()[0]
self.node_link_list.append(node_link_add)
else:
son_data_dict_list0 = get_parent_node(item['data'])
node_link_add['node1'] = item['node']
node_link_add['weight'] = item['weight']
node_link_add['node2'] = son_data_dict_list0[0]['node']
self.node_link_list.append(node_link_add)
self.biult_tree(son_data_dict_list0)
if __name__ == '__main__':
data = {
'体温': ['恒温', '冷血', '冷血', '恒温', '冷血', '冷血', '恒温', '恒温', '恒温', '冷血', '冷血', '恒温', '恒温', '冷血', '冷血'],
'表皮覆盖': ['毛发', '鳞片', '鳞片', '毛发', '无', '鳞片', '毛发', '羽毛', '软毛', '鳞片', '鳞片', '羽毛', '刚毛', '鳞片', '无'],
'胎生': ['是', '否', '否', '是', '否', '否', '是', '否', '是', '是', '否', '否', '是', '否', '否'],
'水生动物': ['否', '否', '是', '是', '半', '否', '否', '否', '否', '是', '半', '半', '否', '是', '半'],
'飞行动物': ['否', '否', '否', '否', '否', '否', '是', '是', '否', '否', '否', '否', '否', '否', '否'],
'右腿': ['是', '否', '否', '否', '是', '是', '是', '是', '是', '否', '是', '是', '是', '否', '是'],
'冬眠': ['否', '是', '否', '否', '是', '否', '是', '否', '否', '否', '否', '否', '是', '否', '是'],
'类标号': ['哺乳类', '爬行类', '鱼类', '哺乳类', '两栖类', '爬行类', '哺乳类', '鸟类', '哺乳类', '鱼类', '爬行类', '鸟类', '哺乳类', '鱼类', '两栖类'],
}
data_array = pd.DataFrame(data=data)
tree_bt = Tree_Method(data_array)
tre = tree_bt.node_link_list
for i in tre:
print(i)
tree_to_dot.py
from graphviz import Digraph
import built_tree
import pandas as pd
data = {
'Temperature': ['AC_tempreture', 'cold_blooded', 'cold_blooded', 'AC_tempreture', 'cold_blooded',
'cold_blooded', 'AC_tempreture', 'AC_tempreture', 'AC_tempreture', 'cold_blooded',
'cold_blooded', 'AC_tempreture', 'AC_tempreture', 'cold_blooded', 'cold_blooded'],
'Covering': ['hair', 'scales', 'scales', 'hair', 'None', 'scales', 'hair', 'feather',
'fur', 'scales', 'scales', 'feather', 'bristles', 'scales', 'None'],
'Viviparity': ['yes', 'no', 'no', 'yes', 'no', 'no', 'yes', 'no', 'yes', 'yes',
'no', 'no', 'yes', 'no', 'no'],
'Aquatic_animals': ['no', 'no', 'yes', 'yes', 'half', 'no', 'no', 'no', 'no', 'yes',
'half', 'half', 'no', 'yes', 'half'],
'Flying_animals': ['no', 'no', 'no', 'no', 'no', 'no', 'yes', 'yes', 'no', 'no', 'no',
'no', 'no', 'no', 'no'],
'Have_a_leg': ['yes', 'no', 'no', 'no', 'yes', 'yes', 'yes', 'yes', 'yes', 'no', 'yes',
'yes', 'yes', 'no', 'yes'],
'Hibernation': ['no', 'yes', 'no', 'no', 'yes', 'no', 'yes', 'no', 'no', 'no', 'no', 'no',
'yes', 'no', 'yes'],
'Lable': ['mammals', 'reptiles', 'fish', 'mammals', 'amphibians', 'reptiles', 'mammals',
'birds', 'mammals', 'fish', 'reptiles', 'birds', 'mammals', 'fish', 'amphibians'],
}
data_array = pd.DataFrame(data=data)
tree_bt = built_tree.Tree_Method(data_array)
tree = tree_bt.node_link_list
k = 0
for i in tree:
if i['node1']==tree[0]['node1']:
i['node1_name'] = 'root'
i['node2_name'] = 'name' + str(k)
else:
i['node1_name'] = 'name'+str(k+1)
i['node2_name'] = 'name' + str(k+2)
k+=3
for item in range(len(tree)-1):
if tree[item]['node1']==tree[item+1]['node1']:
tree[item+1]['node1_name'] = tree[item]['node1_name']
elif tree[item]['node2']==tree[item+1]['node1']:
tree[item + 1]['node1_name'] = tree[item]['node2_name']
else:
#向上索引
k0 = item
while(True):
if tree[item+1]['node1']==tree[k0]['node1']:
tree[item + 1]['node1_name'] = tree[k0]['node1_name']
break
else:
k0-=1
# 实例化一个Digraph对象(有向图),name:生成的图片的图片名,format:生成的图片格式
dot = Digraph(name="MyPicture", comment="the test", format="png")
for item in tree:
# 生成图片节点,name:这个节点对象的名称,label:节点名,color:画节点的线的颜色
dot.node(name=item['node1_name'], label=item['node1'], color='green')
dot.node(name=item['node2_name'], label=item['node2'], color='green')
# 在节点之间画线,label:线上显示的文本,color:线的颜色
dot.edge(item['node1_name'], item['node2_name'], label=item['weight'], color='red')
# 打印生成的源代码
with open(file='MyPicture.dot', mode='w') as f:
f.write(dot.source)
终端输入:dot MyPicture.dot -Tpng -o mytree_png.png
最终生成决策树如下: