[PyTorch 学习笔记] 8.1 图像分类简述与 ResNet 源码分析

本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson8/resnet_inference.py

这篇文章主要介绍了 图像分类的 inference,其中会着重介绍 ResNet

模型概览

torchvision.model中,有很多封装好的模型。


可以分类 3 类:
  • 经典网络
    • alexnet
    • vgg
    • resnet
    • inception
    • densenet
    • googlenet
  • 轻量化网络
    • squeezenet
    • mobilenet
    • shufflenetv2
  • 自动神经结构搜索方法的网络
    • mnasnet

ResNet18 使用

ResNet 18 为例。

首先加载训练好的模型参数:

resnet18 = models.resnet18()

# 修改全连接层的输出
num_ftrs = resnet18.fc.in_features
resnet18.fc = nn.Linear(num_ftrs, 2)

# 加载模型参数
checkpoint = torch.load(m_path)
resnet18.load_state_dict(checkpoint['model_state_dict'])
然后比较重要的是把模型放到 GPU 上,并且转换到`eval`模式:
resnet18.to(device)
resnet18.eval()

在 inference 时,主要流程如下:

  • 代码要放在with torch.no_grad():下。torch.no_grad()会关闭反向传播,可以减少内存、加快速度。

  • 根据路径读取图片,把图片转换为 tensor,然后使用unsqueeze_(0)方法把形状扩大为 B × C × H × W B \times C \times H \times W B×C×H×W,再把 tensor 放到 GPU 上 。

  • 模型的输出数据outputs的形状是 1 × 2 1 \times 2 1×2,表示 batch_size 为 1,分类数量为 2。torch.max(outputs,0)是返回outputs每一列最大的元素和索引,torch.max(outputs,1)是返回outputs每一行最大的元素和索引。

    这里使用_, pred_int = torch.max(outputs.data, 1)返回最大元素的索引,然后根据索引获得 label:pred_str = classes[int(pred_int)]

关键代码如下:

    with torch.no_grad():
        for idx, img_name in enumerate(img_names):

            path_img = os.path.join(img_dir, img_name)

            # step 1/4 : path --> img
            img_rgb = Image.open(path_img).convert('RGB')

            # step 2/4 : img --> tensor
            img_tensor = img_transform(img_rgb, inference_transform)
            img_tensor.unsqueeze_(0)
            img_tensor = img_tensor.to(device)

            # step 3/4 : tensor --> vector
            outputs = resnet18(img_tensor)

            # step 4/4 : get label
            _, pred_int = torch.max(outputs.data, 1)
            pred_str = classes[int(pred_int)]

全部代码如下所示:

import os
import time
import torch.nn as nn
import torch
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
import torchvision.models as models
import enviroments
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

# config
vis = True
# vis = False
vis_row = 4

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

inference_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

classes = ["ants", "bees"]


def img_transform(img_rgb, transform=None):
    """
    将数据转换为模型读取的形式
    :param img_rgb: PIL Image
    :param transform: torchvision.transform
    :return: tensor
    """

    if transform is None:
        raise ValueError("找不到transform!必须有transform对img进行处理")

    img_t = transform(img_rgb)
    return img_t


def get_img_name(img_dir, format="jpg"):
    """
    获取文件夹下format格式的文件名
    :param img_dir: str
    :param format: str
    :return: list
    """
    file_names = os.listdir(img_dir)
    # 使用 list(filter(lambda())) 筛选出 jpg 后缀的文件
    img_names = list(filter(lambda x: x.endswith(format), file_names))

    if len(img_names) < 1:
        raise ValueError("{}下找不到{}格式数据".format(img_dir, format))
    return img_names


def get_model(m_path, vis_model=False):

    resnet18 = models.resnet18()

    # 修改全连接层的输出
    num_ftrs = resnet18.fc.in_features
    resnet18.fc = nn.Linear(num_ftrs, 2)

    # 加载模型参数
    checkpoint = torch.load(m_path)
    resnet18.load_state_dict(checkpoint['model_state_dict'])


    if vis_model:
        from torchsummary import summary
        summary(resnet18, input_size=(3, 224, 224), device="cpu")

    return resnet18


if __name__ == "__main__":

    img_dir = os.path.join(enviroments.hymenoptera_data_dir,"val/bees")
    model_path = "./checkpoint_14_epoch.pkl"
    time_total = 0
    img_list, img_pred = list(), list()

    # 1. data
    img_names = get_img_name(img_dir)
    num_img = len(img_names)

    # 2. model
    resnet18 = get_model(model_path, True)
    resnet18.to(device)
    resnet18.eval()

    with torch.no_grad():
        for idx, img_name in enumerate(img_names):

            path_img = os.path.join(img_dir, img_name)

            # step 1/4 : path --> img
            img_rgb = Image.open(path_img).convert('RGB')

            # step 2/4 : img --> tensor
            img_tensor = img_transform(img_rgb, inference_transform)
            img_tensor.unsqueeze_(0)
            img_tensor = img_tensor.to(device)

            # step 3/4 : tensor --> vector
            time_tic = time.time()
            outputs = resnet18(img_tensor)
            time_toc = time.time()

            # step 4/4 : visualization
            _, pred_int = torch.max(outputs.data, 1)
            pred_str &#
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值