便于分(坐)享(享)模(其)型(成)的好方法!
目录
三、利用predict_from_tree_list函数恢复树
记录本文的原因
本文是对以下这篇博客(后称为前文)的延伸: 决策树分类算法(一)(信息熵,信息增益,基尼指数计算)_阿维的博客日记的博客-CSDN博客
为什么要写这篇文章呢?我在学习大佬的决策树博客时,想到一个问题:能不能直接拿他训练好的模型来预测我的数据呢?这就涉及到模型的保存、传输和恢复。但查到的资料都是对sklearn库中的决策树进行操作,而不是这种大佬自己创建的class对象,所以写了点代码来恢复决策树模型。
注:本文所用数据集为 mushroom.data(本文最后附件),其特征维数为22,标签值‘1’代表有毒,‘2’代表无毒
一、利用print_tree函数保存树
前文的print_tree函数:
def print_tree(tree, level='0'):
'''简单打印一颗树的结构
para tree:biTree_node,树的根结点
para level='0':str, 节点在树中的位置,用一串字符串表示,0表示根节点,0L表示根节点的左孩子,0R表示根节点的右孩子
'''
if tree.leafLabel != None:
print('*' + level + '-' + str(tree.leafLabel)) # 叶子节点用*表示,并打印出标签
else:
print('+' + level + '-' + str(tree.f) + '-' + str(tree.fvalue)) # 中间节点用+表示,并打印出特征编号及其划分值
print_tree(tree.l, level + 'L')
print_tree(tree.r, level + 'R')
为了保存树,创建一个全局列表tree_list,将其打印的东西都加入列表中。修改后的print_tree函数:
tree_list=[]
def print_tree(tree, level='0'):
'''简单打印一颗树的结构
para tree:biTree_node,树的根结点
para level='0':str, 节点在树中的位置,用一串字符串表示,0表示根节点,0L表示根节点的左孩子,0R表示根节点的右孩子
'''
global tree_list
if tree.leafLabel != None:
tree_list.append('*' + level + '-' + str(tree.leafLabel))
print('*' + level + '-' + str(tree.leafLabel)) #叶子节点用*表示,并打印出标签
else:
tree_list.append('+' + level + '-' + str(tree.f) + '-' + str(tree.fvalue))
print('+' + level + '-' + str(tree.f) + '-' + str(tree.fvalue)) #中间节点用+表示,并打印出特征编号及其划分值
print_tree(tree.l, level+'L')
print_tree(tree.r, level+'R')
return tree_list
其结果形式为一棵完整的树,表示如下:
['+0-4-29', '+0L-4-26', '*0LL-1', '+0LR-19-103', '*0LRL-2', '+0LRR-0-4', '*0LRRL-2', '*0LRRR-1', '*0R-1']
其中各项的具体含义请参照前文。
二、利用tree_list列表传输树
由上一点可知,一个tree_list可以表示一棵完整的树,所以树的传输也就是该列表的传输。A可以用任何方式把这个列表发给B,比如在微信聊天框里发给B、保存到一个txt文件里发给B、保存到文件后附带数字签名和加密操作后发给B……只要B能还原出这个列表l_tree,就代表树传输成功。
三、利用predict_from_tree_list函数恢复树
B接收到树列表l_tree后,遍历自己数据集test_sample中的每一项x(x的最后一项为标签值)来判断A的模型准确率是否达标。
for x in test_sample:
#err+=(predict(x, tree)==x[-1])
err+=(predict_from_tree_list(l_tree,x)==x[-1])
print("accurency:",err/l_test)
函数predict_from_tree_list的思想是,先从l_tree中获取一个字典x_dict,其形式为:
{'0': '4-29', '0L': '4-26', '0LL': '*1', '0LR': '11-61', '0LRL': '2-18', '0LRLL': '*2', '0LRLR': '0-7', '0LRLRL': '*1', '0LRLRR': '*2', '0LRR': '7-39', '0LRRL': '*1', '0LRRR': '*2', '0R': '*1'}
先上代码:
def predict_from_tree_list(l_tree,x):
x_dict={}
for node in l_tree:
node_info=node.split('-')
if node_info[0][0]=='+':
x_dict[node_info[0][1:]]=node_info[1]+'-'+node_info[2]
if node_info[0][0]=='*':
x_dict[node_info[0][1:]]='*'+node_info[1]
#print(x_dict)
return func(x_dict,x,'0')
其中返回的func为递归函数,输入字典x_dict和这一条数据x以及当前的序列seq:
def func(x_dict,x,seq):
if x_dict[seq][0]=='*':
return int(x_dict[seq][1:])
f_list=x_dict[seq].split('-')
f=int(f_list[0])
fvalue=int(f_list[1])
if x[f]<fvalue:
seq+='L'
return func(x_dict,x,seq)
if x[f]>=fvalue:
seq+='R'
return func(x_dict,x,seq)
如何理解呢?我们举一个例子来说明。先赋值:
l_tree = ['+0-4-29', '+0L-4-26', '*0LL-1', '+0LR-2-18', '*0LRL-2', '+0LRR-2-19', '*0LRRL-1', '*0LRRR-2', '*0R-1']
x=[3, 11, 13, 24, 28, 34, 37, 39, 45, 53, 54, 59, 63, 67, 76, 85, 86, 90, 94, 98, 109, 114, 2]
可以得到:
x_dict = {'0': '4-29', '0L': '4-26', '0LL': '*1', '0LR': '11-61', '0LRL': '*2', '0LRR': '7-39', '0LRRL': '*1', '0LRRR': '*2', '0R': '*1'}
即存储的若是叶子节点,就将其路径和分类结果写入字典,否则写入在该节点进入下一节点的判断依据:a-b表示x[a]与b的比较,若x[a]<b,则进入左子树,若x[a]>=b,则进入右子树。获取x_dict后,开始迭代运行func。
因为第一次判断都是从根节点‘0’开始,所以输入func的第一个序列seq为‘0’。首先从字典中判断该序列是否为叶子结点,是则return该节点代表的分类,否则用 f 和 fvalue 表示 a 和 b ,然后进行判断。若x[f]<fvalue,则进入左子树,即序列seq+=‘L’并进入下一次迭代,反之同理。由此就可以根据树列表l_tree预测出x的分类结果。
信息增益二叉树:
+0-4-29
+0L-4-26
*0LL-1
+0LR-19-103
*0LRL-2
+0LRR-2-19
*0LRRL-2
*0LRRR-1
*0R-1
l_tree: ['+0-4-29', '+0L-4-26', '*0LL-1', '+0LR-19-103', '*0LRL-2', '+0LRR-2-19', '*0LRRL-2', '*0LRRR-1', '*0R-1']
信息增益二叉树对样本进行预测的结果:
accurency: 0.9893333333333333
四、完整代码
splitInfo.py可在前文中获得
rebuild_tree.py:
# coding:UTF-8
from splitInfo import info_entropy, gini_index, split_samples, sum_of_each_label
import pandas as pd
import random
class biTree_node:
'''
二叉树节点
'''
def __init__(self, f=-1, fvalue=None, leafLabel=None, l=None, r=None, splitInfo="gini"):
'''
类初始化函数
para f: int,切分的特征,用样本中的特征次序表示
para fvalue: float or int,切分特征的决策值
para leafLable: int,叶节点的标签
para l: biTree_node指针,内部节点的左子树
para r: biTree_node指针,内部节点的右子树
para splitInfo="gini": string, 切分的标准,可取值'infogain'和'gini',分别表示信息增益和基尼指数
'''
self.f = f
self.fvalue = fvalue
self.leafLabel = leafLabel
self.l = l
self.r = r
self.splitInfo = splitInfo
def build_biTree(samples, splitInfo="gini"):
'''构建树
para samples:list,样本的列表,每样本也是一个列表,样本的最后一项为label,其它项为特征
para splitInfo="gini": string, 切分的标准,可取值'infogain'和'gini',分别表示信息增益和基尼指数
return biTree_node:Class biTree_node,二叉决策树的根结点
'''
if len(samples) == 0:
return biTree_node()
if splitInfo != "gini" and splitInfo != "infogain":
return biTree_node()
bestInfo = 0.0
bestF = None
bestFvalue = None
bestlson = None
bestrson = None
if splitInfo == "gini":
curInfo = gini_index(samples) # 当前集合的基尼指数
else:
curInfo = info_entropy(samples) # 当前集合的信息熵
sumOfFeatures = len(samples[0]) - 1 # 样本中特征的个数
for f in range(0, sumOfFeatures): # 遍历每个特征
featureValues = [sample[f] for sample in samples]
for fvalue in featureValues: # 遍历当前特征的每个值
lson, rson = split_samples(samples, f, fvalue)
if splitInfo == "gini":
# 计算分裂后两个集合的基尼指数
info = (gini_index(lson)*len(lson) + gini_index(rson)*len(rson))/len(samples)
else:
# 计算分裂后两个集合的信息熵
info = (info_entropy(lson)*len(lson) + info_entropy(rson)*len(rson))/len(samples)
gain = curInfo - info # 计算基尼指数减少量或信息增益
# 能够找到最好的切分特征及其决策值,左、右子树为空说明是叶子节点
if gain > bestInfo and len(lson)>0 and len(rson)>0:
bestInfo = gain
bestF = f
bestFvalue = fvalue
bestlson = lson
bestrson = rson
if bestInfo > 0.0:
l = build_biTree(bestlson)
r = build_biTree(bestrson)
return biTree_node(f=bestF, fvalue=bestFvalue, l=l, r=r, splitInfo=splitInfo)
else: # 如果bestInfo==0.0,说明没有切分方法使集合的基尼指数或信息熵下降了
label_counts = sum_of_each_label(samples)
# 返回该集合中最多的类别作为叶子节点的标签
return biTree_node(leafLabel=max(label_counts, key=label_counts.get), splitInfo=splitInfo)
def predict(sample, tree):
'''
对样本sample进行预测
para sample:list,需要预测的样本
para tree:biTree_node,构建好的分类树
return: biTree_node.leafLabel,所属的类别
'''
# 1、只是树根
if tree.leafLabel != None:
return tree.leafLabel
else:
# 2、有左右子树
sampleValue = sample[tree.f]
branch = None
if sampleValue >= tree.fvalue:
branch = tree.r
else:
branch = tree.l
return predict(sample, branch)
tree_list=[]
def print_tree(tree, level='0'):
'''简单打印一颗树的结构
para tree:biTree_node,树的根结点
para level='0':str, 节点在树中的位置,用一串字符串表示,0表示根节点,0L表示根节点的左孩子,0R表示根节点的右孩子
'''
global tree_list
if tree.leafLabel != None:
tree_list.append('*' + level + '-' + str(tree.leafLabel))
print('*' + level + '-' + str(tree.leafLabel)) #叶子节点用*表示,并打印出标签
else:
tree_list.append('+' + level + '-' + str(tree.f) + '-' + str(tree.fvalue))
print('+' + level + '-' + str(tree.f) + '-' + str(tree.fvalue)) #中间节点用+表示,并打印出特征编号及其划分值
print_tree(tree.l, level+'L')
print_tree(tree.r, level+'R')
return tree_list
def func(x_dict,x,seq):
if x_dict[seq][0]=='*':
return int(x_dict[seq][1:])
f_list=x_dict[seq].split('-')
f=int(f_list[0])
fvalue=int(f_list[1])
if x[f]<fvalue:
seq+='L'
return func(x_dict,x,seq)
if x[f]>=fvalue:
seq+='R'
return func(x_dict,x,seq)
def predict_from_tree_list(l_tree,x):
x_dict={}
for node in l_tree:
node_info=node.split('-')
if node_info[0][0]=='+':
x_dict[node_info[0][1:]]=node_info[1]+'-'+node_info[2]
if node_info[0][0]=='*':
x_dict[node_info[0][1:]]='*'+node_info[1]
#print(x_dict)
return func(x_dict,x,'0')
if __name__ == "__main__":
mushrooms=pd.read_csv('mushroom.data')
X=mushrooms.iloc[:,0:23]
X=X.values
all_date=(X.tolist())
random.shuffle(all_date)
blind_date=all_date[500:700]
print("信息增益二叉树:")
tree = build_biTree(blind_date, splitInfo="infogain")
l_tree=print_tree(tree)
print("l_tree:",l_tree)
print('信息增益二叉树对样本进行预测的结果:')
#test_sample = all_date[1000:7500]
test_sample = all_date[3000:7500]
l_test=len(test_sample)
err=0
for x in test_sample:
#err+=(predict(x, tree)==x[-1])
err+=(predict_from_tree_list(l_tree,x)==x[-1])
print("accurency:",err/l_test)