PyTorch学习笔记-8.PyTorch深度体验

8.PyTorch深度体验

8.1.图像分类预测

模型如何完成图像分类?

将图像转换为tensor --> 模型 --> 输出向量 --> 取向量的最大值作为预测结果

 

代码基本步骤:
1. 获取数据与模型
2. 数据变换,如RGB → 4D-Tensor
3. 前向传播
4. 输出保存预测结果

注意事项:
1. 确保 model处于eval状态而非training
2. 设置torch.no_grad(),减少内存消耗
3. 数据预处理需保持一致, RGB or BGR

 

代码实现:

以模型微调中的分类模型为例,进行预测

# -*- coding: utf-8 -*-
import os
import time
import torch.nn as nn
import torch
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
import torchvision.models as models
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

# 配置可视化开关
vis = True
# vis = False
# 设置可视化时每行有几张图片
vis_row = 4

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

# 对数据预处理,注意和模型的预处理方式保持一致
inference_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# 标签的分类
classes = ["ants""bees"]

# 将图片进行预处理得到张量,即数据转换为模型读取的形式
def img_transform(img_rgb, transform=None):

    if transform is None:
        raise ValueError("找不到transform!必须有transform对img进行处理")

    img_t = transform(img_rgb)
    return img_t

# 获取文件夹下format格式的文件名
def get_img_name(img_dir, format="jpg"):

    file_names = os.listdir(img_dir)
    img_names = list(filter(lambda x: x.endswith(format), file_names))

    if len(img_names) < 1:
        raise ValueError("{}下找不到{}格式数据".format(img_dir, format))
    return img_names


def get_model(m_path, vis_model=False):

    # 创建resnet18模型
    resnet18 = models.resnet18()
    # 获取全连接层的输入参数
    num_ftrs = resnet18.fc.in_features
    # 自定义全连接层,设置输入为2,即二分类
    resnet18.fc = nn.Linear(num_ftrs, 2)

    # 根据模型路径加载模型,为检查点
    checkpoint = torch.load(m_path)
    # 将参数加载到模型中
    resnet18.load_state_dict(checkpoint['model_state_dict'])

    # 打印模型的信息,如每一层的类型、shape 和 参数量等
    if vis_model:
        # 需要导入torchsummary包,安装:pip install torchsummary
        from torchsummary import summary
        summary(resnet18, input_size=(3, 224, 224), device="cpu")

    return resnet18


if __name__ == "__main__":
    # 指定数据路径
    img_dir = os.path.join("..""..""data/hymenoptera_data/val/ants")
    # 指定模型路径,为蚂蚁蜜蜂之前保存的检查点
    model_path = "./checkpoint_24_epoch.pkl"
    # 用来统计总预测时间
    time_total = 0
    # 定义两个list,分别用来存放图片和对应的预测类型
    img_list, img_pred = list(), list()

    # 1. data
    # 获取指定路径下所有文件名
    img_names = get_img_name(img_dir)
    # 获取文件的数量
    num_img = len(img_names)

    # 2. model
    # 获取模型
    resnet18 = get_model(model_path, True)
    # 将模型放到指定设备上
    resnet18.to(device)
    # 将模型设置为测试模式
    resnet18.eval()

    # 下面的所有运算无需保存梯度
    with torch.no_grad():
        # 遍历文件名
        for idx, img_name in enumerate(img_names):

            # 拼接每个文件的全路径
            path_img = os.path.join(img_dir, img_name)

            # step 1/4 : path --> img 根据路径读取rgb图片
            img_rgb = Image.open(path_img).convert('RGB')

            # step 2/4 : img --> tensor 将rgb图像转换为张量
            img_tensor = img_transform(img_rgb, inference_transform)
            # 3d --> 4d
            img_tensor.unsqueeze_(0)
            # 将数据放到指定设备上
            img_tensor = img_tensor.to(device)

            # step 3/4 : tensor --> vector
            # 统计运行时间
            time_tic = time.time()

pytorch-09.ipynb是一个使用PyTorch库进行深度学习实践的笔记本文件。PyTorch是一个基于Python的深度学习框架,它提供了方便简洁的API接口,使得深度学习模型的构建和训练变得更加容易。 在这个笔记本文件中,我推测可能包括以下内容: 1. 张量的基本概念和操作:张量是PyTorch中最基本的数据类型,类似于Numpy中的多维数组。这个笔记本可能会介绍如何创建和操作张量,以及张量在深度学习中的应用。 2. 自动梯度计算:PyTorch通过自动梯度计算(Autograd)模块实现了计算图和反向传播。这个笔记本可能会介绍如何使用PyTorch的autograd模块来计算张量的导数,并利用导数进行模型参数的更新。 3. 模型构建和训练:深度学习模型的构建和训练是PyTorch的核心功能。这个笔记本可能会介绍如何使用PyTorch构建各种类型的神经网络模型(如全连接网络、卷积神经网络和循环神经网络)并进行训练。 4. 数据加载和预处理:在深度学习中,数据的加载和预处理是非常重要的一步。这个笔记本可能会介绍如何使用PyTorch的数据加载器和数据转换工具进行数据的加载和处理。 5. 模型性能评估和调优:在实际应用中,评估模型性能和进行调优是不可或缺的步骤。这个笔记本可能会介绍如何使用PyTorch进行模型性能的评估,并介绍一些常见的调优方法,如学习率调整、正则化和dropout等。 总之,这个笔记本文件可能会提供一些关于PyTorch库的基本操作和深度学习模型构建的实践指南,帮助读者更好地理解和应用PyTorch进行深度学习任务。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值