Post-hoc Concept Bottleneck Models (PCBM)

ICLR 2023 spotlight

文章链接:https://arxiv.org/abs/2205.15480

代码链接:https://github.com/mertyg/post-hoc-cbm

一、概述

        Post-hoc CBM(PCBM)也是CBM大家族中的一员,因此它的基本逻辑与CBM一致,就是在输入和输出之间构造一个bottleneck用于预测concepts。和其它很多文章类似,作者同样指出了CBM模型的缺点:

        (i) dense annotation,即需要大量精细的标注;

        (ii)  accuracy-interpretability trade-off,即准确性与可解释性之间的取舍与权衡(尤其是在concepts not enough的情况下);

        (iii) local intervention,即CBM只是针对个例进行干预,而不是提升模型本身的效果。

        因此,本文提出PCBM,可以将任何网络转化为PCBM,且在不牺牲模型精度的同时保证可解释性;此外,当训练集中缺失annotation时,PCBM可以从其它数据集或使用多模态模型产生概念:transfer concepts from other datasets or from natural descriptions of concepts via multimodal models”——在介绍CBM那篇文章的时候提到过——或者,引入一个residual modeling step来recover the original blackbox model's performance。此外,PCBM允许global model edits即全局的intervention,这种方法会比针对specific prediction的local intervention更加有效。


二、方法

        We let f:\mathcal{X}\rightarrow\mathbb{R}^{d} be any pretrained backbone model, where d is the size of the corres-ponding embedding space and \mathcal{X} is the input space.  f 可以是CLIP中的image encoder或者ResNet的倒数第二层(总之是一个编码器)。

        建立PCBM需要以下几个步骤:

(i) Learning the Concept Subspace

        为了学习concept representations,作者使用了CAVs的做法,首先定义了一个概念集合concept library I=\left \{ i_1,i_2,...,i_{N_c} \right \},其中 N_c 代表concepts的总数;concept library可以由domain expert定义或者从数据中自动学习(参考NeurIPS 2019, Towards automatic concept-based explanations.https://arxiv.org/abs/1902.03129)。

        For each concept i, we collect embeddings for the positive examples, denoted by the set P_i, and negative examples N_i.

P_i=\left \{ f(x_{p_1}),...,f(x_{p_{N_p}}) \right \}

N_i=\left \{ f(x_{n_1}),...,f(x_{n_{N_n}}) \right \}

        作者训练了一个SVM对 P_i 与 N_i 分类,并计算对应的CAV(分类边界的法向量),并且与TCAV相同,CAV的学习并不局限于the data used to train the backbone model;将第 i 个concept对应的CAV记为 \boldsymbol{c}_i,let \boldsymbol{C}\in \mathbb{R}^{N_c\times d} denote the matrix of concept vectors. \boldsymbol{C} 的每一行就代表第 i 个concept对应的CAV \boldsymbol{c}_i

        现在,我们有一个backbone model f 作为encoder,一个由一系列CAVs组成的concept matrix \boldsymbol{C}。此时给定输入 x,我们可以通过 f_{\boldsymbol{C}}(x)=\mathrm{proj}_{\boldsymbol{C}}f(x)\in\mathbb{R}^{N_c}将 f(x) 投影到由 \boldsymbol{C} 张成的向量空间,i.e., f_{\boldsymbol{C}}^{(i)}(x)=\frac{\left \langle f(x),\boldsymbol{c}_i \right \rangle}{\left \| \boldsymbol{c_i} \right \|_{2}^{2}}\in\mathbb{R},即 f_{\boldsymbol{C}}^{(i)}(x) 代表当前输入在第 i 个concept vector \boldsymbol{c}_i 方向上的长度(是一个scalar),直观来说就是当前输入 x 中包含概念 \boldsymbol{c}_i 的程度(图中红色方框)👇

(ii) Leveraging multimodal models to learn concepts

        前面提到CBM需要dense annotation,限制了实际应用。作者提出可以使用多模态模型比如CLIP来生成concept vector,具体来说,由于CLIP (Radford et al., 2021)具有一个image encoder和一个text encoder可以将二者编码到shared embedding space中,因此我们可以通过mapping the prompt using the text encoder to obtain the concept vectors;举例来说,如果我们想得到“strpes”这一concept对应的CAV但是又缺少标注好的数据,我们可以通过将“stripes”输入到CLIP的text encoder中,使用其编码后得到的向量作为CAV(其实就不叫CAV了,但是得到的这个向量也是类似CAV的一种用来表示概念的向量;为方便理解,此处索性就统一叫作CAV,但不要混淆),i.e. \boldsymbol{c}_{\textrm{stripes}}^{\textrm{text}}=f_{\textrm{text}}(\textrm{"stripes"});这样,对于每一个concept我们都有对应的语言表述,也都能相应地得到CAV,由此得到我们的multimodal concept bank \boldsymbol{C}^\textrm{text}.

Note:CAVs与Multimodal Models两种方法二选一,而不是将两种方法得到的CAV求并。

        对于classification task,可以使用ConceptNet (Speer et al., 2017)来自动获取与类别相关的concepts,从而构建concept bank。

(iii) Learning the Interpretable Predictor

        Let g:\mathbb{R}^{N_c}\rightarrow \mathcal{Y} be an interpretable predictor. g 可以选择线性模型或者决策树这种具有较强可解释性的模型,将预测得到的评分 f_{\boldsymbol{C}}(x) 映射为最终的类别 \mathcal{Y}。通过优化以下式子来学习模型:

\min\limits_g \mathbb{E}_{(x,y)\sim \mathcal{D}}[\mathcal{L}(g(f_{\boldsymbol{C}}(x)),y)]+\frac{\lambda }{N_cK}\Omega (g)

\Omega (g)=\alpha \left \| \omega \right \|_1(1-\alpha )\left \| \omega \right \|_{2}^{2}

        前面一项对应分类损失(如交叉熵),后面一项为正则项,用来限制predictor g 的复杂度,并由类别和概念的数量进行归一化。在这项工作中作者使用的是sparse linear models。

(iv) Recovering the original model performance with residual modeling

        即使我们拥有了一个相对丰富的概念子空间,概念很可能仍然不足以解决我们感兴趣的下游任务。对于这种情况,即PCBM与原始模型性能不匹配时,作者引入了从original embedding连接到最终决策层的残差部分,以保持原有模型的准确度,对应的模型为PCBM-h。此时,作者使用sequential的训练方式,首先训练 interpretable predictor g ,然后固定concept bottleneck and the interpretable predictor并优化残差部分:

\min\limits_r \mathbb{E}_{(x,y)\sim \mathcal{D}}[\mathcal{L}(g(f_{\boldsymbol{C}}(x))+r(f(x)),y)]

        其中 r 是residual predictor,其输入是原始的不具有解释性的embedding,而最后的输出结果是综合了interpretable predictor的输出 g(f_{\boldsymbol{C}}(x))以及residual predictor的输出 r(f(x))。可以将r(f(x)) 视为原来interpretable predictor的一种补充;g的输入是interpretable concept embeddings,r 的输入是uninterpretable的original embeddings from backbone encoder. 模型的决策由 g 尽量解释,解释不了的由 r 来恢复原始精度。很显然,PCBM-h的精度一定是高于PCBM的。

Note:如果想观察interpretable predictor g 的表现,那么就把residual predictor r 网络中的参数全部置零从而drop掉这一支路,如果我们想得到一个黑盒模型,就把 g 网络中的参数全部置零。


三、实验及结果

(i) PCBMs achieve comparable performance to the original model

        PCBMs获得了与黑盒模型comparable的性能,尤其是PCBM-h。

(ii) PCBMs achieve comparable performance to the original model

        当提供的concepts not available or insufficient的时候,可以使用借助CLIP的text ecncoder产生的concept bank,发现CLIP自动生成的concept要比人为提供的概念标注更好。

(iii) Explaining Post-hoc CBMs

        展示了针对于一个类别线性层中权重最大的三个concepts,在皮肤癌的例子中,模型考虑的concept与人类判断时考虑的因素一致。

(iv) Model editing

        与基本的CBM对单个样本做干预(local intervention)不同,PCBM的一个优势就是允许global intervention从而直接提升模型整体的表现。当我们知道某些概念是错误的时候,可以通过剪枝(Prune)等操作优化模型。举个例子,如果训练集和测试集存在域偏差,比如,训练集中有很多“狗”的图片,但是在测试集中没有“狗”的图片,那么在训练阶段学习到的所有关于狗的概念都将无效,或者说对于测试集是“错误的概念”;此时我们可以采用以下三种strategies对模型进行修改:

        (1) Prune: 在决策层将错误概念对应的权重置0,i.e., for a concept indexed by i, we let \tilde{\boldsymbol{\omega }}_i=0

        (2) PruneNormalize:在prune后rescale the concept weights,归一化可以缓解剪枝后较大权重造成的权值不平衡问题;

        (3) Fine-tune (Oracle):在测试集上对整个模型进行微调,作为oracle。

        可以发现PCBM进行PruneNormalize之后的增益较高,最接近oracle;而PCBM-h的增益很低。一个原因是PCBM可以通过Prune直接剪掉干扰预测的错误概念,但是由于PCBM-h的残差连接中仍包含来自错误概念的信息无法被去除,因此预测精度的提升不明显。

(v) User study

        作者还进行了user-study,即测试集与训练集存在偏差时,让user自行选择一定数量的concepts进行prune,观察模型性能是否有提高,以验证模型能够良好的与人类进行交互;作者使用了三个实验设置作为对比:

        (1) Random Pruning:随机对weights置零;

        (2) Greedy pruning(Oracle):即prune掉与人类同样数量的concepts使得模型得到最佳增益;

        (3) Fine-tune (Oracle):在测试集上微调。

        Random prune发生了性能降低,而user prune可以明显改善模型性能,大概相当于80%的greedy prune增益与50%的fine-tune增益。

        另一个现象是即使有残差连接但是仍然可以通过剪枝提高PCBM-h的性能,具体原因不知道。


        最后是简单的discussion:

        (1) 人类构建的concept bottleneck是否可以解决更大规模的任务是一个悬而未决的问题(例如ImageNet级别),因为会有information bottleneck的存在,精度concept定义insufficient,也是导致accuracy-interpretability之间有trade-off的原因所在。

        (2) 以无监督的方式为模型寻找概念子空间是一个活跃的研究领域,它将有助于构建更加有用的、丰富的概念瓶颈。

  • 16
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

exploreandconquer

谢谢老板,老板大气!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值