Andrew Y. Ng式ResNet在MIT-BIH上的Inter-Patient分类实现(1)

       相信不少人已经早就看过了吴恩达(Andrew Y. Ng)大佬发表在Nature Medicine上的ECG诊断算法方案。深度学习网络一举打败了经验丰富的心血管疾病专家,在常见的12种ECG正常或异常判定中的F1达到了0.837,而参与实验的6位医生的平均F1却仅有0.780。对这样的结果,当然可以有所质疑,毕竟这几位医生可能并不能代表人类的诊断水平,不过这确实再一次反映了深度学习的强大。没看过的童鞋建议去看看:https://www.nature.com/articles/s41591-018-0268-3 ,说不定就有所启发呢。

       剥去吴大佬以及Nature子刊的光环,我们来看看这篇文章到底好在哪里。说实话,如果你真的认真看了一下原文以及对这个领域有所了解的话,你可能就会觉得这篇文章的核心并不在于他用了多新颖的网络结构(ResNet-34的小改版),而是在于那个包含91232条记录,并附带详细专家标注的数据库;另外,算法真的跟实际硬件产品挂钩了,就是iRhythm的Zio便携监护仪,并通过了FDA认证。以上两点再加上出众的效果,并且有人类专家的参与实验和背书,多方的合作造就了这么一个完整,扎实,有意义的研究。

iRhythm Zio, 图来自iRhythm 官网:https://www.irhythmtech.com/products-services/zio-xt  侵删

比较尴尬的是,虽然吴恩达课题组开源了代码,但可能是基于一些其他方面的考虑,对使用的大数据库并未开源。这就使得复现算法变得很困难。我们能做的,也就仅有利用手头的一些数据库去进行研究了。本人之前曾写过在MIT-BIH数据库上进行入门的系列博客,出于对小白友好的考虑,因此整体方案为intra-patient,而且样本整理地非常均衡。不过有很多童鞋想看一看inter-patient的实践,又恰好最近的科研中也涉及到了MIT-BIH数据库,因此就结合吴恩达的网络,对MIT-BIH数据库的inter-patient诊断做了一个小的尝试。目前代码已经开源:。主要内容如下:

  1. 缩减了网络。由于MIT-BIH数据库总样本量不大,因此对原有网络进行了缩减,由吴恩达改装版ResNet-34变为改装版ResNet-18。
  2. 调整了参数。网络结构和任务都发生了很大变化,因此对一些细节参数进行了调整。
  3. 根据AAMI标准划分5类,N,V,S,F,Q。训练集和测试集采用了最常用的DS1(22条记录)-DS2(22条记录)inter-patient划分。
  4. 不进行QRS波检测,对一个任意5s切片进行诊断,个人认为这样可以让网络脱离对QRS波检测算法精度的依赖,更贴近实际一些。
  5. 进行了与其他目前已发表工作的性能分析和对比。

当然也有几点声明。首先这不是一篇论文,只是一篇仅用于学习交流的博客,所以并不会有什么特别的创新或是深邃的思想;其次有些地方包括最终的效果还存在很多改进空间,因此大家也不要期望过高。如果有大佬可以提出具体的,有建设性的改进方案,非常欢迎。

下篇正式开始。

Github:https://github.com/Aiwiscal/MIT_Scheme

喜欢请点赞和github给star哦~~~~

好的,我可以为您提供代码实现。在这里,我将使用Keras中的ResNet50预训练模型,并使用Fashion-MNIST数据集对十种服装进行分类。首先,我们需要安装一些必要的库: ``` !pip install tensorflow !pip install keras !pip install matplotlib ``` 接下来,我们将加载数据集并进行预处理: ```python import numpy as np import keras from keras.datasets import fashion_mnist from keras.preprocessing.image import ImageDataGenerator # 数据集路径 (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data() # 将图像转换为RGB格 x_train = np.repeat(x_train[..., np.newaxis], 3, -1) x_test = np.repeat(x_test[..., np.newaxis], 3, -1) # 批量大小 batch_size = 32 # 数据增强 train_datagen = ImageDataGenerator( rescale=1./255, shear_range=0.2, zoom_range=0.2, horizontal_flip=True) # 没有数据增强的验证数据生成器 val_datagen = ImageDataGenerator(rescale=1./255) # 训练集生成器 train_generator = train_datagen.flow( x_train, keras.utils.to_categorical(y_train), batch_size=batch_size) # 验证集生成器 val_generator = val_datagen.flow( x_test, keras.utils.to_categorical(y_test), batch_size=batch_size) ``` 接下来,我们将加载ResNet50模型,并对其进行微调,以适应我们的数据集: ```python from keras.applications.resnet50 import ResNet50 from keras.layers import Dense, GlobalAveragePooling2D from keras.models import Model # 加载ResNet50模型,不包括顶层(全连接层) base_model = ResNet50(weights='imagenet', include_top=False) # 添加全局平均池化层 x = base_model.output x = GlobalAveragePooling2D()(x) # 添加全连接层,输出为十个类别 predictions = Dense(10, activation='softmax')(x) # 构建我们需要训练的完整模型 model = Model(inputs=base_model.input, outputs=predictions) # 冻结ResNet50的所有层,以便在训练过程中不更新它们的权重 for layer in base_model.layers: layer.trainable = False ``` 现在,我们可以开始训练模型了: ```python from keras.optimizers import SGD # 编译模型,指定损失函数、优化器和评价指标 model.compile(loss='categorical_crossentropy', optimizer=SGD(lr=0.001), metrics=['accuracy']) # 训练模型 history = model.fit_generator( train_generator, steps_per_epoch=x_train.shape[0] // batch_size, epochs=10, validation_data=val_generator, validation_steps=x_test.shape[0] // batch_size) ``` 最后,我们可以使用matplotlib库绘制训练和验证的准确率和损失曲线: ```python import matplotlib.pyplot as plt # 绘制训练和验证的准确率曲线 plt.plot(history.history['accuracy']) plt.plot(history.history['val_accuracy']) plt.title('Model accuracy') plt.ylabel('Accuracy') plt.xlabel('Epoch') plt.legend(['Train', 'Val'], loc='upper left') plt.show() # 绘制训练和验证的损失曲线 plt.plot(history.history['loss']) plt.plot(history.history['val_loss']) plt.title('Model loss') plt.ylabel('Loss') plt.xlabel('Epoch') plt.legend(['Train', 'Val'], loc='upper left') plt.show() ``` 现在您应该可以使用这些代码实现您的需求了。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值