《Pooling Regularized Graph Neural Network for fMRI Biomarker Analysis》阅读笔记

前言

2020年MICCAI的论文,Individual-level的。

论文地址:https://arxiv.org/abs/2007.14589
代码地址:https://github.com/xxlya/PRGNN_fMRI


一、模型

在这里插入图片描述
模型由convolutional layerpooling layer以及readout layer组成。


1.定义

样本: G = ( V , ε ) G=(V, \varepsilon) G=(V,ε)

节点集: V = { v 1 , . . . , v N } V=\{v_1,...,v_N\} V={v1,...,vN} N N N为ROI的个数。

邻接矩阵: E = [ e i j ] ∈ R N × N E=[e_{ij}]∈R^{N×N} E=[eij]RN×N e i j = 0 , i f ( v i , v j ) ∉ ε e_{ij}=0, if (v_i,v_j)\notin\varepsilon eij=0,if(vi,vj)/ε

属性矩阵: H = [ h i , . . . , h N ] T H=[h_i,...,h_N]^T H=[hi,...,hN]T h i h_i hi表示节点 i i i的属性。

2.Convolutional Layer

在GAT的基础上,加入了边的属性。

  1. 计算节点对其邻居节点的注意力系数。注意与GAT的区别,GAT模型是没有 e i , j e_{i,j} ei,j的。

    在这里插入图片描述

  2. 对注意力系数归一化。

    在这里插入图片描述

  3. 更新节点 i i i的嵌入表示。

    在这里插入图片描述


3.Pooling Layer

(1).现存方法及其弊端

在graph上做池化的方法可以分为两种:

第一种,基于聚类的池化方法,根据图的拓扑结构把几个节点聚合成一个超级节点。在脑网络中,这种池化方法的可解释性不强。

第二种,基于重要性的池化方法,计算每个节点的重要程度并保存排名靠前的几个节点。然而,把gPool或SAGPool等现有的池化模型照搬到脑网络中,有以下几点缺陷:1. 被保留节点与被丢弃节点之间的区分度不大,不利于寻找有代表性的区域生物标志;2.同一类被试(患者或正常人)中,各节点的重要程度可能完全不同,不利于找到组水平的生物标志。

(2).本文池化方法及其创新

虽然基于节点重要程度的池化方法有上述这些缺点,但只要对其加以限制,还是能利用这类方法去找到生物标志的。因此,本文还是采用的基于重要性的方法:

在这里插入图片描述
SAGE pooling中函数与卷积层的计算方法相似,只不过 ϕ i θ ∈ R 1 × d ( l ) \phi_{i}^{\theta}\in R^{1×d^{(l)}} ϕiθR1×d(l),得到的结果是一个标量。

在得到各个节点的重要程度后,只保留前k个(或一定比例)节点
在这里插入图片描述
论文中只描述了池化后的邻接矩阵如何表示,并没有描述如何设计属性矩阵,等我看完代码再把这个部分补充进来。

上述两个池化公式,其实就是gPool和SAGPool中的公式。如果只是这样的话,是没有创新可言的。本文的创新点在于为池化部分设计了两个损失函数:Distance Loss用来加强被保留/丢弃的节点之间的差异;Group-level Consistency Loss则迫使同类graph的节点重要性排序相似。具体的函数表达式放到后面部分讲。


4.Readout Layer

尽管用了池化的方法,但最后并不是只剩下一个节点。因此,仍然需要从剩余的节点中“总结”出全局表示向量。

在这里插入图片描述

class NNGAT_Net(torch.nn.Module):
    def __init__(self, ratio, indim, poolmethod = 'topk'):
        super(NNGAT_Net, self).__init__()
        self.dim1 = 32
        self.dim2 = 32
        self.dim3 = 8
        self.indim = indim
        self.poolmethod = poolmethod
		# 这里直接调用torch_geometric中的GAT,应该是没用到边的属性的。
		# 即没有用Convoluton Layer中的第一个公式,而是用了GAT原本的更新公式
        self.conv1 = GATConv( self.indim, self.dim1)
        self.bn1 = torch.nn.BatchNorm1d(self.dim1)
        if self.poolmethod == 'topk':
            self.pool1 = TopKPooling(self.dim1, ratio=ratio, multiplier=1, nonlinearity=torch.sigmoid)
        elif self.poolmethod == 'sag':
            self.pool1 = SAGPooling(self.dim1, ratio=ratio, GNN=GATConv,nonlinearity=torch.sigmoid) #0.4 data1 10 fold

        self.conv2 = GATConv(self.dim1, self.dim2)
        self.bn2 = torch.nn.BatchNorm1d(self.dim2)
        if self.poolmethod == 'topk':
            self.pool2 = TopKPooling(self.dim2, ratio=ratio, multiplier=1, nonlinearity=torch.sigmoid)
        elif self.poolmethod == 'sag':
            self.pool2 = SAGPooling(self.dim2, ratio=ratio, GNN=GATConv,nonlinearity=torch.sigmoid)

		# 原文中是在最后一层池化后用mean来获得全局向量
		# 这里用了mean||max,并且用skip-connection拼接了所有池化层后的向量表示,与原文有出入。
        self.fc1 = torch.nn.Linear((self.dim1+self.dim2)*2, self.dim2)
        self.bn4 = torch.nn.BatchNorm1d(self.dim2)
        self.fc2 = torch.nn.Linear(self.dim2, self.dim3)
        self.bn5 = torch.nn.BatchNorm1d(self.dim3)
        self.fc3 = torch.nn.Linear(self.dim3, 2)

    def forward(self, x, edge_index, batch, edge_attr):
        # edge_attr = edge_attr.squeeze()  
        # edge_index, edge_attr = self.augment_adj(edge_index, edge_attr, x.size(0))
        x = self.conv1(x, edge_index)
        if x.norm(p=2, dim=-1).min() == 0:
            print('x is zeros')
        x, edge_index, edge_attr, batch, perm, score1 = self.pool1(x, edge_index, edge_attr, batch)
        x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        edge_attr = edge_attr.squeeze()
        edge_index, edge_attr = self.augment_adj(edge_index, edge_attr, x.size(0))

        x = self.conv2(x, edge_index)
        x, edge_index, edge_attr, batch, perm, score2  = self.pool2(x, edge_index, edge_attr, batch)
        x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = torch.cat([x1,x2], dim=1) #concate

        x = self.bn4(F.relu(self.fc1(x)))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.bn5(F.relu(self.fc2(x)))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.log_softmax(self.fc3(x), dim=-1)

        return x, score1, score2

    def augment_adj(self, edge_index, edge_weight, num_nodes):
        edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
                                                 num_nodes=num_nodes)
        edge_index, edge_weight = sort_edge_index(edge_index, edge_weight,
                                                  num_nodes)
        edge_index, edge_weight = spspmm(edge_index, edge_weight, edge_index,
                                         edge_weight, num_nodes, num_nodes,
                                         num_nodes)
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)

        return edge_index, edge_weight

5.Loss Function

(1).Distance Loss

引入这个损失的目的是为了增强被保留节点与被丢弃节点间的区分度

s ^ m ( l ) = [ s ^ m , 1 ( l ) , . . . , s ^ m , N ( l ) ] \hat{s}_{m}^{(l)}=[\hat{s}_{m,1}^{(l)}, ... , \hat{s}_{m,N}^{(l)}] s^m(l)=[s^m,1(l),...,s^m,N(l)] 代表第 m m m个被试的各节点在第 l l l层的重要程度(降序排列)。

记前 k ( l ) k^{(l)} k(l)个节点为 a m , i ( l ) = s ^ m , i ( l ) , i = 1 , . . . , k ( l ) a_{m,i}^{(l)}=\hat{s}_{m,i}^{(l)}, i=1, ... , k^{(l)} am,i(l)=s^m,i(l),i=1,...,k(l)

其余节点则记为 b m , j ( l ) = s ^ m , j + k ( l ) ( l ) , j = 1 , . . . , N ( l ) − k ( l ) b_{m,j}^{(l)}=\hat{s}_{m,j+k^{(l)}}^{(l)}, j =1, ... , N^{(l)}-k^{(l)} bm,j(l)=s^m,j+k(l)(l),j=1,...,N(l)k(l)

1). MDD Loss

引用了GAN中的损失函数。

在这里插入图片描述
在这里插入图片描述

if opt.distL == 'mmd':
	mmd = MDD_loss()
	def dist_loss(s, ratio):
		s = s.sort(dim=1).values
		# 这里和原文有出入。如果ratio > 0.5,那么source和target就会有重复的节点。
		# 按论文的意思应该是没有的重复的节点,而且source和target第二维的维度可以不相同。
		source = s[:,-int(s.size(1)*ratio):]
		target = s[:,:int(s.size(1)*ratio)]
		res = mmd(source, target)
		return -res

class MDD_loss(nn.Module):
   def __init__(self, kernel_mul=2.0, kernel_num=5):
       super(MMD_loss, self).__init__()
       self.kernel_num = kernel_num
       self.kernel_mul = kernel_mul
       self.fix_sigma = None

    def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=1, fix_sigma=None):
        n_samples = int(source.size()[0]) + int(target.size()[0])
		# 假设batch_size=1,s=[s_1, s_2, ... ,s_n]
		# source = [s_1, s_2, ..., s_k]; target = [s_n-k, s_n-k+1, ..., s_n]
		# 记source = [a_1, a_2, ..., a_k]; target = [b_1, b_2, ..., b_k]

		# total = [[a_1, a_2, ..., a_k], [b_1, b_2, ..., b_k]]
        total = torch.cat([source, target], dim=0)
        # toal0 = [[[a_1, a_2, ..., a_k], [b_1, b_2, ..., b_k]], [[a_1, a_2, ..., a_k], [b_1, b_2, ..., b_k]]]
        total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
        # total1 = [[[a_1, a_2, ..., a_k],[a_1, a_2, ..., a_k]],[[b_1, b_2, ..., b_k],[b_1, b_2, ..., b_k]]]
        total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
        # L2_distance = [[aa, ab], [ba, bb]]
        L2_distance = ((total0 - total1) ** 2).sum(2)
        if fix_sigma:
            bandwidth = fix_sigma
        else:
            bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples)
        bandwidth /= kernel_mul ** (kernel_num // 2)
        bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)]
        # 论文里公式写错了
        kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
        return sum(kernel_val)

	def forward(self, source, target):
		batch_size = int(source.size()[0])
		kernels = self.guassian_kernel(source, target, kernel_mul=self.kernel.mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
		
		XX = kernels[:batch_size, :batch_size]   # 被保留节点之间的MDD
		YY = kernels[batch_size:, batch_size:]   # 被丢弃节点之间的MDD
		XY = kernels[:batch_size, batch_size:]   # 被保留/丢弃节点之间的MDD
		YX = kernels[batch_size:, :batch_size]   # 被丢弃/保留节点之间的MDD
		loss = torch.mean(XX + YY - XY - YX)
		# 和原文的公式有出入
		return loss

2).BCE Loss

理想状态下,被保留节点的得分应该尽可能接近于1,而被舍弃节点的得分接近于0。因此也可以用二值交叉熵作为损失函数。

在这里插入图片描述

if opt.distL == 'bce':
	def dist_loss(s, ratio):
		# s.shape = batch_size x num_nodes
		if ratio > 0.5:
			ratio = 1 - ratio
		s = s.sort(dim=1).values
		# 这里的计算与原文有出入。这里是取相同数量的排名靠前及靠后的节点计算损失。
		res = -torch.log(s[:, -int(s.size(1)*ratio):]+EPS).mean() - torch.log(1-s[:,:int(s.size(1)*ratio)]+EPS).mean()
		return res

(2).Group-level Consistency Loss

前面提到了,不同被试的节点重要性排序可能截然不同,这是不利于寻找组水平的生物标记的。该损失函数能迫使不同被试的 s ( 1 ) s^{(1)} s(1)尽可能相似(注意,这是只要求第一层的排序相似。因为经过一次池化后,不同被试所保留的节点不一定相同)。

在这里插入图片描述
上述公式代表每一类的损失。其中 L c = D c − W c L_c=D_c-W_c Lc=DcWc D c D_c Dc为对角阵,对角线元素的值均为 M c M_c Mc W c W_c Wc则为全1矩阵。

def consist_loss(s):
	if len(s) == 0:
		ruturn 0
	else:
		s = torch.sigmoid(s)
		# W: 全1矩阵
		W = torch.ones(s.shape[0], s.shape[0])
		# D: 对角阵,对角元素为一个batch中属于该类的被试个数
		D = torch.eye(s.shape[0]) * torch.sum(W, dim=1)
		L = D - W
		L = L.to(device)
		res = torch.trace(torch.transpose(s,0,1) @ L @ s) / (s.shape[0] * s.shape[0])
		return res


再加上分类损失 L c e L_{ce} Lce,最终的损失函数为:
在这里插入图片描述



二、实验结果

1.数据预处理

这篇论文的数据预处理方式和其他论文的处理方法有点不同。通常,会将ROI内所有体素的时间序列的均值作为该区域的时间序列。但这篇文章仅取了1/3的体素,并且用这种方法做了10倍的数据扩充(即对每个ROI随机取10次1/3的体素)。

同时,本篇文章构建graph的方法也有所不同。节点间的连接强度用的是偏相关系数(一般用的是皮尔逊系数),并且只保留前10%的边。如果graph不连通,则按照偏相关系数的降序,继续添加边直到连通

节点的属性则是用节点间的皮尔逊相关系数表示

def read_sigle_data(data_dir, filename):
	temp = h5py.File(osp.join(data_dir, filename), 'r')
	
	# read edge and edge attribute
	pcorr = np.abs(temp['pcorr'].value)
	# 这里应该是2019年MICCAI文章的代码,和这篇文章保留前10%不符。
	th = np.percentile(pcorr.reshape(-1), 95)
	pcorr[pcorr < th] = 0
	num_nodes = pcorr.shape[0]
	
	G = from_numpy_matrix(pcorr)
	A = nx.to_scipy_sparse_matrix(G)
	adj = A.tocoo()
	# len(adj.row): 边的数量
	edge_att = np.zeros((len(adj.row)))
	for i in range(len(adj.row)):     
		edge_att[i] = pcorr[adj.row[i], adj.col[i]]
	edge_index = np.stack([adj.row, adj.col])
	edge_index, edge_att = remove_self_loops(torch.from_numpy(edge_index).long(), torch.from_numpy(edge_att).float())
	edge_index, edge_att = coalesce(edge_index, edge_att, num_nodes, num_nodes)

	# 节点属性,皮尔逊系数矩阵
	att = temp['corr'].value

	return edge_att.data.numpy(), edge_index.data.numpy(), att, temp['indicator'].value, num_nodes

2.消融实验

在这里插入图片描述

λ 2 = 0 \lambda_2=0 λ2=0时,模型会在训练集上过拟合;而 λ 2 \lambda_2 λ2太大时(> 0.1),则会在欠拟合。

在这里插入图片描述
随着迭代的进行,节点的重要程度有了较为明显的分化(除了用SAGE pooling的)。被保留节点的得分逐渐趋于1,而被舍弃节点的得分则趋近于0。


3.对比实验

在这里插入图片描述
机器学习方法的输入为皮尔逊相关系数矩阵的上三角阵;BrainNetCNN的输入则是皮尔逊相关系数矩阵;GNN方法的输入则与PR-GNN相同。


4.可解释性

在这里插入图片描述
λ 2 \lambda_2 λ2设置的较小时(图a),模型更倾向于寻找个人的生物标志,被试间重叠区域不多;而 λ 2 \lambda_2 λ2较大时(图b-c),则更倾向于寻找组水平的生物标志,并且随着 λ 2 \lambda_2 λ2的增大,重叠区域增多。


  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
### 回答1: 卷积神经网络 (Convolutional Neural Network, CNN) 是一种深度学习网络架构, 它通过在网络中使用卷积层来提取图像的特征, 常用于图像分类, 对象检测, 图像生成和自然语言处理等应用中. 卷积层通过使用不同大小的滤波器 (Filter) 在图像上做卷积运算, 从而能够提取出图像中不同维度的特征. 卷积神经网络通常由多个卷积层和池化层 (Pooling Layer) 组成, 并在最后几层使用全连接层 (Fully Connected Layer) 来进行分类或回归. ### 回答2: 卷积神经网络(Convolutional Neural Network,简称CNN)是一类专门用于图像处理和模式识别的深度学习模型。它模拟人脑视觉处理方式,通过层层的卷积操作和池化操作来提取图像的特征。 CNN的主要特点是具有共享权重和局部感知野。共享权重指的是网络中权重参数在卷积过程中是共享的,这使得CNN对于图像的位置变化具备一定的鲁棒性。局部感知野指的是网络仅对局部区域进行感知和处理,而不是整张图像,以此降低网络的复杂度。 CNN由多个卷积层、池化层和全连接层组成。卷积层通过卷积操作将输入图像与一系列卷积核(权重矩阵)进行卷积操作,得到不同的特征图。池化层则通过对特征图进行下采样,减少特征图的维度,同时保留重要的特征。全连接层则将池化后的特征图与分类器相连,完成最终的分类任务。 CNN训练过程中通常使用反向传播算法和梯度下降方法来更新网络参数。在大规模神经网络训练中,还可以使用随机梯度下降法或者自适应学习率的方法以提高训练速度和收敛性。 CNN在计算机视觉领域取得了许多重要的突破,例如图像分类、物体检测、人脸识别等任务上的优良表现。它的成功主要归功于其对于图像特征的自动学习能力以及对于局部结构和空间关系的有效建模能力。通过深度学习的训练和迁移学习,CNN在不同领域中都有着广泛的应用和研究价值。 ### 回答3: 卷积神经网络(Convolutional Neural Network,简称CNN)是一种深度学习模型,以其在图像分类、目标检测和语义分割等任务上的出色表现而备受瞩目。CNN的核心思想是通过多个卷积层、池化层和全连接层构建网络,以学习从原始输入数据中提取有用特征,并输出对输入的分类或回归结果。 CNN的主要特点有以下几个方面。首先,它采用卷积层与池化层的交替组合,使网络具备对输入数据的位置信息具有不变性。卷积层通过滑动不同大小的卷积核在输入数据上提取特征,而池化层则通过降低特征图的尺寸以减少计算量,并保留关键特征。其次,CNN通过权值共享使得网络在处理不同位置的输入时具有参数数量共享,从而减少了模型的复杂性,提高了计算效率。此外,通过使用非线性激活函数(例如ReLU)和批量归一化等技术,CNN可以克服非线性问题,提高网络的非线性拟合能力。最后,CNN具有自学习的能力,通过反向传播算法可以有效地调整网络参数,使得模型能够从数据中学习到更具有判别性的特征表示。 CNN在计算机视觉领域的广泛应用表明了其强大的特征提取和模式识别能力。通过利用卷积层的滤波器,CNN可以学习到不同大小的特征,从而在图像分类和目标检测等任务中取得出色的性能。此外,CNN还可以用于语义分割,即将图像中的每个像素分类到不同的语义类别中,从而实现更细粒度的图像分析。由于CNN在特征学习和模式识别方面的卓越能力,它已成为许多计算机视觉任务的首选模型,并在自然语言处理和推荐系统等领域也有应用潜力。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值