基于注意力机制和多任务学习的阿尔茨海默症分类
文章目录
一、背景介绍
阿尔茨海默症(AD)是一种随时间不断恶化的神经退化性疾病,一般症状为语言障碍、情绪不稳、记忆丧失等,又称早老性痴呆、老年痴呆。AD的临床前期被称为轻微认知障碍(MCI),是一种介于正常和失忆症的过渡状态,表现为轻度认知功能减退,研究表明遗忘型MCI更容易发展为AD。AD的病因迄今未明,目前对其的研究也未有较大突破,所以早期诊断是避免进入AD晚期、抑制病情发展的最有效方式之一。磁共振成像(MRI)是目前最常见的医学影像技术,它对人脑部组织无创伤且成像效果良好,可以很好地展示脑组织萎缩情况,已广泛用于AD诊断中。传统诊断是AD通过专业医师对患者脑组织的医学图像进行分析,然而实际情况中,医生需要对MRI图像各个方位的切片以及患者的表现进行检查,此过程会消耗大量时间且诊断结果具有主观性。对AD、MCI患者及健康对照(HC)的差异研究能减轻医师负担,有助于AD的早期预防。
二、主要内容
首先,利用改进的基础C3D网络,生成较粗糙的低级特征图;然后,将其分别输入至引入注意力机制的卷积块与普通卷积块中,注意力机制的卷积块关注MRI图像的结构特性,能获取特征图中不同像素位置特有的注意力权重,与普通卷积块输出的特征图对应相乘;最后,利用不同的全连接层来实现多任务学习,获得包含主分类任务在内的3种输出,另2种输出在训练过程中通过反向传播优化主分类任务,得到优化后的阿尔茨海默症分类结果。
三、模型介绍
本模型主要包含3个部分:基础网络模块、注意力模块、多任务学习模块。基础网络模块用于特征提取,共包含5个卷积块;注意力模块在基础网络中间嵌入,加强网络对图像有效信息的获取;多任务学习模块在基础网络末端引入,补充两个辅助任务的信息。
基于注意力机制和多任务学习的3D卷积网络
1 基础网络
搭建一个三维卷积神经网络用于AD分类。三维卷积网络用于人体动作分析与识别,相较于二维卷积,三维卷积可以更好地捕捉序列信息,保留更多的数据特征信息。
本文选用C3D网络作为基础网络,用于提取图像的特征。由于三维MRI图像计算量大,并且医学图像数据集与动作识别数据集差异明显,为了能更快速有效地将C3D应用到AD分类中,本文对其网络结构进行了改进。
改进C3D网络结构
基础网络保留了C3D的5个卷积,每个Block中含有一个池化层对图像进行下采样。实验证明,对于3D卷积网络,所有卷积层使用3×3×3的小卷积核会取得最佳效果,本文基础网络中的所有卷积层均使用3×3×3卷积核。除Block1中的池化层尺寸和步长设置为1×2×2,后续所有池化层的尺寸和步长均设置为2×2×2,因此若Block1的输入尺寸为(N,C,D,H,W),最终Block5输出大小为(N,C′,D/16,H/32,W/32),其中 N为批处理数量,C 和C′为通道数,D,H,W 为三维图像信息。由于医学图像数据集通常数量较少,因此批归一化层(BN)也适用于医学图像任务中,其作用是加速网络收敛,防止过拟合。本文于基础网络中添加BN层,在显著提升网络性能的同时能加快训练速度。Block5后利用全连接层将特征整合,输出不同类别的概率,然后将概率最大值对应的类别作为分类结果。
2 注意力模块
注意力模块
包括了卷积层、池化层、BN层以及限制输出范围的归一化层。
假设图1中的Block3输出为F3,将F3同时输入至注意力模块和Block4,注意力模块需要先经过一个卷积层。卷积层在训练时更新参数,会造成后续输入数据分布变化,因此需要BN层对网络中间数据做归一化处理。BN层通常置于网络激活函数前。引入BN层后的前向传导公式为:(1)
其中,f(·)表示线性整流函数,BN(·)表示BN层的归一化处理,FAtt_1表示上一个卷积的输出,FAtt_2表示经BN层和激活函数后的输出。下个卷积层的输出FAtt_3,将其输入至池化层。本文采用最大池化方式,取区域内的最大值作为输出,用以实现扩大感受野,降低分辨率。池化处理后经卷积层获得FAtt_4 ,然后将其传入归一化层。归一化层的作用是将上一层输出转为概率。注意力模块的输出与Block4输出的每个特征图对应相乘,从而实现注意力机制的引入。
3 多任务学习模块
本文采用基于参数硬共享机制的多任务学习来实现AD分类(深度学习中的多任务学习通常分为两类:参数软共享机制,参数硬共享机制。参数软共享机制每个任务都有自己的参数和模型;参数硬共享机制共享隐藏层的参数,只有特定输出层保留不同的参数。两者相比,参数硬共享机制参数量更少,能有效防止模型过拟合,因此更适用于医学图像数据量少的情况),其中主任务为AD分类,辅助子任务设置为临床痴呆评定量表(CDR)评分回归和和简易智力状态检查表(MMSE)评分回归。CDR是医生通过与患者及其家属交流获得信息,完成对患者认知受损程度的评估,最终整合成一个总分;MMSE是最具影响的认知缺损筛选工具之一。两者都与AD分类具有相关性,可以作为辅助子任务。
利用全连接层获得最终分类概率及CDR和MMSE评分,全连接层的计算公式为:
(2)
其中,xi表示输入的第i个特征图,wi表示全连接参数。共享的两个全连接层通道数依次为4096和2048,并额外添加dropout层,以一定概率舍弃神经元,防止过拟合,提高训练速度。主任务后续添加一个通道为2的全连接层,CDR和MMSE任务后续都添加两个全连接层,这3个任务同时训练,共享分支前的所有参数,根据反向传播更新模型参数,以主任务AD分类作为最终输出。
损失函数包含两部分:AD分类任务中真实标签与预测结果的分类损失、辅助任务中的CDR回归损失和MMSE回归损失。
(3)
其中,Lmain表示AD分类损失,LCDR表示CDR回归损失,LMMSE表示MMSE回归损失,α和λ为损失平衡系数。
四、总结
本文针对医学图像特征难以有效提取和AD分类的临床辅助信息资源浪费等难题,提出一种基于深度三维卷积神经网络的AD分类方法,其优势在于:1)该方法为端到端网络,不需要人工提取特征;2)在三维卷积神经网络的基础上引入注意力机制,保证重要特征信息绝不丢失的同时能获得AD分类任务中需要重点关注的目标区域;3)通过多任务学习引入临床CDR和MMSE评分,利用辅助信 息提升AD分类性能并提高其泛化能力。本文实验证明了所提模型能准确实现AD分类,相较于其他算法有更好的分类性能,促进了临床医学中AD和MCI患者的早诊早治,具有重要的现实意义。但是,本文HC实验表现较差,未来将针对这一问题,结合多尺度信息和引入对抗样本的方式进一步探索MCI患者图像的结构特征,同时优化分类损失函数,使其能在训练过程中适应不均衡样本分类任务。