在本例中,将使用三组人的心电图数据:心律失常(ARR)患者、充血性心力衰竭(CHF)患者和窦性心律正常(NSR)患者。
一共有162条信号。其中,96条信号来自心律失常患者,30条信号来自充血性心力衰竭患者,36条信号来自正常窦性心律患者。本例的目标是训练一个分类器来区分ARR、CHF和NSR。
ECGData is a structure array with two fields: Data and Labels. The Data field is a 162-by-65536 matrix where each row is an ECG recording sampled at 128 hertz. Labels is a 162-by-1 cell array of diagnostic labels, one for each row of Data. The three diagnostic categories are: 'ARR', 'CHF', and 'NSR'.
时频表示也叫做scalograms,它是一个信号的CWT系数的绝对值。
要创建scalograms,要预计算一个CWT滤波器组。预计算滤波器组是使用相同参数来获取信号的连续小波变换的首选方法,并将每个scalograms转换成224✖224✖3的RGB图片以符合GoogLeNet网络的输入。
将80%划分成训练集,剩余的20%作为验证集。因此,训练集的样本个数为130,验证集的样本个数为32。
4)加载并重新训练GoogLeNetGoogLeNet的输入尺寸是224✖224✖3。
GoogLeNet中的每一层都可以看作是一个滤波器。靠前的层识别输入图像中较为一般的特征,如斑点、边缘和颜色。靠后的层更关注具体的特征,以便区分不同的类别。经过预训练,GoogLeNet可以将图像分成1000个对象类别。对于心电图分类问题,必须重新训练GoogLeNet。
由于本实例采用了迁移学习的思路,所以只需要更新最后几层的权重。训练一个神经网络是一个迭代过程,包括最小化一个损失函数。为了使损失函数最小化,使用了梯度下降算法。在每次迭代中,计算损失函数的梯度并更新下降算法的权值。
训练完成后,使用验证集来评估网络的性能。
CNN的每一层可以看作是对输入图片的响应或者激活。靠前的层会捕捉基本的图片特征,像边缘和斑点;靠后的层则更关注具体的特征,以便区分不同的类别。通过检查各层的输出,可以看出卷积神经网络究竟学到了什么。