《BrainNNExplainer》阅读笔记

《BrainNNExplainer: An Interpretable Graph Neural Network Framework for Brain Network based Disease Analysis》


前言

Individual-level的。这篇文章中作者做了一个假设:患有同一种疾病的患者具有较为相似的脑网络模式。而脑网络模式又可以间接的通过功能连接矩阵表示出来,因此作者单独设计了一个结构去找那些对疾病预测影响较大的连接(边)。

论文地址:https://arxiv.org/abs/2107.05097v1

代码地址:暂未开源…


一、模型

在这里插入图片描述

这个模型由两个结构组成:BrainNN以及Explanation Generator。BrainNN是一个消息传递网络,主要的任务就是用于生成全图的表示向量;Explanation Generator则被用于训练一个所有被试共享的掩膜,进一步找到那些与疾病相关性较大的连接,捕捉脑网络模式。

整个训练过程被分为三步:(1). 训练BrainNN的参数;(2).训练Explanation Generator的参数;(3).再次训练BrainNN的参数。


符号定义:

      G = (V, E, W)

      V:节点集合,{vi}, i = 1, … , n

      E: 邻接矩阵(0, 1), Rn x n

      W:连接强度矩阵(加权),描述ROI之间的连接强度,Rn x n

      M:掩膜,Rn x n


1.BrainNN

(1).节点属性

节点的初始属性对GNN的表现是有一定影响的。因此,作者在实验中尝试了以下几种方法生成节点属性:one-hot向量、LDP、degree、binning degree和node2vec。


在LDP中,节点的属性定义如下:

在这里插入图片描述
在这里插入图片描述
其中 deg(.) 为度,[. ; .] 为拼接操作。


(2).消息传递

  1. 节点 vi 与其每一个一跳邻居 vj 生成消息向量 mij ∈ RD[. ; .] 应该也是拼接。

    在这里插入图片描述

  2. 每个节点聚合其对应的消息向量,生成节点的表示向量。

    在这里插入图片描述

  3. 通过readout层获得全局表示向量z ∈ RD

在这里插入图片描述


然后用全局表示向量 z 去做分类并用交叉熵 Lp 训练这个模型。我觉得这个消息传递的原理和GAT挺像的…


2.Explanation Generator

考虑到患有同一种疾病的患者可能会有较为相似的脑网络模式,那么如果能捕获这种共同的模式,就可以比较准确的判断某个被试是否患有这种疾病了。为了达到这个目的,作者训练了一个全局共享的掩膜M ∈ Rn x n

1. 将掩膜 M 覆盖于原连接强度矩阵 W 上,得到相应的矩阵 W’

在这里插入图片描述
其中σ为sigmoid函数。

2.G’ = (V, E, W’) 输入到训练过后的BrainNN网络中,预测其标签 y ′ ^ \hat{y'} y^

然后用损失函数 L 去训练网络的参数。这个损失函数 L, 由交叉熵损失 Lp 、互信息损失 Lm (希望 G’G 之间的的互信息最大)、稀疏性损失 Ls (对 G’ 的边的数量进行一定限制)以及逐元素的交叉熵损失 Le (希望矩阵M的元素是离散的)构成。


   L p = − 1 N ∑ i = 1 N ( y i l o g ( y i ^ ) + ( 1 − y i ) l o g ( 1 − y i ^ ) ) L_p = -\frac{1}{N}\sum_{i=1}^N(y_ilog(\hat{y_i}) + (1 - y_i)log(1 - \hat{y_i})) Lp=N1i=1N(yilog(yi^)+(1yi)log(1yi^))

   L m = − ∑ c = 1 C 1 [ y = c ] l o g P ϕ ( y ′ = y ∣ G = W ′ ) L_m = -\sum_{c=1}^C1[y=c]logP_\phi(y'=y | G = W') Lm=c=1C1[y=c]logPϕ(y=yG=W)

   L s = ∑ i = 1 n ∑ j = 1 n a i j , M = [ a i j ] L_s = \sum_{i=1}^n\sum_{j=1}^na_{ij}, M = [a_{ij}] Ls=i=1nj=1naij,M=[aij]

   L e = − ( M l o g ( M ) + ( 1 − M ) l o g ( 1 − M ) ) L_e = -(M log(M) + (1 - M)log(1 - M)) Le=(Mlog(M)+(1M)log(1M))


3.训练过程

整个训练过程分为三步:

  1. G = (V, E, W) 以及对应的标签 c 输入BrainNN网络,训练网络的各种参数,损失函数为 Lp

  2. 经初次训练后的BrainNN模型BrainNN的预测 y ^ \hat{y} y^G’ = (V, E, W’) 及其实际标签 c 输入到Explanation Generator网络中, 冻结BrainNN的所有参数,仅训练掩膜 M 及其它相关参数, 损失函数为 L

  3. 将训练好的掩膜 M 覆盖在原数据集上,得到 G’ = (V, E, W’)再次将其输入到BrainNN网络中,进一步训练BrainNN的参数(冻结M)


二、实验

1.可解释性分析

以下图片仅展示了权值大于阈值的边。

在这里插入图片描述

以下表格展示的是功能网络重要程度的top3,分别用了三种指标。

在这里插入图片描述


2.对比实验

在这里插入图片描述
相比于前两栏中的模型,BrainNN已经在指标上有了不错的提升。这说明作者提出的消息传递方法还是比较有效的,有利于获得更具代表性的全局表示向量。而BrainNNExplainer更为突出的表现,验证了掩膜的合理性(捕获了共同的脑网络模式)。


  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值