[代码实战]手把手带你训练一个COVID检测网络,准确率高达90%

本次实战的的概况如下:

  • 代码来源:https://github.com/junaidiqbalsyed/Covid_detection_CNN
  • 目的:使用 CNN(vgg and resnet)检测 COVID 并使用GRAD-CAM进行可视化
  • 方法: 二分类(normal or covid)
  • 框架: keras (不会没关系,很简单)
  • 结果:源代码在VGG-16的准确率为 85%, 我使用resnet跑的结果为88%,甚至90%。
  • 难易程度: ⭐️⭐️

结果展示

  • resnet准确度

  • 可视化结果

感兴趣的话一起进入代码环节吧~~


1 准备工作

  • 在github上下载源代码
  • 下载数据集(约800M)左右。
    如果你网络够好(可随意打开GitHub,Google网站),下载数据集这一步可省略,后面通过代码下载。如果网络不好,建议先下载数据集并解压。
  • 配置环境
# 查看版本
import tensorflow as tf
import keras
print(tf.__version__) # 2.4.1
print(keras.__version__)  # 2.4.3

注意:这里有个坑,当我在 tf版本 2.0.0 keras 版本 2.3.1运行时,准确度一直在0.5左右徘徊,没有提升。如果你遇到了相同的问题,请重新建一个环境。

2 使用jupyter运行代码

文件:Covid_detection_using_chest_X_Ray(using_ResNet_50)%20(2).ipynb

这部分主要解读代码。

导入包的部分自动略过~~

2.1 数据集是否下载

如果刚开始你没有下载数据集,可运行 cell2 and cell3 下载数据集及解压

!wget https://www.dropbox.com/s/e1r2laj50nh4tez/COVID-19_Radiography_Dataset.zip?dl=0
!unzip "/content/COVID-19_Radiography_Dataset.zip?dl=0"

如果已经下载好,则这两步省略~~~~

2.2 探索数据

直接运行代码即可,不过多介绍。

可以简单探索下:数据的名字,格式,大小,以及下载来源。比如这里的covid数据来自sirm网站。

2.3 把所有类别的数据放在同一个文件夹并可视化

下载好的数据集是按照疾病分类的。这里我们把所有图像放在同一个文件夹。

ROOT_DIR = "/content/COVID-19_Radiography_Dataset/"
imgs = ['COVID','Lung_Opacity','Normal','Viral Pneumonia']

NEW_DIR = "content/all_images/"

需要注意这部分代码的地址。作者在地址前加了”/“,如/content,表示content目录在根目录中,如果你的数据不在根目录,就不要加前面的”/“。

同理,如果后文出现了找不到文件夹,很有可能是你这里加了”/“, 去掉即可。

可视化看看数据集的分布

下载的数据集有4类,但是我们代码中实际上只用到了2类(COVID 和 Normal)

2.4 把数据集分成训练集、验证集和测试集

if not os.path.exists(NEW_DIR+"train_test_split/"):

  os.makedirs(NEW_DIR+"train_test_split/")
  .....
  os.makedirs(NEW_DIR+"train_test_split/validation/Covid")


  # Train Data
  for i in np.random.choice(replace= False , size= 3000 , a = glob.glob(NEW_DIR+imgs[0]+"*") ):
    shutil.copy(i , NEW_DIR+"train_test_split/train/Covid" )
    os.remove(i)

  for i in np.random.choice(replace= False , size= 3900 , a = glob.glob(NEW_DIR+imgs[2]+"*") ):
    shutil.copy(i , NEW_DIR+"train_test_split/train/Normal" )
    os.remove(i)
    ....   

这部分代码里有很多可以学习的地方:

比如,我们要在所有covid图像中,随机选取3000个作为训练集。怎么做到?

答案: np.random.choice()

当这部分选取作为训练集后,如何保证这部分数据在验证集和测试集中选不到它。

答案:os.remove(i) 当被选取后,删除它,那么再选择的时候就选择不到它了。

现在,数据集有了。

2.5 为keras生成数据流

train_path  = "content/all_images/train_test_split/train"
valid_path  = "content/all_images/train_test_split/validation"
test_path   = "content/all_images/train_test_split/test"

这是我们数据集存放的地址。再强调一次,content前面的’/‘我已经去掉。

为各数据集生成keras可识别的数据

train_data_gen = ImageDataGenerator(preprocessing_function= preprocess_input, 
                                    zoom_range= 0.2, 
                                    horizontal_flip= True, 
                                    shear_range= 0.2,
                                    
                                    )

train = train_data_gen.flow_from_directory(directory= train_path, 
                                           target_size=(224,224))

训练集中: Found 7800 images belonging to 2 classes.

2.6 构建模型

res = ResNet50( input_shape=(224,224,3), include_top= False) 
# include_top will consider the new weights

include_top= False表示不要全连接层,只加载特征部分。

这里加载的预训练权重是在:https://storage.googleapis.com

注意
我用其他版本keras加载的权重是在:https://github.com/
不同的版本会有区别,也会影响到结果。我也不知道为啥。

你如果得不到一个较好的结果,程序也没有报错,可能就是这步出现了问题,注意检查。

2.6 冻结特征层,这里我们只训练网络最后一层

for layer in res.layers:           # Dont Train the parameters again 
  layer.trainable = False

2.7 添加全连接层

x = Flatten()(res.output)
x = Dense(units=2 , activation='sigmoid', name = 'predictions' )(x)

# creating our model.
model = Model(res.input, x)

可以通过 model.summary()查看每一层的信息

2.8 训练模型

model.compile( optimizer= 'adam' , loss = 'categorical_crossentropy', metrics=['accuracy'])

es = EarlyStopping(monitor= "val_accuracy" , min_delta= 0.01, patience= 3, verbose=1)
mc = ModelCheckpoint(filepath="bestmodel.h5", monitor="val_accuracy", verbose=1, save_best_only= True)

设置优化器,loss, 评估指标
通过监控val_accuracy来保存网络,使用EarlyStopping来结束训练。

然后,开始训练👇

hist = model.fit_generator(train, steps_per_epoch= 10, epochs= 30, validation_data= valid , validation_steps= 16, callbacks=[es,mc])

model.fit_generator可以了解一下keras的这个函数,可参数的意思,不想了解直接运行就对了。

你的运行准确度应该在80%以上才是正确的,如果结果异常,检查哪一步除了问题。并解决它。

2.9 加载模型查看训练历史结果

## load only the best model 
from keras.models import load_model
model = load_model("bestmodel.h5")

查看保存了哪些历史信息

图片中可以看到,分别保存了训练集和验证集的loss和acc.用matplotlib画出来

plt.plot(h['accuracy'])
plt.plot(h['val_accuracy'] , c = "red")
plt.title("acc vs v-acc")
plt.show()

在测试集上评估模型

acc = model.evaluate_generator(generator= test)[1] 
print(f"The accuracy of your model is = {acc*100} %")

The accuracy of your model is = 88 %

2.10 如何对测试集单张测试

这里需要注意的是,即便是单张测试,也需要对图像进行预处理。

预处理方法同测试集的方法一样。

这里使用from keras.preprocessing import image方法进行预处理。需要统一大小(224,224,3), 转化成nunpy数据,并添加一个batch维度。

def get_img_array(img_path):
  """
  Input : Takes in image path as input 
  Output : Gives out Pre-Processed image
  """
  path = img_path
  img = image.load_img(path, target_size=(224,224,3))
  img = image.img_to_array(img)
  img = np.expand_dims(img , axis= 0 )
  
  return img

处理好的图像就可以通过model.predict(img)进行预测。

2.11 可视化

keras的可视化代码我没研究过,感兴趣的自行研究。

可视化结果如下:
这是最后一个卷积层获得的热力图 尺寸为7*7

将它放大到原始图像一样大,并叠加在原始图像上的效果如下:

我们在可视化一个normal样本


可以发现,健康样本的热力图为全白图像,叠加在原始图像使得整个图像偏蓝。

3 可能遇到的问题总结

  • 环境问题
    如果你的环境可能存在问题,建议尝试重新创建一个虚拟环境,安装tensorflow and keras
conda install tensorflow-gpu
conda install keras

我用的版本为:tf: 2.4.1 keras: 2.4.3

  • 找不到文件问题
    如果你的数据地址前面加了”/“,表示根目录,通常我们不会把数据放在根目录,删掉”/“。

  • 得到的结果跟我的差很远,甚至网络train不动,acc一直在0.5左右 第一可能是环境问题,重新安装环境后,还存在此问题,那么可能是网络问题,预训练权重没下载下来,可以多尝试几次。

如果不出问题,把数据集下载好,训练只要几分钟的时间。

tip:我这里只用resnet进行了实验,你还可以尝试train另一个用vgg16训练的文件.

希望您能享受这次实验,并从中获取知识~~

文章持续更新,可以关注微信公众号【医学图像人工智能实战营】获取最新动态,一个关注于医学图像处理领域前沿科技的公众号。坚持已实践为主,手把手带你做项目,打比赛,写论文。凡原创文章皆提供理论讲解,实验代码,实验数据。只有实践才能成长的更快,关注我们,一起学习进步~

我是Tina, 我们下篇博客见~

白天工作晚上写文,呕心沥血

觉得写的不错的话最后,求点赞,评论,收藏。或者一键三连
在这里插入图片描述

  • 3
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 10
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Tina姐

我就看看有没有会打赏我

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值