《BrainNN》阅读笔记

《Joint Embedding of Structural and Functional Brain Networks with Graph Neural Networks for Mental Illness Diagnosis》


前言

之前先看的《BrainNNExplainer》,里面用到的消息传递网络BrainNN就源于本篇文章。这篇文章也是顶会文章,Individual level的,创新点在于同时用了结构像和功能像的MRI去预测疾病。

论文地址:https://arxiv.org/abs/2107.03220
代码地址:寻找中…


一、模型

在这里插入图片描述
结构像DTI以及功能像fMRI分别构成一个graph,通过参数共享的网络学习出各自的全局向量表示。根据聚合函数,将两个模态的全局向量表示融合得到被试的表示向量,用于下游任务。

1.定义

数据集: M = { ( { G i s , G i f } ) , y i } i S M = \{(\{G_{i}^s, G_{i}^f \}), y_{i}\}_{i}^S M={({Gis,Gif}),yi}iS 。其中S代表被试数量, G i s G_{i}^s Gis G i f G_{i}^f Gif分别代表第 i i i个被试的结构脑网络和功能脑网络。

脑网络: G i ∗ = ( V , E i ∗ , W i ∗ ) G_{i}^* = (V, E_{i}^*, W_{i}^*) Gi=(V,Ei,Wi) V = { v i } i = 1 N V = \{v_{i}\}_{i=1}^N V={vi}i=1N为节点集, E i ∗ = V × V E_{i}^* = V × V Ei=V×V是边集, W i ∗ ∈ R N × N W_{i}^* ∈ R^{N × N} WiRN×N则为加权邻接矩阵。

属性:  x n = [ d e g ( n ) ; m i n ( D n ) ; m a x ( D n ) ; m e a n ( D n ) ; s t d ( D n ) ] x_{n} = [deg(n); min(D_{n}); max(D_{n}); mean(D_{n}); std(D_{n})] xn=[deg(n);min(Dn);max(Dn);mean(Dn);std(Dn)] D n = { d e g ( m ) ∣ ( n , m ) ∈ e i j } D_{n} = \{deg(m) | (n,m)∈ e_{ij}\} Dn={deg(m)(n,m)eij}
     其中 d e g ( . ) deg(.) deg(.)为度, [ . , . ] [. , .] [.,.]为拼接操作。

标签:  y i y_{i} yi(患者或正常人)。


2.BrainNN

  1. 根据节点 i i i的嵌入表示、节点 j j j的嵌入表示以及它们之间的连接强度 w i j w_{ij} wij生成 i , j i, j i,j之间的消息向量 m i j m_{ij} mij

    m i j ( l ) = t Θ ( [ h i ( l ) ; h j ( l ) ; w i j ] ) m_{ij}^{(l)} = t_{\Theta}([h_{i}^{(l)}; h_{j}^{(l)}; w_{ij}]) mij(l)=tΘ([hi(l);hj(l);wij])

  1. 对于节点 i i i,聚合其所有邻居的消息向量并经非线性激活后得到其表示向量。

    h i ( l ) = σ ( ∑ j ∈ N i ∪ { i } m i j ( l − 1 ) ) h_{i}^{(l)} = \sigma(\sum_{j ∈ N_{i} ∪ \{i\}}m_{ij}^{(l-1)}) hi(l)=σ(jNi{i}mij(l1))

  2. Readout层聚合所有节点的嵌入表示,得到全局的表示向量 z ∈ R D z ∈ R^{D} zRD

    z ′ = ∑ i ∈ V h i ( k ) , z = t Φ ( z ′ ) + z ′ z'=\sum_{i∈V}h_{i}^{(k)}, z=t_{\Phi}(z') + z' z=iVhi(k),z=tΦ(z)+z

    这里的公式原文中应该是写错了。还有一点就是,这个全局表示向量 z z z仅仅是一个模态的(结构像或功能像)。文中取两个模态的全局向量 z i s , z i f z_{i}^s,z_{i}^f zis,zif的均值, z i = ( z i s + z i f ) / 2 z_{i}=(z_{i}^s+z_{i}^f)/2 zi=(zis+zif)/2,作为被试最后用于下游任务的表示向量。

3.损失函数

由于这个模型是有用到对比学习的方法的,所以在损失函数中除了有交叉熵损失之外,还引入了对比损失 J c o n J_{con} Jcon

     J c o n = 1 2 S ∑ G i ∈ M [ 1 N ∑ v i ∈ V ( I ( h i f ; z i s ) + I ( h i s ; z i f ) ) ] J_{con} = \frac{1}{2S}\sum_{G_{i}∈M}[\frac{1}{N}\sum_{v_{i}∈V}(I(h_{i}^f;z_{i}^s)+I(h_{i}^s;z_{i}^f))] Jcon=2S1GiM[N1viV(I(hif;zis)+I(his;zif))]

     I ( h i ; z i ) = − s p ( − d ( h i , z i ) ) − 1 N − 1 ∑ v j ∈ V ∖ { v i } s p ( d ( h i , z j ) ) I(h_{i};z_{i})=-sp(-d(h_{i},z_{i})) - \frac{1}{N-1}\sum_{v_{j}∈V\setminus\{v_{i}\}}sp(d(h_{i},z_{j})) I(hi;zi)=sp(d(hi,zi))N11vjV{vi}sp(d(hi,zj))

     s p ( x ) = l o g ( 1 + e x ) sp(x)=log(1+e^x) sp(x)=log(1+ex)

     d ( a , b ) = s i g m o i d ( < a , b > ) d(a,b)=sigmoid(<a,b>) d(a,b)=sigmoid(<a,b>)


二、实验

1.对比实验

在这里插入图片描述
这里的准确率和《BrainNNExplainer》中的准确率让我迷惑了。


2.消融实验

在这里插入图片描述
对比V-GCN是为了体现本文提出的消息传递网络的性能。

对比CONCAT是为了体现本文提出的多模态融合策略(取均值)的好处。


3.可视化

在这里插入图片描述



  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值