一、整体流程
1. 数据集下载地址:https://www.kaggle.com/paultimothymooney/chest-xray-pneumonia/download
2. 数据集展示
案例主要流程:
第一步:加载预训练模型ResNet,该模型已在ImageNet上训练过。
第二步:冻结预训练模型中低层卷积层的参数(权重)。
第三步:用可训练参数的多层替换分类层。
第四步:在训练集上训练分类层。
第五步:微调超参数,根据需要解冻更多层。
ResNet 网络结构图
二、显示图片功能
#1加载库
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import os
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
#2、定义一个方法:显示图片
def img_show(inp, title=None):
plt.figure(figsize=(14,3))
inp = inp.numpy().transpose((1,2,0)) #转成numpy,然后转置
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224,0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001)
plt.show()
def main():
pass
#3、定义超参数
BATCH_SIZE = 8
DEVICE = torch.device("gpu" if torch.cuda.is_available() else "cpu")
#4、图片转换 使用字典进行转换
data_transforms = {
'train': transforms.Compose([
transforms.Resize(300),
transforms.RandomResizedCrop(300) ,#随机裁剪
transforms.RandomHorizontalFlip(),
transforms.CenterCrop(256),
transforms.ToTensor(), #转为张量
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]) #正则化
]),
'val': transforms.Compose([
transforms.Resize(300),
transforms.CenterCrop(256),
transforms.ToTensor(), #转为张量
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]) #正则化
])
}
#5、操作数据集
# 5.1、数据集路径
data_path = "D:/chest_xray/"
#5.2、加载数据集的train val
img_datasets = { x : datasets.ImageFolder(os.path.join(data_path,x),
data_transforms[x]) for x in ["train","val"]}
#5.3、为数据集创建一个迭代器,读取数据
dataloaders = {x : DataLoader(img_datasets[x], shuffle=True,
batch_size= BATCH_SIZE) for x in ["train","val"]
}
# 5.4、训练集和验证集的大小(图片的数量)
data_sizes = {x : len(img_datasets[x]) for x in ["train","val"]}
# 5.5、获取标签类别名称 NORMAL 正常 -- PNEUMONIA 感染
target_names = img_datasets['train'].classes
#6 显示一个batch_size 的图片(8张图片)
#6.1 读取8张图片
datas ,targets = next(iter(dataloaders['train'])) #iter把对象变为可迭代对象,next去迭代
#6.2、将若干正图片平成一副图像
out = make_grid(dat