1、引言
无论是本科生还是硕士生,关于心电信号的课题主要就是围绕着这几个大方向,今天主要介绍心电信号的分类。
2、概述
对于心电信号分类,目前主流的做法是直接将心电信号输入到深度学习模型,使用模型来自动提取特征,之后再输入到全连接层进行分类。
使用深度学习模型的优势在于:1、模型可以自己根据心电信号的特点来学习到相关的特征,而无需进行复杂的人工特征提取;2、可以将计算机视觉(Computer Vision,CV)领域的一些方法和创新点应用到心电信号中,这样更好发文章(懂的都懂)。
3、使用深度学习模型来进行心电信号分类的基本流程
概括地说来,心电信号分类的基本流程包括数据集的读取,模型的搭建、在训练集上进行模型的训练和在测试集上进行模型的测试。
3.1、心电信号的采集/心电信号数据集的选择
对于心电信号来说,使用到的数据可以是私有数据集,也可以使用公开数据集。
对于私有数据集,这个不过多介绍,可以通过自己采集或者与医院的心内科合作。自己采集的通常是单导联或者2导联信号,医院的通常是标准12导联信号。
对于公开数据集,常用的数据集可以从PhysioNet 网站进行获取。包括MITDB(MIT-BIH Arrhythmia Database)、AFDB(MIT-BIH Atrial Fibrillation Database)、PTB-XL(PTB-XL, a large publicly available electrocardiography dataset)等等,具体需要使用什么数据集需要根据自己的课题来选择。此外,PhysioNet每年都会进行有关生理信号的比赛,比赛使用到的数据集也是经常用到的。包括CINC 2011(用于12导联心电信号的质量分类)、CINC 2017(用于四个类别的分类)、CINC 2020、CINC 2021和CINC 2024等。除去PhysioNet,国内的东南大学也会举办比赛,涉及到的数据集包括CPSC 2018、CPSC 2019、CPSC2020和CPSC 2021等。
还是一句话,具体需要使用哪个数据集,取决于要做的任务。心电信号的分类可以分为很多,包括常见的心电信号质量分类、心律失常分类、情绪识别和睡眠呼吸暂停等,需要根据自己的任务来选择适当的数据集。
3.2 心电信号的预处理
常见的心电信号预处理包括滤波(去噪)、分割(划分为长度均为5秒/10秒的片段)、重采样和标准化,这些都比较简单。
- 对于滤波,也是根据具体任务来进行,常见的包括使用带通滤波和小波变换。
- 分割也是为了使得输入到模型的信号长度一致,这样有利于进行模型的训练。此外,除去分割,也可以选择使用零填充将较短的信号填充到更长的长度。
- 重采样是为了统一来自不同数据集中心电信号的采样率,通常使用SciPy中的resample函数进行操作。
- 标准化同样是为了加速模型的训练,通常可以使用Z-Score或者Min-Max Normalization将心电信号进行标准化。
3.3 构建数据集
通过torch.utils.data.Dataset()实现读取数据集的心电信号及其对应的标签,然后通过torch.utils.data.DataLoader()来迭代读取Batch个Dataset对象。对应代码如下:
class ECGDataset(Dataset):
def __init__(self, data_dir):
super(ECGDataset, self).__init__()
self.data_dir = data_dir
self.npy_files = os.listdir(self.data_dir)
def __getitem__(self, item):
data, label = load_data(os.path.join(self.data_dir, self.npy_files[item]))
return torch.from_numpy(data).float(), torch.as_tensor(label).type(torch.LongTensor)
def __len__(self):
return len(self.npy_files)
实例化一个dataset对象,并实例化一个dataloader:
train_dataset = ECGDataset(train_path)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
3.4 模型的搭建
与图像不一样,心电信号是一维时序信号,因此通常使用一维卷积(Conv1d)来进行特征提取。但正如前面所说,可以将CV领域的模型应用到心电信号中,一个技巧就是将其中的二维卷积(Conv2d)替换为Conv1d,同时其他如二维池化(MaxPool2d)等操作也需要替换为一维池化(MaxPool1d)。
以经典的ResNet为例,对于ResNet中的BasicBlock,将其中的卷积层、BN层替换为1维后,即可用于心电信号。
def conv3x3(in_planes, out_planes, kernel_size, stride):
return nn.Conv1d(in_planes, out_planes, kernel_size, stride=stride, padding=3, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride=stride)
self.bn1 = nn.BatchNorm1d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm1d(planes)
self.downsample = downsample
self.stride = stride
self.dropout = nn.Dropout(.2)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
3.5 训练及预测
接下来,就可以定义一些训练和验证函数来进行模型的训练。
def train(model, train_loader, criterion, optimizer, device):
model.train()
train_loss, train_acc = 0.0, 0.0
for data, labels in tqdm(train_loader):
optimizer.zero_grad()
outputs = model(data.to(device))
loss = criterion(outputs, labels.to(device))
loss.backward()
optimizer.step()
train_loss += loss.item()
train_acc += (outputs.argmax(dim=1) == labels.to(device)).sum().item() / labels.size(0)
train_loss /= len(train_loader)
train_acc /= len(train_loader)
return train_loss, train_acc
def evaluate(model, valid_loader, criterion, device):
model.eval()
valid_loss, valid_acc = 0.0, 0.0
with torch.no_grad():
for data, labels in tqdm(valid_loader):
outputs = model(data.to(device))
loss = criterion(outputs, labels.to(device))
valid_loss += loss.item()
valid_acc += (outputs.argmax(dim=1) == labels.to(device)).sum().item() / labels.size(0)
valid_loss /= len(valid_loader)
valid_acc /= len(valid_loader)
return valid_loss, valid_acc
加载好训练集、验证集和测试集,实例化模型、损失函数、优化器后,便可以进行模型的训练。训练过程中的学习率、batch_size的大小,都需要进行参数调整。对于学习率,可以使用学习率调度程序实现动态调整;对于batch_size,可以使用网格搜索算法来遍历最合适的大小。此外,最好使用早停(Early Stopping)来及时终止模型的训练。
保存训练好的模型权重后,便可以加载模型权重,在测试集上进行测试,得到一些评价指标。常用的分类指标包括准确率、召回率、精确度和F1值,这些都是通过测试集上的结果来反映。此外,还可以绘制混淆矩阵或者使用一些类似主成分分析技术,以实现更直观的展示。
4、结语
以上是一些基本的介绍,可以作为一个通用的流程来进行学习。目前这一版本基本都是文字和代码介绍,后期时间充裕的话,会以图文+代码的形式进行讲解。
如果对于使用PyTorch进行深度学习实验有问题的同学,可以评论或者私信,我有时间都会解答。