基于Pytorch肺部感染识别案例(采用ResNet网络结构)

本文介绍了基于Pytorch利用ResNet网络结构进行肺部感染识别的案例,详细阐述了从数据集下载到模型训练的整体流程,包括加载预训练模型、冻结低层卷积层、替换分类层、训练和微调超参数。同时,展示了如何在Jupyter Notebook中显示数据集图片。
摘要由CSDN通过智能技术生成

一、整体流程

1. 数据集下载地址:https://www.kaggle.com/paultimothymooney/chest-xray-pneumonia/download

2. 数据集展示

案例主要流程:

第一步:加载预训练模型ResNet,该模型已在ImageNet上训练过。

第二步:冻结预训练模型中低层卷积层的参数(权重)。

第三步:用可训练参数的多层替换分类层。

第四步:在训练集上训练分类层。

第五步:微调超参数,根据需要解冻更多层。

ResNet 网络结构图

二、显示图片功能

#1加载库
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import os
from torchvision.utils import make_grid

from torch.utils.data import DataLoader
#2、定义一个方法:显示图片
def img_show(inp, title=None):
    plt.figure(figsize=(14,3))
    inp = inp.numpy().transpose((1,2,0)) #转成numpy,然后转置
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224,0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)
    plt.show()
def main():
    pass
    #3、定义超参数
    BATCH_SIZE = 8
    DEVICE = torch.device("gpu" if torch.cuda.is_available() else "cpu")

    #4、图片转换    使用字典进行转换
    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize(300),
            transforms.RandomResizedCrop(300) ,#随机裁剪
            transforms.RandomHorizontalFlip(),
            transforms.CenterCrop(256),
            transforms.ToTensor(), #转为张量
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225]) #正则化

        ]),
        'val': transforms.Compose([
            transforms.Resize(300),
            transforms.CenterCrop(256),
            transforms.ToTensor(), #转为张量
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225]) #正则化
        ])
    }

    #5、操作数据集
    # 5.1、数据集路径
    data_path = "D:/chest_xray/"
    #5.2、加载数据集的train val
    img_datasets = { x : datasets.ImageFolder(os.path.join(data_path,x),
                        data_transforms[x]) for x in ["train","val"]}
    #5.3、为数据集创建一个迭代器,读取数据
    dataloaders = {x : DataLoader(img_datasets[x], shuffle=True,
                        batch_size= BATCH_SIZE) for x in ["train","val"]
                                  }
    # 5.4、训练集和验证集的大小(图片的数量)
    data_sizes = {x : len(img_datasets[x]) for x in ["train","val"]}
    # 5.5、获取标签类别名称 NORMAL 正常 -- PNEUMONIA 感染
    target_names = img_datasets['train'].classes
    #6 显示一个batch_size 的图片(8张图片)
    #6.1 读取8张图片
    datas ,targets = next(iter(dataloaders['train'])) #iter把对象变为可迭代对象,next去迭代
    #6.2、将若干正图片平成一副图像
    out = make_grid(dat
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值