文章目录
摘要
在脑电图数据中对癫痫发作类型进行自动分类,可以更精确地诊断和有效地管理该病。这项任务具有挑战性,原因包括低信噪比、信号伪影、癫痫患者癫痫符号学的高差异以及临床数据的有限。为了克服这些挑战,在本文中,我们提出了一个名为SeizureNet的深度学习框架,它使用一个集成架构学习多谱特征的嵌入,用于患者之间发作类型的分类。我们使用最近发布的TUH脑电图发作语料库(V1.4.0和V1.5.2)来评估SeizureNet的性能。实验表明,对于基于头皮脑电图的多类别癫痫发作类型分类,癫痫发作交叉验证加权F1 Score可达0.94,患者交叉验证加权F1 Score可达0.59。我们还表明,通过对低内存约束应用的知识蒸馏,由SeizureNet学习的高级特征嵌入显著提高了较小网络的准确性。
1 - 引言
癫痫是一种神经系统疾病,影响世界1%的人口。它会导致突然和不可预见的癫痫发作,从而导致病人严重受伤,甚至死亡。三分之一的癫痫患者得不到适当的治疗。对于剩下的三分之二的患者,治疗的选择和质量各不相同,因为癫痫发作的症状学对每个癫痫患者都是不同的。医生通过对脑电图(EEG)记录的目视检查来分析大脑活动异常是诊断癫痫的一项重要技术。这项任务是耗时的,并受观察者之间差异的影响。随着物联网数据收集的进展,基于机器学习的系统已被开发出来,以捕获癫痫发作期间脑电图数据中的异常模式[10,21,4]。在此背景下,目前的系统主要集中在发作检测和发作预测等任务上[23,3,11],而发作类型分类的任务由于任务的复杂性以及临床数据集缺乏对发作类型进行标注等因素,在很大程度上尚不完善。尽管如此,识别不同发作类型(如局灶性或全面性发作)的能力具有改善长期患者护理的潜力,使临床试验中能够及时进行药物调整和远程监测[6]。近日,美国天普大学发布了用于癫痫研究的TUH脑电图癫痫发作语料(TUH - EEGSC)[18],使其成为世界上最大的癫痫发作类型分类公开数据集。[16]的工作通过对各种标准机器学习算法的搜索空间探索,提出了TUH-EEGSC[18]癫痫类型分类的基线结果。其他方法,如[20,17],使用所选癫痫发作类型数据的子样本进行癫痫发作分析。在本文中,我们提出了一种集成学习方法,并使用癫痫发作和患者交叉验证提出了癫痫发作类型分类的新基准。本文的主要贡献如下:
-
我们提出了一个名为“SeizureNet”的深度学习框架,通过学习不同空间和频率分辨率的脑电图数据的特征嵌入来实现集成个体分类器的多样化。实验表明,我们的多谱特征学习在集成中增加了多样性,并减少了最终癫痫类型分类预测的方差。
-
我们提出了显著性编码频谱图,这是一种视觉表征,它捕捉了时间序列脑电图数据的频率变换中包含的显著性信息。实验表明,我们的显著性编码频谱图提高了TUH-EEGSC[18]癫痫发作分类的准确性。
-
我们评估了我们的框架通过知识蒸馏将知识转移到更小网络的能力,并给出了TUH-EEGSC[18]上癫痫发作类型分类的基准结果。
2 - 提出的框架(SeizureNet)
图1-A显示了我们框架的总体架构,该框架将原始时间序列脑电图信号转换为提出的显著性编码频谱图,并使用深度CNN模型的集合产生癫痫类型分类的预测。下面,我们将详细描述框架的各个组成部分。
2.1 显著编码频谱图
显著性编码频谱图的灵感来自于视觉显著性检测[7],在该方法中,我们将时序脑电图信号转换为视觉表征,从数据中捕获多尺度显著性信息。显著性编码频谱图包括三个特征图,如图1-D所示。
- 一个傅立叶变换图( F T FT FT),它对脑电图信号的对数振幅傅立叶变换进行编码;
- 一个谱显著性图( S 1 S1 S1),通过计算FT特征图的谱残差来提取显著性;
- 一个多尺度显著性图(S2),它利用FT特征图的特征的中心-环绕差异在多个尺度捕获谱显著性[13,9]。
从数学上讲,给定一个来自以时间 t t t为参数的通道 c c c的时间序列EEG序列 X ( c , t ) X(c, t) X(c,t),我们计算该序列的快速傅里叶变换( F \mathcal{F} F)为: F ( X ) = ∫ − ∞ ∞ X ( c , t ) e − 2 π i t d t \mathcal{F}(X) = ∫_{−∞}^{∞}X(c, t)e^{−2πit}dt F(X)=∫−∞∞X(c,t)e−2πitdt。我们对选定的20个通道的数据计算 F \mathcal{F} F并对傅里叶变换的幅度取对数。输出被重构为 R p × 20 \mathbb{R}^{p×20} Rp×20-维特征映射( F T FT FT),其中 p p p表示脑电图序列的数据点数。
在数学上, F T FT FT可以写成: F T = l o g ( A m p l i t u d e e ( F ( X ) ) ) FT = log(Amplitudee(\mathcal{F}(X))) FT=log(Amplitudee(F(X)))。
在数学上, S 1 S_1 S1可以写成: S 1 = G ∗ F − 1 ( e x p ( F T − H ∗ F T ) + P ) 2 S_1 = \mathcal{G} * F^{−1}(exp(FT−H * FT)+P)^2 S1=G∗F−1(exp(FT−H∗FT)+P)2,其中 F − 1 F^{−1} F−1表示傅里叶反变换, H H H代表 F T FT FT的平均谱,通过一个3 × 3的局部平均滤波器对 F T FT FT的特征图进行卷积逼近。 G \mathcal{G} G是一个高斯核,平滑特征值。术语 P P P表示特征图 F T FT FT的相位谱。
显著性图
S
2
S_2
S2通过计算多个尺度上的中心-环绕差异来捕捉特征图
F
T
FT
FT中与其周围数据点的显著性。设
F
T
i
FT_i
FTi表示位置
i
i
i的一个特征值,
Ω
Ω
Ω表示位置
i
i
i周围的一个尺度为
ρ
ρ
ρ的圆形邻域。在数学上,位置
i
i
i的显著性计算可以写成:
S
2
(
i
)
=
∑
ρ
∈
[
2
,
3
,
4
]
(
F
T
i
−
m
i
n
(
[
F
T
k
,
ρ
]
)
)
∀
k
∈
Ω
S_2(i)=\sum_{ρ∈[2,3,4]}(FT_i-min([FT_{k,ρ}])) \quad \forall k ∈\Omega
S2(i)=∑ρ∈[2,3,4](FTi−min([FTk,ρ]))∀k∈Ω。
[
F
T
k
,
ρ
]
[FT_{k,ρ}]
[FTk,ρ]表示本地邻域的特征值Ω。最后,我们将三个特征
F
T
、
S
1
FT、S1
FT、S1和
S
2
S2
S2连接到一个类似RGB的数据结构(
D
\mathcal{D}
D),该数据结构在0和255范围内标准化,如图1-D所示。
2.2 多谱特征学习
深度神经网络往往是过度参数化的,需要足够的训练数据来有效地学习特征,从而可以推广到测试数据。当面对有限的训练数据时(这是健康信息学[2]中的一个常见问题),深度架构往往会遭遇收敛性差或着过拟合等得问题。为了克服这些挑战,我们提出了多谱特征采样(MSFS),一种新的方法,通过使用不同频率和时间分辨率的数据采样来训练子网络,以鼓励集成学习的多样性。
图1-E显示了我们的MSFS方法的概述。假设有一个 M M M维训练数据集 D \mathcal{D} D = { ( D i , y i ) ∣ 0 ≤ i ≤ N d (\mathcal{D_i}, y_i)|0≤i≤N_d (Di,yi)∣0≤i≤Nd},由 N d N_d Nd个样本组成,其中 D i \mathcal{D}_i Di是一个训练样本,对应的类标签为 y i ∈ y y_i∈y yi∈y,在训练过程中,MSFS生成一个特征子空间 D m \mathcal{D^m} Dm = { ( D i m , y i ) ∣ 0 ≤ i ≤ N d (\mathcal{D_i^m}, y_i)|0≤i≤N_d (Dim,yi)∣0≤i≤Nd}包含随机选取采样频率 f ∈ F f∈F f∈F(Hz)、窗长参数 w ∈ W w∈W w∈W (seconds)、窗步长参数 o ∈ O † o∈\mathcal{O^†} o∈O†生成的谱图。这个过程重复 N e = 3 N_e = 3 Ne=3次,得到随机子空间{ D 1 m , … , D N e m D^m_1,…, D^m_{N_e} D1m,…,DNem},其中 N e N_e Ne为集合的大小。
2.3 提出的集合结构(SeizureNet)
SeizureNet由 N e N_e Ne个深度卷积神经网络(DCNs)组成。图1-A展示了三个子网络的SeizureNet架构。DCN的基本构建块是一个Dense块,它由多个bottleneck(瓶颈)卷积组成,通过Dense连接[8]互连。
具体来说,每个DCN模型以7 × 7卷积开始,然后是批归一化(BN)、修正线性单元(ReLU)和3 × 3平均池化操作。接下来,有四个Dense块,其中每个密集块由
N
l
N_l
Nl个称为Dense层的层组成,这些层共享来自通过fuse连接到当前层的所有前面层的信息。图1-B显示了
N
l
=
6
N_l = 6
Nl=6个Dense层的Dense块的结构。每个Dense层由1 × 1和3 × 3卷积组成,然后是BN、ReLU和dropout block,如图1- c所示。在数学上,Dense块中第
l
l
l个Dense层的输出可以写成:
X
l
=
[
X
0
,
…
,
X
l
−
1
]
\mathcal{X}_l = [\mathcal{X}_0,…, \mathcal{X}_{l−1}]
Xl=[X0,…,Xl−1],其中[···]表示第
0
,
.
.
.
,
l
−
1
0,...,l-1
0,...,l−1产生的特征的拼接。最终的Dense块产生
Y
d
e
n
s
e
∈
R
k
×
R
×
7
×
7
Y_{dense}∈\mathbb{R}^{k×R×7×7}
Ydense∈Rk×R×7×7−维特征,这些特征通过平均操作被压缩到
k
×
R
k×R
k×R维,然后被馈送到一个线性层
f
c
∈
R
K
f_c∈\mathbb{R}^K
fc∈RK,它学习关于
k
k
k个目标类的输入数据的概率分布。为了增加集成子网络之间的多样性,我们改变子网络的Dense block 3和Dense block 4的Dense层数。
2.4 训练和实施
考虑一个谱图和标签 ( D , y ) ∈ ( D , y ) (D, y)∈(\mathcal{D}, \mathcal{y}) (D,y)∈(D,y)的训练数据集,其中每个样本属于 K K K类中的一个 ( y = 1 , 2 , … , K ) (y = 1,2,…,K) (y=1,2,…,K)目标是确定一个函数 f s ( D ) : D → y f_s(D): D→y fs(D):D→y。为了学习这个映射,我们用 f ( D , θ ∗ ) f(D,θ^*) f(D,θ∗)参数化训练SeizureNet,其中 θ ∗ θ^* θ∗是通过最小化一个训练目标函数得到的学习参数: θ ∗ = a r g m i n θ L C E ( y , f ( D , θ ) ) θ^* = argmin_θ L_{CE}(y, f(D, θ)) θ∗=argminθLCE(y,f(D,θ)),其中 L C E L_{CE} LCE表示交叉熵损失(Cross Entropy Loss),该损失相对于基础真值标签应用于集合的输出。数学上, L C E L_{CE} LCE可以写成: L C E = ∑ k = 1 K Ⅱ ( k = y i ) l o g σ ( O e , y i ) L_{CE}=\sum_{k=1}^{K}Ⅱ(k=y_i)logσ(O_e,y_i) LCE=∑k=1KⅡ(k=yi)logσ(Oe,yi),其中 O e = 1 / N e ∑ e = 1 N e O k O_e=1/N_e \sum_{e=1}^{N_e} O_k Oe=1/Ne∑e=1NeOk表示集成后的logits, O k O_k Ok表示分对数由一个单独的子网络, Ⅱ Ⅱ Ⅱ是指标函数, σ σ σ是SoftMax操作。 σ ( z i ) = σ(zi) = σ(zi)= exp z i / ∑ k = 1 K z_i/ \sum_{k=1}^{K} zi/∑k=1K exp z k z_k zk。为了训练网络,我们从零均值高斯分布初始化网络的权值。标准偏差设置为0.01,偏差设置为0。我们为400个epoch训练网络,初始学习率为0.001(在总epoch数量的50%和75%时除以10),参数衰减为0.0005(在权重和偏差上)。我们的实现是基于Torch库[15]的自动梯度计算框架。训练由ADAM优化器进行,批大小为50。
3 实验和结论
我们使用TUH脑电图癫痫发作语料库(TUH- eegsc)[18],这是世界上最大的公开的带类型注释的癫痫发作记录数据集。2018年10月发布的TUH-EEGSC v1.4.0包含2012年,而2020年5月发布的TUH-EEGSC v1.5.2包含3050次癫痫发作。表1显示了不同发作类型和患者数量的TUH-EEGSC的统计。在实验中,我们将肌阵挛(MC)发作排除在研究之外,因为发作次数太少,无法进行有统计学意义的分析(只有3次发作),如表1所示。
为了进行评估,我们在患者层面和癫痫发作层面进行了交叉验证。
具体来说,对于TUH-EEGSC v1.4.0,我们考虑了癫痫发作的交叉验证。由于TUH-EEGSC v1.4.0中的TN和SP发作类型仅包含2例患者的数据,因此患者交叉验证不会产生有统计学意义的结果。因此,我们考虑了5 folds癫痫发作交叉验证,即不同发作类型的癫痫发作被平均和随机地分配到5 folds。
对于TUH-EEGSC v1.5.2,我们考虑了患者交叉验证。表1显示TUH-EEGSC v1.5.2中所选择的7种癫痫发作类型包含至少3例患者的数据,允许进行具有统计学意义的3倍患者交叉验证。在这种情况下,将数据分为训练子集和测试子集,因此训练子集和测试子集中的癫痫发作总是来自不同的患者。这种方法使得改善模型性能更具挑战性,但由于它支持在患者中推广模型,因此具有更高的临床相关性。
表2和表4分别显示了TUH-EEGSC在患者和癫痫发作方面的交叉验证结果。结果表明,与现有方法相比,SeizureNet使患者加权F1分数提高了3分左右,发作加权F1分数提高了4分左右。这些改进主要归功于所提出的多谱特征学习,它从不同的频率和空间分辨率捕获信息,使SeizureNet能够比其他方法学习更多的鉴别性特征。
3.1 SeizureNet用于知识蒸馏
在这里,我们评估了SeizureNet将知识迁移到更小的网络进行癫痫发作分类的能力。为此,我们使用基于知识蒸馏的训练函数,将3个残差层与作为教师网络的SeizureNet结合起来,训练一个学生ResNet模型。我们的训练函数
L
K
D
L_{KD}
LKD是交叉熵损失项
L
C
E
L_{CE}
LCE和蒸馏损失项
L
K
L
L_{KL}
LKL的加权组合。LKD在数学上可以表示为:
L
K
D
=
α
⋅
L
C
E
(
P
t
,
y
)
+
β
⋅
L
C
E
(
P
s
,
y
)
+
γ
⋅
L
K
L
L_{KD} = α·L_{CE}(P_t, y) +β·L_{CE}(P_s, y) +γ·L_{KL}
LKD=α⋅LCE(Pt,y)+β⋅LCE(Ps,y)+γ⋅LKL,其中
P
t
P_t
Pt和
P
s
P_s
Ps分别表示SeizureNet和学生模型的logits(输入到SoftMax)。
α
∈
[
0
,
0.5
,
1
]
,
β
∈
[
0
,
0.5
,
1
]
,
γ
∈
[
0
,
0.5
,
1
]
α∈[0,0.5,1],β∈[0,0.5,1],γ∈[0,0.5,1]
α∈[0,0.5,1],β∈[0,0.5,1],γ∈[0,0.5,1]是平衡个体损失项的超参数。蒸馏损失项
L
K
L
L_{KL}
LKL由SeizureNet输出计算得到的对数概率和学生模型之间定义的Kullback-Leibler (KL)散度函数组成。LKL在数学上可表示为:
L
K
L
(
P
s
,
P
t
)
=
σ
(
P
s
)
⋅
(
l
o
g
(
σ
(
P
s
)
)
−
σ
(
P
t
)
/
T
)
L_{KL}(P_s, P_t) = σ(P_s)·(log(σ(P_s))−σ(P_t)/T)
LKL(Ps,Pt)=σ(Ps)⋅(log(σ(Ps))−σ(Pt)/T),其中σ表示SoftMax操作,T为温度超参数,控制输出的软化?。表5显示,与没有知识蒸馏的ResNet模型相比,ResNet-KD在癫痫发作平均F1评分上提高了约2%。这些结果表明,由SeizureNet学习的特征嵌入可以有效地用于提高较小网络(例如,3层ResNet具有45×更少的训练参数,1100×更少的FLOPS数,45×更快的推理速度,如表5所示)在内存受限的系统中部署的准确性。
3.2 显著编码谱的意义
从图2可以看出,本文提出的显著性编码光谱图中傅立叶变换的谱残差和多尺度中心-环绕差信息的组合对癫痫发作分类具有高度的鉴别性,特别是对小的训练数据。例如,当只有10%的数据用于训练时,使用显著编码频谱图训练的模型对所有目标类的混淆产生了相当大的减少。
3.3 多谱特征学习的意义
表3显示,与特定频率的模型相比,使用MSFS训练的模型产生了更高的F1分数。例如,当只使用50%的训练数据时,使用MSFS训练的模型与没有使用MSFS训练的模型相比,癫痫发作相关F1分数提高了约9分。这些改进表明,来自不同频带的信息相互补充,可以更好地识别发作类别,尤其是在数据规模较小的情况下,与在独立频带学习的特征相比。图3显示了使用和不使用MSFS训练的模型产生的TSNE映射的比较。结果表明,相对于不使用MSFS产生的癫痫发作manifolds(如图3-B),使用MSFS产生的癫痫发作manifolds在高维特征空间中得到了更好的分离(如图3-A所示)。这表明,结合不同空间和频带的数据,增加训练信息的变化,有利于学习癫痫发作分类的鉴别特征。图3还显示,使用MSFS训练的模型产生的混淆更少(如图3-C所示),显示了组合来自不同频率和空间分辨率的数据的重要性。
4 - 结论和未来工作
本文提出了一个名为SeizureNet的深度学习框架,用于在患者交叉验证场景中基于EEG的发作类型分类。以患者为中心的验证方法的最大挑战是从有限的训练数据中学习健壮的特征,这些特征可以有效地归纳到不可见的测试患者数据中。这是通过两个新颖的贡献实现的:i)显著性编码频谱图,对脑电图信号的频率变换中包含的多尺度显著性信息进行编码:ii)集成体系结构中的多谱特征学习,在该体系结构中,不同频率和空间分辨率生成的频谱图鼓励通过网络的信息流的多样性。而集成降低了最终预测的方差。在世界上最大的公开可用癫痫数据集上的实验表明,与现有方法相比,我们的缉获网在癫痫类型分类方面产生了具有竞争力的F1分数。实验还表明,通过知识提取,学习到的特征嵌入大大提高了较小网络的准确率。未来,我们计划研究可穿戴传感器和视频数据的融合,用于真实癫痫监测单元的多模态癫痫类型分类。