Beyond Sparsity: Tree Regularization of Deep Models for Interpretability

Beyond Sparsity: Tree Regularization of Deep Models for Interpretability

这篇文章是使用树正则方法对深度网络的可解释性的探索,论文的一作作者为Mike Wu, 时为Stanford University博士一年级学生,该论文发表在AAAI18, 该论文的源代码地址为https://github.com/dtak/tree-regularization-public。

可解释性相关概念

可解释性就是人类的模仿性(human simulatability)。如果人类可以在合适时间内采用输入数据和模型参数,经过每个计算步,作出预测,那么该模型就具备了模仿性。

在医疗中,该模仿性的应用为:给定一个模仿性模型,医生可以根据检查模型的每一步是否违背其专业知识,甚至推断数据中的公平性和系统偏差等,这可以帮助从业者利用正向反馈来对模型进行改进。

本文探讨的主要对象 – 决策树,就具有模仿性。通过决策树,按照相关的基本情况在决策树下一直走下去,就可以理解这些特征并进行预测。

此处树正则化与提升DNN的模仿性之间是否有直接必然联系?

源码执行与演示

首先从官网上下载对应的源码,接着按照README.md文件中的指导,构建conda环境(可选项),然后使用

pip install -r experiments.txt

下载所需要的所有python 包。

(值得注意的是,此处可以进行换源,参照http://blog.csdn.net/lambert310/article/details/52412059)

linux下,修改 ~/.pip/pip.conf (没有就创建一个), 修改 index-url至tuna,内容如下:

[global]
index-url = https://pypi.tuna.tsinghua.edu.cn/simple

接着运行

python datasets.py

然后运行 python train.py --strength 1000.0,该指令将会进行基于Tree Regularization 网络的训练。

最后的训练数据为:

training deep net... [12/12], learning rate: 0.0100
model: gru | iter: 0 | loss: 1113.41 | apl: 1.41
model: gru | iter: 10 | loss: 904.21 | apl: 1.41
model: gru | iter: 20 | loss: 716.69 | apl: 1.41
building surrogate dataset...
training surrogate net... [12/12]
model: mlp | iter: 0 | loss: 23.58
model: mlp | iter: 250 | loss: 0.65
model: mlp | iter: 500 | loss: 0.55
model: mlp | iter: 750 | loss: 0.55
model: mlp | iter: 1000 | loss: 0.55
model: mlp | iter: 1250 | loss: 0.55
model: mlp | iter: 1500 | loss: 0.55
model: mlp | iter: 1750 | loss: 0.55
model: mlp | iter: 2000 | loss: 0.55
model: mlp | iter: 2250 | loss: 0.55
model: mlp | iter: 2500 | loss: 0.55
model: mlp | iter: 2750 | loss: 0.55
saved trained model to ./trained_models
saved final decision tree to ./trained_models

然后对其精度进行测试 python test.py. (相关细节将在之后进行说明)

/home/kyi/anaconda2/envs/interpret/lib/python2.7/site-packages/autograd/core.py:120: UserWarning:
------------------------------
  defgrad is deprecated!
------------------------------
Use defvjp instead ("define vector-Jacobian product").
The interface is a little different - look at
autograd/numpy/numpy_grads.py for examples.

  warnings.warn(defgrad_deprecated)
Test AUC: 0.82

由输出数据可知,Test AUC为0.82.

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值