文章目录
Neural Prototype Trees for Interpretable Fine-grained Image Recognition
文章来源:CVPR 2021
Motivation
如何设计一种可以解决深度学习方法的天然黑箱问题,以及拥有不错分类准确率的可解释网络?
Main idea
目前基于原型的方法利用可解释的表征去解决深度学习的天然箱问题,以及树是一种结构透明、且易于理解和可解释的层级结构,作者结合这两者的优势设计可解释模型来解决所提出的问题。
Main Contributions
- 本文提出了一种用于细粒度图像识别的本质上可解释的神经原型树架构;
- 所提出的ProtoTree方法使用集成方法、修剪和二值化来调整准确性-可解释性的权衡;
- ProtoTree 的性能优于 ProtoPNet,而原型数量仅占原型数量的 10%。
Method
该模型分为两个部分:(1). 用于原型提取的卷积神经网络,(2). 用于将原型与图像块进行比较分类的决策树。
- CNN 层是完善的图像分类网络的卷积层,所有这些都在相关的数据集上进行了预训练。 一个额外的1*1卷积层的输出维度为 7 x 7 x D(D 为 128、256 或 512)被添加到现有层。 此输出表示用作原型的潜在空间中的patch。
- 在原型层,将测试图像块与学习的原型块进行比较,并使用全局最小池化(最大池化取反)计算路由每个节点(即原型)的概率,找出概率最大的那条路径,得到的类别分布就是最终的分类情况。
Architecture
找到距离原型最近的隐层表征patch,即将每个原型作为一个滑窗进行滑动并计算两者之间的距离。下面的公式就是表示寻找每个原型最近的隐层表征的patch:
在决策树部分,作者定义了利用每个原型的最近隐层表征的patch和该原型计算出路由通过不同节点(即左孩子和右孩子)的概率,公式如下所示:
Training
ProtoTree为了使得可解释行在利用原型进行分类预测的决策树阶段是一棵hard tree,即在最后测试是只选择了一条路径进行分类预测,但是为了在训练时可微和方便反向传播,在训练时训练了soft tree,最后在训练结束后将其转换为Hard Tree进行预测。
该方法在训练的时候全程只有交叉熵作为损失函数。
在训练soft tree时,由于soft tree的特性可得,所有的节点都参与预测分类,具体计算如(4)所示:
其中,
表示从样本z到该叶子节点的概率。
具体的soft tree的训练算法为:
其中,
是叶子节点处类别的分布。
Pruning
在ProtoTree中,原型的个数与解释的尺寸相关。为了减少解释的尺寸,我们分析了叶子中学习到的类概率分布,并移除了分布几乎均匀的叶子,即分辨力低的叶子,具体操作如下图所示:
在剪枝的时候会提前设定阈值,大于类别数分之一的叶子节点保留,其余的剪枝。
###Prototype Visualization
寻找到隐层表征patch中与原型进行替换,在模型训练结束后。为了可视化到样本的一个patch,利用显著图的方式,计算隐层表征的每个patch与该原型的相似得分,最后映射到原图上反映出来。相似分数的计算公式如下:
Deterministic reasoning
由于soft tree是所有节点都参与预测,而hard tree 只是一条路径中的所有节点参与预测。此外,hard tree比soft tree 更容易解释,所以在训练结束要进行soft转换为hard,有两种方式:
- 选择到达叶子节点概率最大的路径;
- 利用贪心算法遍历所有的路径,最终选择一条路径
##Experiments
CUB-200数据集和CARS-196数据集上与SOTA对比的结果。
主要结论是,所提出的方法在提供更多可解释性的同时,实现了与大多数最先进方法相当的准确性。此外,作者使用卷积网络的集成来找到越来越多有意义的原型,而且这种集成方法不会影响可解释性,可是它只会进一步提升模型分类的准确率。
下面两个表和一张图给我们展示了该方法的效果,表2说明了,剪枝对于该模型的分类准确率几乎没什么影响,但是降低了要解释的尺寸,以及用样本的中的隐层表征patch替换了训练出的原型猴也是对于分类准确率没什么影响;表3说明了两种由soft到hard的方式对于分类准确率的影响不大,都具有较好的保真度,而且两种方法的路径长度一致。图7说明了,该方法对于初始化的树的高度有较大的影响,至少保证叶子节点数大于类别数(soft tree)。
下图是一个利用ProtoTree用于预测的示例:
可以从这张图中看到尽管分类正确但是也会学到一些偏差。
Conclusion
本文所提出的方法ProtoTree,在ProtoPNet的基础上有了更进一步的提升,使得可解释性更强,使用的原型更少。对于该方法进行集成也是可以达到与非可解释的方法相近的准确率。
Think
-
ProtoTree的二叉树设计是否只最优的设计?
-
除了这种结合树的方式来设计可解释模型,还有什么方法?
-
现在ProtoTree是在细粒度分类问题上提出的,如果在每个类别差异很大的情况下效果会怎么样呐,效果会更好还是会更差?
-
文章中用的backbone都是resnet50,如果用更深的网络是不是会出现网络表现不好的情况呐?
-
文章中用的backbone都是resnet50,如果用更深的网络是不是会出现网络表现不好的情况呐?