An Attention-Based Deep Learning Approach for Sleep Stage Classification With Single-Channel EEG
文献来源
https://ieeexplore.ieee.org/abstract/document/9417097
DOI: 10.1109/TNSRE.2021.3076234
IEEE TRANSACTIONS ON NEURAL SYSTEMS AND REHABILITATION ENGINEERING, VOL. 29, 2021
作者:Emadeldeen Eldele , Zhenghua Chen , SeniorMember,IEEE, Chengyu Liu , SeniorMember,IEEE, Min Wu , SeniorMember,IEEE, Chee-Keong Kwoh , Xiaoli Li , SeniorMember,IEEE, and Cuntai Guan ,Fellow,IEEE
一、摘要
方法: AttenSleep 基于注意力的深度学习架构
目的: 从单通道EEG信号中进行睡眠阶段分类
架构: 从基于多分辨率卷积神经网络( MRCNN )和自适应特征重标定( AFR )的特征提取模块入手。MRCNN可以提取低频和高频特征,而AFR可以通过建模特征之间的相互依赖关系来提高提取特征的质量。第二个模块是时间上下文编码器( TCE ),它利用多头注意力机制来捕获提取特征之间的时间依赖关系。特别地,多头注意力利用因果卷积对输入特征中的时间关系进行建模。
验证: 使用三个公共数据集来评估提出的AttnSleep模型的性能。结果表明,AttnSleep性能优于当前最先进的技术。
二、方法
三个模块:特征提取、时间上下文编码器、分类
- 特征提取: 多分辨率卷积神经网络MRCNN:小核卷积提取低频信号,宽核卷积提取高频信号;
自适应特征重标定AFR:选择突出最重要的特征,增强特征分类性能。 - 时间上下文编码器TCE: 捕获输入特征间的时间依赖关系,核心是因果卷积的多头注意力。
- 分类: 具有softmax激活函数的全连接层,利用类别敏感损失函数处理数据不平衡。
特征提取
信号输入:30sEEG信号中提取
信号处理:
- 多分辨率卷积神经网络MRCNN: 不同睡眠阶段有不同的频率范围。不同的核大小,捕获不同范围的时间步长,从不同的睡眠相关频段中寻址特征。小核卷积提取高频信号,宽核卷积提取低频信号;
该模块包含2个分支,每个分支都包含3个卷积层、2个最大池化层,卷积层包含了归一化层、高斯误差线性单元GELU作为激活函数。dropout用于减少过拟合。 - 自适应特征重标定AFR: 重新校准MRCNN学习到的特征。建模了特征之间 的相互依赖关系,通过残差压缩和激励模块自适应选择最具判别力的特征。
时态上下文编码器TCE
作用: 捕获输入特征中的时间依赖关系
组成: TCE由多头注意力MHA层、归一化层和2个FC层组成。将两个完全相同的TCE结构进行堆叠,生成最终的特征。
自注意力机制: 量化输入特征之间的相互依赖关系。根据输入中的每个位置为感兴趣的区域分配较高的权重。
- MHA利用因果卷积编码输入特征的位置信息并捕获它们的时间关系。因果卷积具有快速并行处理的优势,与RNNs相比显著减少了模型训练时间。接下来我们说明了MHA是如何在AttnSleep模型中工作的。
多头注意力MHA的输入:3个重复的自适应特征重标定AFR的输出
因果卷积,通过三个矩阵计算注意力,将注意力扩展到H头上的每一个,每个矩阵被拆分为H个子空间;在每个子空间计算注意力,将H个表示拼在一起。 - 添加归一层Add and Normalize Layer:将上一层的输出通过残差丽娜姐添加到该层的输入,对和进行归一化处理;残差连接有助于模型利用低层特征,归一化有助于加快训练速度。
- 前馈层Feed-Forward Layer:两个FC层的组合。使用ReLU激活函数来打破模型中的非线性并考虑潜在维度之间的相互作用。
类感知损失函数及优化
利用加权交叉熵损失函数
三、实验结果
数据集: Sleep-EDF-20, Sleep-EDF-78, Sleep Heart Health Study (SHHS) 。对每个数据集只使用单通道EEG数据。
- Sleep - EDF - 20包含20名受试者,而Sleep - EDF - 78是78名受试者的扩展版本。被试共参与了两个研究。第一种是睡眠卡片,研究年龄对睡眠的影响,研究对象为25 ~ 101岁的健康被试。第二个是睡眠遥测,该文件研究了22名高加索男性和女性在没有任何其他药物的情况下,使用替西泮对睡眠的影响。本文使用第一种研究睡眠卡片中的数据,使用单个Fpz - Cz通道作为各种模型的输入。
- SHHS是一项关于睡眠呼吸障碍的心血管和其他后果的多中心队列研究。受试者患有多种疾病,为尽量减少这些疾病的影响,我们沿用了[ 40 ]中的研究来选择受试者,他们被认为有规律的睡眠(例如,呼吸暂停低通气指数( Apnea Hypopnea Index , AHI )小于5)。最终,从6441名被试中选取了329名被试进行实验。
预处理:
- 排除了任何不属于睡眠阶段的UNKNOWN阶段。
- 根据AASM标准,我们将N3和N4合并为一个阶段( N3 )。
- 只包括睡眠前后30分钟的觉醒期,以增加对睡眠阶段的更多关注。
评估参数:
准确率( ACC )、宏观平均F1值( MF1 )、Cohen Kappa ( κ ) 、宏观平均G - mean ( MGm )
MF1和MGm都是评价模型在不平衡数据集上性能的常用指标[ 42 ]。给定第i类的真阳性( TPi )、假阳性( FPi )、真阴性( TNi )和假阴性( FNi ),总体准确率ACC、MF1和MGm定义如下。
也用每类准确率( PR )、每类召回率( RE )、每类F1值( F1 )和每类G - mean ( GM )来评估每个模型。
模型评分性能
表II、III、IV分别给出了本文模型在Sleep - EDF数据集的Fpz - Cz通道和SHHS数据集的C4 - A1通道上的混淆矩阵。混淆矩阵的计算是将测试数据的所有评分值通过20折的方式相加得到的。每一行代表专家分类的样本数,每一列代表模型预测的历元数。
值得注意的是,阶段N1达到了最低的性能,F1低于50 %,在这里它经常被误分为W,REM和N2类。相比之下,N3类在Sleep - EDF - 20和SHHS数据集上的表现最好,但在Sleep - EDF - 78数据集上的表现有所下降,因为它是该数据集上的少数类。在不同的数据集中,大多数的错误分类都是N2类,因为它是大多数类。
基线模型和实验配置
DeepSleepNet:采用自定义的CNN结构,然后使用带有残差连接的LSTM进行睡眠阶段分类。· SleepEEGNet:采用与DeepSleepNet [ 20 ]相同的CNN结构,然后使用带有注意力机制的编码器-解码器。
ResnetLSTM:实现了用于特征提取的ResNet结构,然后使用LSTM将EEG信号分类到不同的睡眠阶段。
MultitaskCNN:首先将原始EEG信号转换为功率谱图像,然后使用多任务CNN结构的联合分类和预测技术来识别睡眠阶段。
SeqSleepNet:将原始EEG信号转换为功率谱图像,然后使用分层RNN结构一次性分类多个时期。
使用PyTorch 1.4构建模型,并在Tesla K40 GPU上进行训练。使用了128批的Adam优化器,学习率从1e - 3开始,经过10次迭代后降低到1e - 4。Adam的质量衰减设定为1e - 3,betas ( b1、b2)分别作为( 0.9、0.999),epsilon值设定为1e - 08,AMSGrad算法设定为true。所有卷积层均采用均值为0,方差为0.02的高斯分布进行初始化。对于TCE模块,在MHA中使用了5个头部,每个特征的维度d在Sleep - EDF数据集为80,在SHHS数据集为100,因为SHHS数据集具有更高的采样率,因此其信号长度更长。对于两个全连接层,输入维度为d,输出维度设置为120,对于第二个全连接层,输入维度为d,输出维度设置为120。
与最新方法比较
AttnSleep由于其强大的特征提取模块以及带有注意力机制的TCE,取得了比其他4种方法更好的分类性能。AttnSleep在Sleep - EDF78和SHHS上取得了更好的MF1和MGm,表明设计的代价敏感损失函数有助于处理不平衡数据。此外,我们可以观察到我们的AttnSleep对N1类的性能低,因为W、REM和N1具有相似的特征。因此AttnSleep倾向于将N1错误分类为包括W和REM在内的其他类别。
注意到表5中5种方法均采用单历元(即30秒EEG信号)作为模型输入。SeqSleepNet 以3个历元作为输入,然后预测中间历元的标签。为了公平比较,使用3个历元作为输入,将AttnSleep与表VI中的SeqSleepNet进行比较。
如表6所示,AttnSleep在所有4个指标( ACC、MF1、κ和MGm)上都优于SeqSleepNet。本文方法的训练时间远小于其他方法。
首先,DeepSleepNet、SleepEEGNet和SeqSleepNet都利用了LSTM,由于LSTM中的循环处理导致训练速度变慢。其次,Multi taskCNN和SeqSleepNet在训练主模型之前需要额外的计算来预训练一个基于DNN的滤波器组。但AttnSleep模型使用TCE代替LSTM来捕捉EEG数据之间的时间依赖关系,因此可以从并行计算中获益,从而降低训练复杂度。
消融实验
在Sleep - EDF20数据集上进行了消融研究
首先,AFR可以提高分类性能,这证明了对特征间依赖关系进行建模的必要性。通过比较第三种和第四种变体(即, MRCNN + TCE vs . MRCNN + AFR + TCE),进一步证明了这一点。
其次,通过对比MRCNN和MRCNN + TCE (类似地, MRCNN + AFR vs . MRCNN + AFR + TCE),我们得出结论,利用TCE捕获时间依赖关系对睡眠阶段分类具有重要意义。此外,TCE比AFR更重要,因为MRCNN + TCE优于MRCNN + AFR。
第三,AttnSleep取得了显著优于其他四种变体的MF1和MGm,表明提出的对类别敏感代价敏感损失函数可以有效地解决数据不平衡问题,并且没有任何额外的计算开销。
MHA头数敏感性分析
头的数量应该被特征的长度d分割。由于Sleep - EDF - 20数据集的d为80,我们使用1,2,4,5,8和10个头部来运行我们的模型。图7展示了模型在Sleep - EDF - 20数据集上的准确率和MF1分数。总体而言,当使用不同数量的头时,模型性能相当稳定。随着头数从1,2增加到4和5,我们可以观察到性能上的轻微改善,因为使用更多的头可以使模型发现更有意义的特征和特征交互。同时,当头部数量进一步增加( H = 8和10)时,即每个头部中的特征长度变小,导致性能略有下降。在我们的实验中,最终在Sleep - EDF - 20数据集上将H设置为5。对于另外两个数据集,我们也将H设为5。
四、结论
本文提出了一种新颖的从单通道原始EEG信号中进行睡眠阶段分类的架构,称为AttnSleep。AttnSleep依赖于使用两个模块从EEG信号中提取特征:多分辨率卷积神经网络( MRCNN )和自适应特征重标定( AFR )。这两个模块之后是时间上下文编码器( Temporal Context Encoder,TCE )模块,该模块通过使用多头注意力( MHA )机制捕获提取特征之间的时间依赖关系。本文还提出了一个对类别敏感代价敏感损失函数来处理数据不平衡问题。在3个公开数据集上的实验结果表明,在不同的评价矩阵下,我们的模型均优于现有方法。此外,还进行了消融研究。最后,我们进行了敏感性分析,以证明MHA中头部数量的影响。结果表明,我们的方法在不同的头数下是相当稳定的。对于未来的方向,我们将考虑迁移学习和领域自适应技术,这些技术将在标记数据集上训练的模型自适应地分类其他数据集中的未标记睡眠数据。