Scalable Rule-Based Representation Learning for Interpretable Classification

文章目录

Wang Z., Zhang W., Liu N. and Wang J. Scalable rule-based representation learning for interpretable classification. In Advances in Neural Information Processing Systems (NIPS), 2021.

传统的诸如决策树之类的机器学习方法具有很强的结构性, 也因此具有很好的可解释性. 和深度学习方法相比, 这类方法比较难以推广到大规模的问题上, 很重要的一个原因便是, 其离散的参数和结构导致无法利用梯度进行优化. 本文是对利用梯度来优化这些模型的一个尝试.

主要内容

本文考虑的是上图(a)中的离散模型, 其接受连续变量 C i C_i Ci和离散变量 B i B_i Bi:

  1. 通过Binarization Layer 将连续变量 C i C_i Ci离散化并与 B i B_i Bi拼接得到输入 u ( 0 ) \bm{u}^{(0)} u(0);
  2. 对于Logical Layer, 其以 u l − 1 \bm{u}^{l-1} ul1为输入, 输出 u l \bm{u}^l ul, 其包含且 r \bm{r} r和或 s \bm{s} s两个部分:
    r i ( l ) = ⋀ W i j ( l , 0 ) = 1 u j ( l − 1 ) , s i ( l ) = ⋁ W i j ( l , 1 ) = 1 u j ( l − 1 ) . r_i^{(l)} = \bigwedge_{W_{ij}^{(l, 0)} = 1} u_j^{(l-1)}, \\ s_i^{(l)} = \bigvee_{W_{ij}^{(l, 1)} = 1} u_j^{(l-1)}. \\ ri(l)=Wij(l,0)=1uj(l1),si(l)=Wij(l,1)=1uj(l1).
    其中 W ( l , 0 ) W^{(l, 0)} W(l,0)表示 r \bm{r} r u \bm{u} u的邻接矩阵, 而 W ( l , 1 ) W^{(l, 1)} W(l,1)表示 s \bm{s} s u \bm{u} u的邻接矩阵. 可以发现, Logical Layer中的输入输出和权重都是二元的.
  3. 最后通过一个线性层进行分类, 需要说明的是, 线性层的权重是连续的.

显然由于logical layer是离散的, 直接通过梯度更新是办不到的. 一个自然的想法是用一个连续的版本 F ^ ( X ; θ ) \hat{\mathcal{F}}(X; \theta) F^(X;θ)进行替换, 更新连续的参数 θ \theta θ然后获得下列的离散的版本:
F ( X ; q ( θ ) ) , q ( x ) = I x > 0.5 . \mathcal{F}(X; q(\theta)), \quad q(x) = \mathbb{I}_{x > 0.5}. F(X;q(θ)),q(x)=Ix>0.5.
显然直接套用这个方法是低效的, 因为训练过程和离散没有任何关系, 我们没法保证离散后的模型依旧是有效的, 此外还有一个问题, 上述离散模型如何匹配到一个连续的版本.

下面是一个有趣的解决方案, 假设 W ^ i , j ∈ [ 0 , 1 ] \hat{W}_{i,j} \in [0, 1] W^i,j[0,1], 则
C o n j ( u , W i ) = ∏ j = 1 n { 1 − W i , j ( 1 − u j ) } , D i s j ( u , W i ) = 1 − ∏ j = 1 n { 1 − W i , j u j } , Conj (\bm{u}, W_i) = \prod_{j=1}^n \bigg\{1 - W_{i,j}(1 - u_j) \bigg\}, \\ Disj (\bm{u}, W_i) = 1 - \prod_{j=1}^n \bigg\{1 - W_{i,j}u_j \bigg\}, \\ Conj(u,Wi)=j=1n{1Wi,j(1uj)},Disj(u,Wi)=1j=1n{1Wi,juj},
便为且和或操作的连续版本.
试想:
r i = 1 ⇔ ⋀ j [ u j ( l − 1 ) ∨ ( 1 − W i j ) ] = 1 ⇔ ∏ j { 1 − W i , j ( 1 − u j ) } = 1. \begin{array}{ll} & r_i = 1 \\ \Leftrightarrow & \bigwedge_j [u_j^{(l-1)} \vee (1 - W_{ij})] = 1\\ \Leftrightarrow & \prod_j \bigg\{1 - W_{i,j}(1 - u_j) \bigg\} = 1.\\ \end{array} ri=1j[uj(l1)(1Wij)]=1j{1Wi,j(1uj)}=1.
其它情况可以类似推导, 实在是有趣.

但是上述式子在实际中会有一些梯度消失的问题(因为连乘号, 且内部是[0, 1]之间的), 所示在实际使用中, 作者加了一个投影算子
C o n j + = P ( C o n j ( u , W i ) ) , Conj_+ = \mathbb{P}(Conj (\bm{u}, W_i)), Conj+=P(Conj(u,Wi)),
其中(这设计都是为了避免梯度消失, 怎么想到的? 怎么会往这个方向去想的?)
P ( v ) = − 1 − 1 + log ⁡ ( v ) . \mathbb{P}(v) = \frac{-1}{-1 + \log (v)}. P(v)=1+log(v)1.

解决了连续版本的问题, 现在剩下的难啃的地方是如何更新 θ \theta θ以保证 q ( θ ) q(\theta) q(θ)也是有意义的.
作者采用如下的梯度更新公式:
θ t + 1 = θ t − η ∂ L ( Y ˉ ) ∂ Y ˉ ⋅ ∂ Y ^ ∂ θ t , \theta^{t+1} = \theta^t - \eta \frac{\partial \mathcal{L}(\bar{Y})}{\partial \bar{Y}} \cdot \frac{\partial \hat{Y}}{\partial \theta^t}, θt+1=θtηYˉL(Yˉ)θtY^,
其中 Y ^ = F ^ ( X ; θ ) \hat{Y} = \hat{\mathcal{F}}(X; \theta) Y^=F^(X;θ), Y ˉ = F ( X ; θ ˉ ) \bar{Y} = \mathcal{F}(X; \bar{\theta}) Yˉ=F(X;θˉ).
作者用了一个嫁接的例子来说明该思想, 即损失关于预测的导数用离散的, 内部的导数用连续的.

我惊讶的是, 这些改动居然work? 太不可思议了.

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值