数据描述
来自卡塔尔多哈卡塔尔大学和孟加拉国达卡大学的一组研究人员,以及来自巴基斯坦和马来西亚的合作者与医生合作,建立了一个针对COVID-19阳性病例的胸部X射线图像数据库,以及正常和病毒性肺炎图像。
数据来源
https://www.heywhale.com/mw/dataset/6027caee891f960015c863d7
数据说明
COVID-19阳性病例的胸部X射线图像以及正常和病毒性肺炎图像的数据库。 数据包含有1200个COVID-19阳性图像,1341正常图像和1345病毒性肺炎图像。
#导入需要的包
import os
import math
import zipfile
import random
import json
import cv2
import numpy as np
from PIL import Image
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import Linear,Conv2D,Pool2D
import matplotlib.pyplot as plt
paddle.enable_static() #转换为静态图
定义了一个参数字典configs,其中包含了各种参数的设置,输入图片的大小、分类数、数据路径、模型保存路径、学习率、批次大小和学习次数等。这些参数的设置将在后面的程序中被调用和使用。
#一些参数的设置
configs = {
"input_size": [3, 1024,1024], #输入图片的shape
"class_dim":3, #分类数
'src_path':'data/data82373/input_data.rar', #数据的路径
'train_path':'input_data', #解压路径
'model_save_dir':'save_model', # 模型保存路径
'learning_rate':0.001, #学习率
'batch_size':32, #批次大小
'epoch':10 #学习次数
}
进行图片的预处理,获取训练集和测试集的数据。定义了三个空列表COVID、NORMAL和Viral_Pneumonia,分别用于存放新冠肺炎患者的胸透图片、正常人的胸透图片和病毒性肺炎患者的胸透图片。另外还定义了三个与之对应的标签列表COVID_label、NORMAL_label和Viral_Pneumonia_label。
通过遍历指定目录下的三个子目录,将每个图片的路径名存放到对应的列表中,并给相应的标签赋值。同时,展示了一个例子的图片和对应的标签。
接下来,将不同类型的图片路径和标签进行合并,并打乱顺序。然后将图片路径列表和标签列表转换为numpy数组,并进行随机打乱顺序。
最后,根据指定的比例(ratio),将全部样本分成训练集和测试集。并返回训练集和测试集的图片路径和标签。
#预处理图片,获取训练和测试的数据集
COVID =[] #新冠肺炎患者的胸透图片
COVID_label = []
NORMAL = [] #正常人的胸透图片
NORMAL_label = []
Viral_Pneumonia = [] #病毒性肺炎患者的胸透图片
Viral_Pneumonia_label = []
# 获取所以图片的路径名
# 对应的列表中,同时贴上标签,存放到label列表中
def get_files(file_path, ratio):
for file in os.listdir(file_path + '/COVID'):
COVID.append(file_path + '/COVID' + '/' + file)
COVID_label.append(0) # 0为新冠肺炎患者
for file in os.listdir(file_path + '/NORMAL'):
NORMAL.append(file_path + '/NORMAL' + '/' + file)
NORMAL_label.append(1) # 1为正常人
for file in os.listdir(file_path + '/Viral_Pneumonia'):
Viral_Pneumonia.append(file_path + '/Viral_Pneumonia' + '/' + file)
Viral_Pneumonia_label.append(2) # 2为病毒性肺炎患者
#检测是否读取成功
for i in range(1):
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
img = plt.imread(NORMAL[