[论文精读]BrainTGL: A dynamic graph representation learning model for brain network analysis

论文网址:BrainTGL: A dynamic graph representation learning model for brain network analysis - ScienceDirect

英文是纯手打的!论文原文的summarizing and paraphrasing。可能会出现难以避免的拼写错误和语法错误,若有发现欢迎评论指正!文章偏向于笔记,谨慎食用

目录

1. 心得

2. 论文逐段精读

2.1. Abstract

2.2. Introduction

2.3. Related work

2.3.1. The static brain network analysis methods

2.3.2. The dynamic brain network analysis methods

2.4. Method

2.4.1. Problem statement

2.4.2. Dynamic brain network series construction

2.4.3. Data augmentation

2.4.4. Attention based graph pooling

2.4.5. Dual temporal graph learning

2.4.6. Ensemble

2.4.7. Variation of braintgl for unsupervised clustering

2.4.8. Theoretical contribution

2.5. Experiments

2.5.1. Datasets and environment

2.5.2. Comparison with the prior works on the brain network classification

2.5.3. Ablation study

2.5.4. Discussion

2.5.5. Disease subtype clustering analysis

2.6. Limitations and future works

2.7. Conclusion

3. Reference


1. 心得

(1)不是很难的论文

(2)实验好多啊~~哎

2. 论文逐段精读

2.1. Abstract

        ①Existing works do not consider the relationships between spatial and temporal characteristics in brain

        ②They proposed a temporal graph representation learning framework for brain networks (BrainTGL)

2.2. Introduction

        ①They proposed a dual temporal graph learning (DTGL) module

2.3. Related work

        ①Methods of brain network analysis:

(作者并没有在图里多说而是在图下面附上了很长的解释...我倒是觉得这样有一些不妥,让图本身比较费解。作者认为(a)是最传统的静止FC分析,(b)也是相对来说静止的,我猜测就是去BOLD上滑个窗,作者觉得这没有考虑到“动态空间依赖性”,而(c)的两阶段学习则是让时空分离了,(d)是作者的)

2.3.1. The static brain network analysis methods

        ①Listing some static methods and elaborating that they cannot learn advanced features

2.3.2. The dynamic brain network analysis methods

        ②Examples some dynamic models and explains that they cannot simulate two types of network at the same time

2.4. Method

2.4.1. Problem statement

        ①Time series data: X=\left\{x_{1}^{(i)},\quad x_{2}^{(i)},\ldots,x_{n}^{(i)}\right\}_{i=1}^{N}, where n denotes the number of ROI

        ②Label: Y=\{y_{1},y_{2},\ldots,y_{N}\} and y_{i}\in\{-1,1\}

        ③Dataset: D=\left\{G_t^{(1)},G_t^{(2)},\ldots,G_t^{(N)}\right\}_{t=1}^T

        ④Graph: G_{t}=\{V,X_{t},A_{t}\}

        ⑤Atlas: CC200

        ⑥Adjacency matrix: calculated by Pearson correlation coefficient (PCC)

        ⑦Mapping function: f:\{G_{t}\}_{t=1}^{T}\to Y

        ⑧Overall framework:

2.4.2. Dynamic brain network series construction

        ①Brain graph construction:

2.4.3. Data augmentation

        ①Cropping each BOLD signals to the same length, and then divide them:

2.4.4. Attention based graph pooling

        ①Then coarsen the original graph to {\widehat G}=\left\{​{\widehat V},{\widehat A}\right\} with super nodes \widehat{V}=\{SN_{1},SN_{2},\ldots,SN_{c}\} and adjacency matrix \widehat{A}=F^TAF\in\mathbb{R}^{c\times c}

        ②They define learnable parameter F\in\mathbb{R}^{n\times c}:

F_{ij}=\left\{\begin{array}{ll}s_i,&i\in SN_j\\0,&i\not\in SN_j\end{array}\right.

where s_i denotes the importance score of each node

        ③Schematic:

        ④Values of super edge e_{ij} in adjacency matrix: s_{i}*w_{ij}*s_{j}

2.4.5. Dual temporal graph learning

        ①The sketch map:

        ②Signal representation learning (S-RL) module:

e^{(l+1)}\left(u\right)=\sum_{s=0}^{U-1}e^{(l)}\left(u-s\right)*\mathcal{K}^{(l)}\left(s\right)

where \mathcal{K}^{(l)} denotes the convolutional in the l-th layer, U denotes the kernel size and u denotes the element in BOLD signal

        ③Features of supernode: max pooling by original nodes

        ④Proposed a temporal graph representation learning (TG-RL) module

        ⑤They designed a multi-skip scheme to capture long skip and short skip information:

with all the input, hidden state and cell memory are in graph structure, U,H,C are modulated input, the hidden state and cell memory, respectively

        ⑥Output in the graph convolution:

E_t^{\mathcal{G}^{(l+1)}}=\mathbf{Gconv}\left(\widehat{G}_t\right)=\mathbf{relu}\left(\widehat{A}_tE_t^{\mathcal{G}^{(l)}}W_G^{(l)}\right)

        ⑦The final embedding:

\widehat{H}_T^C=\sum_{p=1}^P\sum_{i=T}^{T-p+1}W_i^{(p)}H_i^{(p)}+b

2.4.6. Ensemble

        ①Hyperparameter optimization: multi-time window ensemble strategy

        ②The ensembling process:

2.4.7. Variation of braintgl for unsupervised clustering

        ①Framework of unsupervised BrainTGL (BrainTGL-C):

they initialize all the pseudo labels in every iterations

2.4.8. Theoretical contribution

(1)Feature learning for spatio-temporal data

(2)Graph structure learning for the graph data with complex structure

2.5. Experiments

2.5.1. Datasets and environment

(1)Datasets

        ①ABIDE: 871 subjects with 403 ASD and 468 HC filtered by preprocessing. Further eliminating BOLD signal which is not in [176,250] and left 512 subjects

        ②HCP: excluding which frame less than 1200, left 1091 subjects with 498 female and 593 male (22 regions of cortical surface)

        ③NMU: NMU MDD with 246 HC and 181 MDD, NMU BD with 246 HC and 146 BD, applying AAL 90

        ④Hardware environment:

2.5.2. Comparison with the prior works on the brain network classification

        ①Cross validation: 5 fold

        ②Comparison table on HCP and ABIDE:

        ③Comparison table on NMU:

        ④ROC:

2.5.3. Ablation study

        ①Module ablation on ABIDE:

        ②Module ablation on HCP:

2.5.4. Discussion

(1)The impact of attention graph pooling

        ①Result of t-SNE:

(2)The impact of the supernode number

        ①Number of supernodes:

(3)Comparison of pooling methods

        ①Different of other pooling methods on ABIDE:

(4)The impact of different skip lengths in TG-RL

        ①Skip ablation:

(5)The model complexity

        ①FLOPs and traning time comparison:

2.5.5. Disease subtype clustering analysis

        ①Subtype:

        ②Subtype classification on MDD:

        ③Subtype on BD:

2.6. Limitations and future works

        ①作者觉得在粗化的时候信息丢失了。但是其实这也是白说这肯定会丢失嘛。要么就保留一下原信息但是占比调低或者整个什么同时精化的来融合balabala...

        ②作者觉得没有捕获到长程依赖。这不是研究者的问题吧???数据集本来就只有那么几分钟的BOLD信号,谁能长程捕获啊,做纵向数据算了

        ③觉得可以有更多HPO方法,嗯

2.7. Conclusion

        ~

3. Reference

Liu, L. et al. (2023) 'BrainTGL: A dynamic graph representation learning model for brain network analysis', Computers in Biology and Medicine, 153. doi: Redirecting

"Structure-Aware Transformer for Graph Representation Learning"是一篇使用Transformer模型进行图表示学习的论文。这篇论文提出了一种名为SAT(Structure-Aware Transformer)的模型,它利用了图中节点之间的结构信息,以及节点自身的特征信息。SAT模型在多个图数据集上都取得了非常好的结果。 以下是SAT模型的dgl实现代码,代码中使用了Cora数据集进行示例: ``` import dgl import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class GraphAttentionLayer(nn.Module): def __init__(self, in_dim, out_dim, num_heads): super(GraphAttentionLayer, self).__init__() self.num_heads = num_heads self.out_dim = out_dim self.W = nn.Linear(in_dim, out_dim*num_heads, bias=False) nn.init.xavier_uniform_(self.W.weight) self.a = nn.Parameter(torch.zeros(size=(2*out_dim, 1))) nn.init.xavier_uniform_(self.a.data) def forward(self, g, h): h = self.W(h).view(-1, self.num_heads, self.out_dim) # Compute attention scores with g.local_scope(): g.ndata['h'] = h g.apply_edges(fn.u_dot_v('h', 'h', 'e')) e = F.leaky_relu(g.edata.pop('e'), negative_slope=0.2) g.edata['a'] = torch.cat([e, e], dim=1) g.edata['a'] = torch.matmul(g.edata['a'], self.a).squeeze() g.edata['a'] = F.leaky_relu(g.edata['a'], negative_slope=0.2) g.apply_edges(fn.e_softmax('a', 'w')) # Compute output features g.ndata['h'] = h g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h')) h = g.ndata['h'] return h.view(-1, self.num_heads*self.out_dim) class SATLayer(nn.Module): def __init__(self, in_dim, out_dim, num_heads): super(SATLayer, self).__init__() self.attention = GraphAttentionLayer(in_dim, out_dim, num_heads) self.dropout = nn.Dropout(0.5) self.norm = nn.LayerNorm(out_dim*num_heads) def forward(self, g, h): h = self.attention(g, h) h = self.norm(h) h = F.relu(h) h = self.dropout(h) return h class SAT(nn.Module): def __init__(self, in_dim, hidden_dim, out_dim, num_heads): super(SAT, self).__init__() self.layer1 = SATLayer(in_dim, hidden_dim, num_heads) self.layer2 = SATLayer(hidden_dim*num_heads, out_dim, 1) def forward(self, g, h): h = self.layer1(g, h) h = self.layer2(g, h) return h.mean(0) # Load Cora dataset from dgl.data import citation_graph as citegrh data = citegrh.load_cora() g = data.graph features = torch.FloatTensor(data.features) labels = torch.LongTensor(data.labels) train_mask = torch.BoolTensor(data.train_mask) val_mask = torch.BoolTensor(data.val_mask) test_mask = torch.BoolTensor(data.test_mask) # Add self loop g = dgl.remove_self_loop(g) g = dgl.add_self_loop(g) # Define model and optimizer model = SAT(features.shape[1], 64, data.num_classes, 8) optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4) # Train model for epoch in range(200): model.train() logits = model(g, features) loss = F.cross_entropy(logits[train_mask], labels[train_mask]) optimizer.zero_grad() loss.backward() optimizer.step() acc = (logits[val_mask].argmax(1) == labels[val_mask]).float().mean() if epoch % 10 == 0: print('Epoch {:03d} | Loss {:.4f} | Accuracy {:.4f}'.format(epoch, loss.item(), acc.item())) # Test model model.eval() logits = model(g, features) acc = (logits[test_mask].argmax(1) == labels[test_mask]).float().mean() print('Test accuracy {:.4f}'.format(acc.item())) ``` 在这个示例中,我们首先加载了Cora数据集,并将其转换为一个DGL图。然后,我们定义了一个包含两个SAT层的模型,以及Adam优化器。在训练过程中,我们使用交叉熵损失函数和验证集上的准确率来监控模型的性能。在测试阶段,我们计算测试集上的准确率。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值