《BrainNNExplainer: An Interpretable Graph Neural Network Framework for Brain Network based Disease Analysis》
前言
Individual-level的。这篇文章中作者做了一个假设:患有同一种疾病的患者具有较为相似的脑网络模式。而脑网络模式又可以间接的通过功能连接矩阵表示出来,因此作者单独设计了一个结构去找那些对疾病预测影响较大的连接(边)。
论文地址:https://arxiv.org/abs/2107.05097v1
代码地址:暂未开源…
一、模型
这个模型由两个结构组成:BrainNN以及Explanation Generator。BrainNN是一个消息传递网络,主要的任务就是用于生成全图的表示向量;Explanation Generator则被用于训练一个所有被试共享的掩膜,进一步找到那些与疾病相关性较大的连接,捕捉脑网络模式。
整个训练过程被分为三步:(1). 训练BrainNN的参数;(2).训练Explanation Generator的参数;(3).再次训练BrainNN的参数。
符号定义:
G = (V, E, W)
V:节点集合,{vi}, i = 1, … , n。
E: 邻接矩阵(0, 1), Rn x n。
W:连接强度矩阵(加权),描述ROI之间的连接强度,Rn x n。
M:掩膜,Rn x n。
1.BrainNN
(1).节点属性
节点的初始属性对GNN的表现是有一定影响的。因此,作者在实验中尝试了以下几种方法生成节点属性:one-hot向量、LDP、degree、binning degree和node2vec。
在LDP中,节点的属性定义如下:
其中 deg(.) 为度,[. ; .] 为拼接操作。
(2).消息传递
- 节点 vi 与其每一个一跳邻居 vj 生成消息向量 mij ∈ RD。[. ; .] 应该也是拼接。
- 每个节点聚合其对应的消息向量,生成节点的表示向量。
- 通过readout层获得全局表示向量z ∈ RD。
然后用全局表示向量 z 去做分类并用交叉熵 Lp 训练这个模型。我觉得这个消息传递的原理和GAT挺像的…
2.Explanation Generator
考虑到患有同一种疾病的患者可能会有较为相似的脑网络模式,那么如果能捕获这种共同的模式,就可以比较准确的判断某个被试是否患有这种疾病了。为了达到这个目的,作者训练了一个全局共享的掩膜M ∈ Rn x n。
1. 将掩膜 M 覆盖于原连接强度矩阵 W 上,得到相应的矩阵 W’。
其中σ为sigmoid函数。
2. 将 G’ = (V, E, W’) 输入到训练过后的BrainNN网络中,预测其标签 y ′ ^ \hat{y'} y′^。
然后用损失函数 L 去训练网络的参数。这个损失函数 L, 由交叉熵损失 Lp 、互信息损失 Lm (希望 G’ 与 G 之间的的互信息最大)、稀疏性损失 Ls (对 G’ 的边的数量进行一定限制)以及逐元素的交叉熵损失 Le (希望矩阵M的元素是离散的)构成。
L
p
=
−
1
N
∑
i
=
1
N
(
y
i
l
o
g
(
y
i
^
)
+
(
1
−
y
i
)
l
o
g
(
1
−
y
i
^
)
)
L_p = -\frac{1}{N}\sum_{i=1}^N(y_ilog(\hat{y_i}) + (1 - y_i)log(1 - \hat{y_i}))
Lp=−N1∑i=1N(yilog(yi^)+(1−yi)log(1−yi^))
L
m
=
−
∑
c
=
1
C
1
[
y
=
c
]
l
o
g
P
ϕ
(
y
′
=
y
∣
G
=
W
′
)
L_m = -\sum_{c=1}^C1[y=c]logP_\phi(y'=y | G = W')
Lm=−∑c=1C1[y=c]logPϕ(y′=y∣G=W′)
L
s
=
∑
i
=
1
n
∑
j
=
1
n
a
i
j
,
M
=
[
a
i
j
]
L_s = \sum_{i=1}^n\sum_{j=1}^na_{ij}, M = [a_{ij}]
Ls=∑i=1n∑j=1naij,M=[aij]
L e = − ( M l o g ( M ) + ( 1 − M ) l o g ( 1 − M ) ) L_e = -(M log(M) + (1 - M)log(1 - M)) Le=−(Mlog(M)+(1−M)log(1−M))
3.训练过程
整个训练过程分为三步:
-
将 G = (V, E, W) 以及对应的标签 c 输入BrainNN网络,训练网络的各种参数,损失函数为 Lp 。
-
将经初次训练后的BrainNN模型、BrainNN的预测 y ^ \hat{y} y^、G’ = (V, E, W’) 及其实际标签 c 输入到Explanation Generator网络中, 冻结BrainNN的所有参数,仅训练掩膜 M 及其它相关参数, 损失函数为 L。
-
将训练好的掩膜 M 覆盖在原数据集上,得到 G’ = (V, E, W’)。再次将其输入到BrainNN网络中,进一步训练BrainNN的参数(冻结M)。
二、实验
1.可解释性分析
以下图片仅展示了权值大于阈值的边。
以下表格展示的是功能网络重要程度的top3,分别用了三种指标。
2.对比实验
相比于前两栏中的模型,BrainNN已经在指标上有了不错的提升。这说明作者提出的消息传递方法还是比较有效的,有利于获得更具代表性的全局表示向量。而BrainNNExplainer更为突出的表现,验证了掩膜的合理性(捕获了共同的脑网络模式)。