机器学习可解释性【随机森林规则提取】

cover

引言

目前,机器学习模型应用于各行各业,数据量够多,那就用深度学习吧,数据量少了,传统机器学习算法也能行。
然而机器学习模型作为“黑盒模型”,人们越来越担心其安全性,因而希望模型具有可解释性。

本文主要讲:

  1. 模型可解释性方案有哪些
  2. 随机森林规则提取的方法有哪些
  3. 随机森林规则提取,如何实现

相关工作

模型可解释性方案可分为:

  1. 事前可解释性建模:
    有些模型自带可解释性,如:朴素贝叶斯、线性回归、决策树、基于规则的
    模型,针对这些模型,在训练之前,从头设计满足可解释性的模型。

  2. 事后可解释性分析:
    模型已经训练好了,然后再进行解释。

自解释模型本身内置可解释性,如决策树模型,自上而下每条路径代表一条决策,模型可解释性很直观。然而,人类认知有限,自解释模型的内置可解释性受模型复杂度的限制,如果树的深度过深或模型过于复杂,人类也难以理解。但结构太简单,其模型拟合能力必然受限。

在训练后,再解释模型,相对能解决此问题。
即先通过选择最优参数来训练模型,此时得到的结果较好,此时再对“黑盒模型”实施拆箱操作,分析其可解释性,即:事后可解释性分析。
可是,往往可解释性最好的模型并非结果最好。

因而,两种方案都需要权衡取舍。
本文主要介绍随机森林规则提取。

随机森林规则提取

随机森林是基于 Bagging 的集成学习模型,通过集成多棵决策树来提升模型决策能力。随机森林由决策树构成,从决策树的根结点到其叶子节点的一条路径,可以认为是一条由多条 if-then 条件构成的规则。

随机森林规则提取,事前、事后都可以做。主要的算法有:RF+HC 以及 RF+HC_CMPR
这两种算法,重点在于规则筛选方面,区别主要在于 RF+HC_CMPR 在规则打分公式中加入了规则的长度。

本文主要针对已训练好的随机森林模型进行事后可解释性分析,其方法简单易用,赶紧点赞收藏(hhhh,kaiwanxiaola)。
本文的规则提取思路比较简单,步骤如下:

  1. 训练好随机森林模型
  2. 遍历随机森林模型中所有子决策树,并提取出所有规则集
  3. 去除重复规则集
  4. 通过规则的长度、误差、频率筛选出简化规则集

代码实现

不想看代码(我不想看代码~)

1. 代码解析

save_decision_rules(self,rf, csv_path) :
遍历所有决策树的规则集,并保存。
举个例子,一棵决策树如下图所示:
在这里插入图片描述
可见,由圆形表示为规则,左边为满足规则,右边为不满足规则,
存储的时候,满足规则,存储为1,不满足规则存储为0,上图中,保存的规则集为:

 TREE:0
 NODE:0,是否房产价值>100w,4,1
 NODE:1,是否有其他值钱的抵押物,4,2
 NODE:2,月收入>10k,3,5
 NODE:3,是否结婚,4,5
 LEAF:4,1
 LEAF:5,0

TREE:0 , 表示第0棵决策树
NODE:0, 表示非叶子节点0
LEAF:4, 表示叶子节点4

从上至下为决策树判断过程,如:
NODE:0,是否房产价值>100w,4,1,表示:房产价值>100w,是:跳到编号4,否则:跳到编号1,
编号4,即:LEAF:4,1,即:给予贷款;编号1,即:NODE:1,是否有其他值钱的抵押物,4,2

这样,所有决策树的规则全保存好了。

read_decision_rules(self,path):
从保存文件中,读取所有规则集,即:先遍历左子树,再遍历右子树,
其中,left_tree(self,tree, left,top_feature) 为遍历左子树,
right_tree(self,tree, right, top_feature) 为遍历右子树。
最终得到规则集如下所示:

是否房产价值>100w:1,1
是否房产价值>100w:0,是否有其他值钱的抵押物:1,1
是否房产价值>100w:0,是否有其他值钱的抵押物:0,月收入>10k:1,是否结婚:1,1
是否房产价值>100w:0,是否有其他值钱的抵押物:0,月收入>10k:0,0
是否房产价值>100w:0,是否有其他值钱的抵押物:0,月收入>10k:1,是否结婚:0,0

这样,得到了5条规则集。

filter_rules(self,rules_path):
去除重复规则集

save_rules(self, path):
保存规则集

2. 全部代码实现
import numpy
import config
import constants
import pandas as pd

def getFeatures(_path):
	""" 获取特征集 """
    df = pd.read_csv(_path)
    cols = df.columns.values.tolist()

    X = df[cols]
    return X.columns

class RFAnalysis():

    def __init__(self):
        self.l_one_rule,self.r_one_rule = [], []
        self.tree_results = []
        self.results = []  # 所有树的规则

    def save_decision_rules(self,rf, csv_path):
        features = getFeatures(csv_path)

        txt_path = constants.OS_PATH + '/output/模型解释/随机森林.txt' # 保存路径
        with open(txt_path, 'w') as f:
            for tree_idx, est in enumerate(rf.estimators_):
                tree = est.tree_
                assert tree.value.shape[1] == 1  # no support for multi-output

                f.write('TREE: {}'.format(tree_idx) + '\n')
                print('TREE: {}'.format(tree_idx))
                iterator = enumerate(
                    zip(tree.children_left, tree.children_right, tree.feature, tree.threshold, tree.value))
                for node_idx, data in iterator:
                    left, right, feature, th, value = data

                    class_idx = numpy.argmax(value[0])

                    # 写入文件
                    if left == -1 and right == -1:
                        print('{} LEAF: return class={}'.format(node_idx, class_idx))
                        f.write('LEAF:' + str(node_idx) + ',' + str(class_idx) + '\n')
                    else:
                        print(
                            '{} NODE: if feature[{}] < {} then next={} else next={}'.format(node_idx, features[feature],
                                                                                            th,
                                                                                            left, right))
                        f.write('NODE:' + str(node_idx) + ',' + str(features[feature]) + ',' + str(left) + ',' + str(
                            right) + '\n')
                f.write("#\n") # 每棵树以"#"结束

    def left_tree(self,tree, left,top_feature):  # 左边:规则
        self.r_one_rule.append(top_feature+':0')
        line = tree[int(left)]

        if line.find("LEAF") != -1:  # 叶子节点
            l = line.split(",")
            value = l[-1]
            if len(self.r_one_rule) > 0: # 没有右边的值,就不加
                self.r_one_rule.append(value)
                _rule = self.r_one_rule.copy()
                self.tree_results.append(_rule)
                del self.r_one_rule[-1]
                del self.r_one_rule[-1]


        if line.find('NODE') != -1:  # 继续遍历
            l = line.split(",")
            feature = l[1]
            _left = l[2]
            _right = l[3]
            # 遍历左子树
            self.left_tree(tree, _left,feature)
            # 遍历右子树
            self.right_tree(tree, _right, feature)

    def right_tree(self,tree, right, top_feature):  # 右边:规则

        if top_feature+':0' in self.r_one_rule:
            self.r_one_rule.remove(top_feature+':0')

        self.r_one_rule.append(top_feature+':1')
        line = tree[int(right)]

        if line.find("LEAF") != -1:  # 叶子节点
            l = line.split(",")
            value = l[-1]
            self.r_one_rule.append(value)
            _rule = self.r_one_rule.copy()
            self.tree_results.append(_rule)
            # del self.r_one_rule[-1]
            del self.r_one_rule[-1]
            del self.r_one_rule[-1]

        if line.find('NODE') != -1:  # 继续遍历
            l = line.split(",")
            feature = l[1]
            _left = l[2]
            _right = l[3]
            # 遍历左子树
            self.left_tree(tree, _left,feature)
            # 遍历右子树
            self.right_tree(tree, _right, feature)

    def read_decision_rules(self,path):
        trees = []
        rules = []
        with open(path, 'r') as f:
            for line in f:
                if line.find('#') != -1:
                    trees.append(rules)
                    rules = []
                else:
                    if line.find('TREE:') != -1:
                        continue
                    rules.append(line)


        for i, tree in enumerate(trees):  # 遍历每棵树
            self.tree_results = []  # 一棵树的所有规则

            root = tree[0]
            print(root)
            l = root.split(",")
            feature = l[1]
            left = l[2]
            right = l[3]

            self.left_tree(tree, left,feature)
            self.r_one_rule = []
            self.right_tree(tree, right, feature)

            self.results.append(self.tree_results)
            # print(self.tree_results)
        # print(self.results)

    def save_rules(self, path):
        l = []
        with open(path, 'w') as f:
            for i, tree in enumerate(self.results):
                for j, value in enumerate(tree):
                    if (len(value) <= 2):
                        continue
                    l.append(value)
                    print(value)
                    for w,k in enumerate(value):
                        if w != 0:
                            f.write(',')
                        f.write(k)
        print(len(l))

    def filter_rules(self,rules_path,save_path=""):
        """ 规则去重 """
        rules = []
        with open(rules_path, 'r') as f:
            for line in f:
                rules.append(line)

        rules_copy = rules.copy()
        for k,v in enumerate(rules):
            r = [i for i,x in enumerate(rules) if x is v]
            print(r)

    def get_rule_frequency_error(self,csv_path,rules_path,save_path):
        """ 计算每条规则频率和误差,并保存在:save_path 中 """
        rules = [] # rules:字典:{'尿黄':0}
        _id = 0
        with open(rules_path, 'r') as f:
            for line in f:
                rule = {}
                l = line.split(",")
                label = l[-1].replace('\n', '')
                rule['id'] = _id
                for i in l[:-1]:
                    block = i.split(":")
                    key = block[0]
                    value = block[1]
                    rule[key] = value
                rule['label'] = label
                rules.append(rule)
                _id += 1
        # print(rules)

        df = pd.read_csv(csv_path)
        df_len = len(df)
        for i, rule in enumerate(rules):
            rule['frequency1'] = 0
            rule['error1'] = 0
            for row in df.itertuples():
                is_true = True # 是否有满足规则的样本
                for k, value in enumerate(rule):
                    if value == 'frequency1' or value == 'id' or value == 'error1':
                        continue

                    if value == 'label':
                        row_value = int(getattr(row, constants.ZHENGHOU1))
                        r = int(rule[value])
                        if row_value != r:
                            rule['error1'] = rule['error1'] + 1
                        continue

                    row_value = int(getattr(row, value))
                    r = int(rule[value])
                    if row_value != r:
                        is_true = False
                        break
                if is_true:
                    rule['frequency1'] = rule['frequency1'] + 1 # 满足规则样本数加一
            rule['frequency2'] = rule['frequency1'] / df_len

            if rule['frequency1'] > 0:
                rule['error2'] = rule['error1'] / rule['frequency1']
                print(rule['id'],', ',rule['frequency1'])

        print(len(rules))

        # 存储频率不为0的规则
                with open(save_path, 'w') as f:
            for i, rule in enumerate(rules):
                if rule['frequency1'] == 0:
                    continue
                for k, value in enumerate(rule):
                    block = value+":"+str(rule[value])
                    f.write(block)
                    if value != 'error2':
                        f.write(',')
                f.write('\n')

    def get_rank_rules(self,rules_path):
        """ 获取规则排序,频率高,误差小 """
        rules = []
        with open(rules_path, 'r') as f:
            for line in f:
                rule = {}
                l = line.split(",")
                last = l[-1].replace('\n', '')
                l[-1] = last
                is_true = False
                is_true_true = False
                for i in l:
                    block = i.split(":")
                    key = block[0]
                    value = block[1]
                    # 筛选频率大于 0。01的
                    rule[key] = value
                    if key == 'frequency2' and float(value) > 0.03:
                        is_true = True
                    if key == 'error2' and is_true and float(value) < 0.05:
                        is_true_true = True
                if is_true_true:
                    rules.append(rule)
        # print(rules)
        ranked_rules = sorted(rules, key=lambda i: i['frequency2'],reverse=True)
        for i in ranked_rules:
            print(i)
        # print(ranked_rules[0:20])

if __name__ == '__main__':
    rf_analysis = RFAnalysis()

    csv_path = config.PATH
    # X_train,X_test,y_train,y_test = data_utils.split(csv_path)
    # estimator = models.randomForestClassifier()
    # estimator.fit(X_train, y_train)

    # 提取并存储规则集
    # rf_analysis.save_decision_rules(estimator,csv_path)

    # 整理规则集
    # txt_path = constants.OS_PATH + '/output/模型解释/随机森林.txt'
    # rf_analysis.read_decision_rules(txt_path)
    #
    # 保存规则集
    # save_path = constants.OS_PATH + '/output/模型解释/结果.txt'
    # rf_analysis.save_rules(save_path)

    # rf_analysis.filter_rules(rules_path=save_path)

    # csv_path = constants.OS_PATH + '/output/模型解释/smote.csv'
    # 获取规则集
    rules_path = constants.OS_PATH + '/output/模型解释/结果.txt'
    save_path = constants.OS_PATH + '/output/模型解释/结果_频率_误差.txt'
    rf_analysis.get_rule_frequency_error(csv_path,rules_path,save_path)
    # rf_analysis.get_rank_rules(rules_path=save_path)



总结

本文首先介绍了机器学习模型可解释性分为:

  1. 事前可解释性建模
  2. 事后可解释性分析

随机森林规则提取,既可做事前也可做事后分析。
本文主要针对事后可解释性分析,提出了先通过参数优化建立随机森林模型,然后提取规则集,再将规则集去重,通过误差、频率、长度来筛选规则集。

本文的方法也存在不足,主要在于其筛选方法过于简单,可能筛选不到最佳规则集,同时在算法上,未经优化,循环过多,数据量太大时,较为耗时。
在以后研究中,将加入其他可解释性分析,包括:深度学习可解释性问题。

谢谢

  • 6
    点赞
  • 34
    收藏
    觉得还不错? 一键收藏
  • 11
    评论
随机森林是一种常用的机器学习算法,用于解决分类和回归问题。在随机森林中,特征重要性是评估每个特征对模型预测能力的贡献程度的一种指标。R语言中的randomForestExplainer包提供了解释随机森林模型的功能。 在使用randomForestExplainer包解释随机森林模型时,可以使用以下方法来解读特征重要性结果: 1. 使用randomForestExplainer包中的函数来计算特征重要性。这些函数可以从随机森林对象中提取特征重要性的度量值。常用的度量包括: - 变量扰动后的预测精度降低(度量a) - 变分裂后节点纯度的变化(度量b) 变量扰动后的预测精度降低的平均(度量c) - 变量分裂后节点纯度变化的平均值(度量d) - 基于森林结构的度量(度量e-i) 2. 根据具体的度量值,可以判断特征的重要性。例如,如果度量a和c的值较大,则表示该特征对模型的预测能力有较大的贡献;如果度量b和d的值较大,则表示该特征对节点纯度的变化有较大的影响;如果度量e-i的值较大,则表示该特征在森林结构中起到了重要的作用。 3. 可以使用randomForestExplainer包中的其他函数来可视化特征重要性结果例如,可以使用plot_min_depth_distribution函数来绘制最小深度的分布图,使用多元重要性绘制函数来比较不同特征的重要性,使用交互图像绘制函数来展示特征之间的交互关系等。 总之,通过使用randomForestExplainer包提供的函数和方法,可以对随机森林模型的特征重要性进行解读和可视化,从而更好地理解模型的预测能力和特征之间的关系。
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值