简单的决策树模型传输和预测本地数据的一种方法

便于分(坐)享(享)模(其)型(成)的好方法!


目录

记录本文的原因

一、利用print_tree函数保存树

二、利用tree_list列表传输树

三、利用predict_from_tree_list函数恢复树

四、完整代码

附件


记录本文的原因

       本文是对以下这篇博客(后称为前文)的延伸: 决策树分类算法(一)(信息熵,信息增益,基尼指数计算)_阿维的博客日记的博客-CSDN博客

        为什么要写这篇文章呢?我在学习大佬的决策树博客时,想到一个问题:能不能直接拿他训练好的模型来预测我的数据呢?这就涉及到模型的保存、传输和恢复。但查到的资料都是对sklearn库中的决策树进行操作,而不是这种大佬自己创建的class对象,所以写了点代码来恢复决策树模型。

注:本文所用数据集为 mushroom.data(本文最后附件),其特征维数为22,标签值‘1’代表有毒,‘2’代表无毒


一、利用print_tree函数保存树

前文的print_tree函数:

def print_tree(tree, level='0'):
    '''简单打印一颗树的结构
    para tree:biTree_node,树的根结点
    para level='0':str, 节点在树中的位置,用一串字符串表示,0表示根节点,0L表示根节点的左孩子,0R表示根节点的右孩子  
    '''
    if tree.leafLabel != None:
        print('*' + level + '-' + str(tree.leafLabel))  # 叶子节点用*表示,并打印出标签
    else:
        print('+' + level + '-' + str(tree.f) + '-' + str(tree.fvalue))  # 中间节点用+表示,并打印出特征编号及其划分值
        print_tree(tree.l, level + 'L')
        print_tree(tree.r, level + 'R')

为了保存树,创建一个全局列表tree_list,将其打印的东西都加入列表中。修改后的print_tree函数:

tree_list=[]
def print_tree(tree, level='0'):
    '''简单打印一颗树的结构
    para tree:biTree_node,树的根结点
    para level='0':str, 节点在树中的位置,用一串字符串表示,0表示根节点,0L表示根节点的左孩子,0R表示根节点的右孩子  
    '''
    global tree_list
    if tree.leafLabel != None:
        tree_list.append('*' + level + '-' + str(tree.leafLabel))
        print('*' + level + '-' + str(tree.leafLabel)) #叶子节点用*表示,并打印出标签
    else:
        tree_list.append('+' + level + '-' + str(tree.f) + '-' + str(tree.fvalue))
        print('+' + level + '-' + str(tree.f) + '-' + str(tree.fvalue)) #中间节点用+表示,并打印出特征编号及其划分值
        print_tree(tree.l, level+'L')
        print_tree(tree.r, level+'R')
    return tree_list

其结果形式为一棵完整的树,表示如下:

['+0-4-29', '+0L-4-26', '*0LL-1', '+0LR-19-103', '*0LRL-2', '+0LRR-0-4', '*0LRRL-2', '*0LRRR-1', '*0R-1']

 其中各项的具体含义请参照前文
 

二、利用tree_list列表传输树

        由上一点可知,一个tree_list可以表示一棵完整的树,所以树的传输也就是该列表的传输。A可以用任何方式把这个列表发给B,比如在微信聊天框里发给B、保存到一个txt文件里发给B、保存到文件后附带数字签名和加密操作后发给B……只要B能还原出这个列表l_tree,就代表树传输成功。


三、利用predict_from_tree_list函数恢复树

B接收到树列表l_tree后,遍历自己数据集test_sample中的每一项x(x的最后一项为标签值)来判断A的模型准确率是否达标。

    for x in test_sample:
        #err+=(predict(x, tree)==x[-1])
        err+=(predict_from_tree_list(l_tree,x)==x[-1])
    print("accurency:",err/l_test)

函数predict_from_tree_list的思想是,先从l_tree中获取一个字典x_dict,其形式为:

{'0': '4-29', '0L': '4-26', '0LL': '*1', '0LR': '11-61', '0LRL': '2-18', '0LRLL': '*2', '0LRLR': '0-7', '0LRLRL': '*1', '0LRLRR': '*2', '0LRR': '7-39', '0LRRL': '*1', '0LRRR': '*2', '0R': '*1'}

先上代码:

def predict_from_tree_list(l_tree,x):
    x_dict={}
    for node in l_tree:
        node_info=node.split('-')
        if node_info[0][0]=='+':
            x_dict[node_info[0][1:]]=node_info[1]+'-'+node_info[2]
        if node_info[0][0]=='*':
            x_dict[node_info[0][1:]]='*'+node_info[1]
    #print(x_dict)
    return func(x_dict,x,'0')

其中返回的func为递归函数,输入字典x_dict和这一条数据x以及当前的序列seq:

def func(x_dict,x,seq):
    if x_dict[seq][0]=='*':
        return int(x_dict[seq][1:])
    f_list=x_dict[seq].split('-')
    f=int(f_list[0])
    fvalue=int(f_list[1])
    if x[f]<fvalue:
        seq+='L'
        return func(x_dict,x,seq)
    if x[f]>=fvalue:
        seq+='R'
        return func(x_dict,x,seq)

 如何理解呢?我们举一个例子来说明。先赋值:

l_tree = ['+0-4-29', '+0L-4-26', '*0LL-1', '+0LR-2-18', '*0LRL-2', '+0LRR-2-19', '*0LRRL-1', '*0LRRR-2', '*0R-1']
x=[3, 11, 13, 24, 28, 34, 37, 39, 45, 53, 54, 59, 63, 67, 76, 85, 86, 90, 94, 98, 109, 114, 2]

可以得到:

x_dict = {'0': '4-29', '0L': '4-26', '0LL': '*1', '0LR': '11-61', '0LRL': '*2', '0LRR': '7-39', '0LRRL': '*1', '0LRRR': '*2', '0R': '*1'}

        即存储的若是叶子节点,就将其路径和分类结果写入字典,否则写入在该节点进入下一节点的判断依据:a-b表示x[a]与b的比较,若x[a]<b,则进入左子树,若x[a]>=b,则进入右子树。获取x_dict后,开始迭代运行func。

        因为第一次判断都是从根节点‘0’开始,所以输入func的第一个序列seq为‘0’。首先从字典中判断该序列是否为叶子结点,是则return该节点代表的分类,否则用 f fvalue 表示 a b ,然后进行判断。若x[f]<fvalue,则进入左子树,即序列seq+=‘L’并进入下一次迭代,反之同理。由此就可以根据树列表l_tree预测出x的分类结果。

信息增益二叉树:
+0-4-29
+0L-4-26
*0LL-1
+0LR-19-103
*0LRL-2
+0LRR-2-19
*0LRRL-2
*0LRRR-1
*0R-1
l_tree: ['+0-4-29', '+0L-4-26', '*0LL-1', '+0LR-19-103', '*0LRL-2', '+0LRR-2-19', '*0LRRL-2', '*0LRRR-1', '*0R-1']
信息增益二叉树对样本进行预测的结果:
accurency: 0.9893333333333333

四、完整代码

splitInfo.py可在前文中获得

rebuild_tree.py:

# coding:UTF-8
from splitInfo import info_entropy, gini_index, split_samples, sum_of_each_label
import pandas as pd
import random

class biTree_node:
    '''
    二叉树节点
    '''
    def __init__(self, f=-1, fvalue=None, leafLabel=None, l=None, r=None, splitInfo="gini"):
        '''
        类初始化函数
        para f: int,切分的特征,用样本中的特征次序表示
        para fvalue: float or int,切分特征的决策值
        para leafLable: int,叶节点的标签
        para l: biTree_node指针,内部节点的左子树
        para r: biTree_node指针,内部节点的右子树
        para splitInfo="gini": string, 切分的标准,可取值'infogain'和'gini',分别表示信息增益和基尼指数
        '''
        self.f = f
        self.fvalue = fvalue
        self.leafLabel = leafLabel
        self.l = l
        self.r = r
        self.splitInfo = splitInfo
        
def build_biTree(samples, splitInfo="gini"):
    '''构建树
    para samples:list,样本的列表,每样本也是一个列表,样本的最后一项为label,其它项为特征
    para splitInfo="gini": string, 切分的标准,可取值'infogain'和'gini',分别表示信息增益和基尼指数
    return biTree_node:Class biTree_node,二叉决策树的根结点
    '''
    if len(samples) == 0:
        return biTree_node()
    if splitInfo != "gini" and splitInfo != "infogain":
        return biTree_node()
    
    bestInfo = 0.0
    bestF = None
    bestFvalue = None
    bestlson = None
    bestrson = None

    if splitInfo == "gini":
        curInfo = gini_index(samples) # 当前集合的基尼指数
    else:
        curInfo = info_entropy(samples) # 当前集合的信息熵
        
    sumOfFeatures = len(samples[0]) - 1 # 样本中特征的个数
    for f in range(0, sumOfFeatures): # 遍历每个特征
        featureValues = [sample[f] for sample in samples]
        for fvalue in featureValues:  # 遍历当前特征的每个值
            lson, rson = split_samples(samples, f, fvalue)
            if splitInfo == "gini":
                # 计算分裂后两个集合的基尼指数
                info = (gini_index(lson)*len(lson) + gini_index(rson)*len(rson))/len(samples)
            else:
                # 计算分裂后两个集合的信息熵
                info = (info_entropy(lson)*len(lson) + info_entropy(rson)*len(rson))/len(samples)
            gain = curInfo - info # 计算基尼指数减少量或信息增益
            # 能够找到最好的切分特征及其决策值,左、右子树为空说明是叶子节点
            if gain > bestInfo and len(lson)>0 and len(rson)>0:
                bestInfo = gain
                bestF = f
                bestFvalue = fvalue
                bestlson = lson
                bestrson = rson
    
    if bestInfo > 0.0: 
        l = build_biTree(bestlson)
        r = build_biTree(bestrson)
        return biTree_node(f=bestF, fvalue=bestFvalue, l=l, r=r, splitInfo=splitInfo)
    else: # 如果bestInfo==0.0,说明没有切分方法使集合的基尼指数或信息熵下降了
        label_counts = sum_of_each_label(samples)
        # 返回该集合中最多的类别作为叶子节点的标签
        return biTree_node(leafLabel=max(label_counts, key=label_counts.get), splitInfo=splitInfo)

def predict(sample, tree):
    '''
    对样本sample进行预测
    para sample:list,需要预测的样本
    para tree:biTree_node,构建好的分类树
    return: biTree_node.leafLabel,所属的类别
    '''
    # 1、只是树根
    if tree.leafLabel != None:
        return tree.leafLabel
    else:
    # 2、有左右子树
        sampleValue = sample[tree.f]
        branch = None
        if sampleValue >= tree.fvalue:
            branch = tree.r
        else:
            branch = tree.l
        return predict(sample, branch)
    
tree_list=[]
def print_tree(tree, level='0'):
    '''简单打印一颗树的结构
    para tree:biTree_node,树的根结点
    para level='0':str, 节点在树中的位置,用一串字符串表示,0表示根节点,0L表示根节点的左孩子,0R表示根节点的右孩子  
    '''
    global tree_list
    if tree.leafLabel != None:
        tree_list.append('*' + level + '-' + str(tree.leafLabel))
        print('*' + level + '-' + str(tree.leafLabel)) #叶子节点用*表示,并打印出标签
    else:
        tree_list.append('+' + level + '-' + str(tree.f) + '-' + str(tree.fvalue))
        print('+' + level + '-' + str(tree.f) + '-' + str(tree.fvalue)) #中间节点用+表示,并打印出特征编号及其划分值
        print_tree(tree.l, level+'L')
        print_tree(tree.r, level+'R')
    return tree_list

def func(x_dict,x,seq):
    if x_dict[seq][0]=='*':
        return int(x_dict[seq][1:])
    f_list=x_dict[seq].split('-')
    f=int(f_list[0])
    fvalue=int(f_list[1])
    if x[f]<fvalue:
        seq+='L'
        return func(x_dict,x,seq)
    if x[f]>=fvalue:
        seq+='R'
        return func(x_dict,x,seq)
    
def predict_from_tree_list(l_tree,x):
    x_dict={}
    for node in l_tree:
        node_info=node.split('-')
        if node_info[0][0]=='+':
            x_dict[node_info[0][1:]]=node_info[1]+'-'+node_info[2]
        if node_info[0][0]=='*':
            x_dict[node_info[0][1:]]='*'+node_info[1]
    #print(x_dict)
    return func(x_dict,x,'0')

if __name__ == "__main__":
    
    mushrooms=pd.read_csv('mushroom.data')
    X=mushrooms.iloc[:,0:23]
    X=X.values
    all_date=(X.tolist())
    random.shuffle(all_date)
    blind_date=all_date[500:700]

    print("信息增益二叉树:")
    tree = build_biTree(blind_date, splitInfo="infogain")
    l_tree=print_tree(tree)
    print("l_tree:",l_tree)
    print('信息增益二叉树对样本进行预测的结果:')
    #test_sample = all_date[1000:7500]
    test_sample = all_date[3000:7500]

    l_test=len(test_sample)
    err=0
    for x in test_sample:
        #err+=(predict(x, tree)==x[-1])
        err+=(predict_from_tree_list(l_tree,x)==x[-1])
    print("accurency:",err/l_test)


附件

UCI Machine Learning Repository: Mushroom Data Set

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Roymasterpiece

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值