【深度之眼】【Pytorch打卡第19天】:图像分类Resnet的Inference

模型是如何将图像分类的?

3-d 张量→字符串

  1. 类别名与标签的转换: label_name = {"ants": 0, "bees": 1}
  2. 取输出向量最大值的标号 : _, predicted = torch.max(outputs.data, 1)
  3. 复杂运算: outptus = resnet18(img_tensor)

可以看出,计算机读入了一个三维张量,通过模型的复杂运算,得到了一个输出向量,模型仅仅只干了这么一件事情,而后面得到标签以及类别名与标签的转换不属于模型的范畴,是我们人为的去理解这个输出向量的物理意义从而制定的转换规则

图像分类的Inference(推理)

  • 步骤
  1. 获取数据与标签
  2. 选择模型,损失函数,优化器
  3. 写训练代码
  4. 写inference代码

Inference代码基本步骤

  1. 获取数据与模型
  2. 数据变换,如RGB→4D-Tensor
  3. 前向传播
  4. 输出保存预测结果

Inference阶段注意事项

  1. 确保 model处于eval状态而非training
  2. 设置torch.no_grad(),减少内存消耗
  3. 数据预处理需保持一致,RGB o rBGR?

Resnet18模型Inference代码

模型将数据从图片转换到向量的代码可以分为如下几个步骤:

核心代码

            # step 1/4 : path --> imgRGB
            img_rgb = Image.open(path_img).convert('RGB')

            # step 2/4 : img --> tensor
            img_tensor = img_transform(img_rgb, inference_transform)
            img_tensor.unsqueeze_(0) #变成4D
            img_tensor = img_tensor.to(device) #张量没有inplace,所以要重新赋值

            # step 3/4 : tensor --> vector,前向传播
            time_tic = time.time()
            outputs = resnet18(img_tensor)
            time_toc = time.time()

            # step 4/4 : visualization,预测结果
            _, pred_int = torch.max(outputs.data, 1) # output = torch.max(input, dim)
            pred_str = classes[int(pred_int)]

notice:
torch.max(input, dim) 函数

  • 输入
    input是softmax函数输出的一个tensor
    dim是max函数索引的维度0/1,0是每列的最大值,1是每行的最大值
  • 输出
    函数会返回两个tensor,第一个tensor是每行的最大值第二个tensor是每行最大值的索引

全部代码

# -*- coding: utf-8 -*-
"""
# @brief      : inference demo
"""

import os
import time
import torch.nn as nn
import torch
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
import torchvision.models as models

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

# config
vis = True
# vis = False
vis_row = 4

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

inference_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop
  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值