目录
模型是如何将图像分类的?
3-d 张量→字符串
- 类别名与标签的转换:
label_name = {"ants": 0, "bees": 1}
- 取输出向量最大值的标号 :
_, predicted = torch.max(outputs.data, 1)
- 复杂运算:
outptus = resnet18(img_tensor)
可以看出,计算机读入了一个三维张量,通过模型的复杂运算,得到了一个输出向量,模型仅仅只干了这么一件事情,而后面得到标签以及类别名与标签的转换不属于模型的范畴,是我们人为的去理解这个输出向量的物理意义从而制定的转换规则
![](https://i-blog.csdnimg.cn/blog_migrate/254445d95e0fa9d9e3c559b5452946ae.png)
![](https://i-blog.csdnimg.cn/blog_migrate/b535ecc4dc08a8fc9540c8f5f84944d7.png)
![](https://i-blog.csdnimg.cn/blog_migrate/1c463e948b2b7bf186e99a2f30b96bfd.png)
图像分类的Inference(推理)
- 步骤
- 获取数据与标签
- 选择模型,损失函数,优化器
- 写训练代码
- 写inference代码
Inference代码基本步骤
- 获取数据与模型
- 数据变换,如RGB→4D-Tensor
- 前向传播
- 输出保存预测结果
Inference阶段注意事项
- 确保
model处于eval状态而非training
设置torch.no_grad()
,减少内存消耗数据预处理需保持一致
,RGB o rBGR?
Resnet18模型Inference代码
模型将数据从图片转换到向量的代码可以分为如下几个步骤:
核心代码
# step 1/4 : path --> imgRGB
img_rgb = Image.open(path_img).convert('RGB')
# step 2/4 : img --> tensor
img_tensor = img_transform(img_rgb, inference_transform)
img_tensor.unsqueeze_(0) #变成4D
img_tensor = img_tensor.to(device) #张量没有inplace,所以要重新赋值
# step 3/4 : tensor --> vector,前向传播
time_tic = time.time()
outputs = resnet18(img_tensor)
time_toc = time.time()
# step 4/4 : visualization,预测结果
_, pred_int = torch.max(outputs.data, 1) # output = torch.max(input, dim)
pred_str = classes[int(pred_int)]
notice:
torch.max(input, dim) 函数
- 输入
input
是softmax函数输出的一个tensor
dim
是max函数索引的维度0/1,0是每列的最大值,1是每行的最大值
- 输出
函数会返回两个tensor,第一个tensor是每行的最大值
;第二个tensor是每行最大值的索引
。
全部代码
# -*- coding: utf-8 -*-
"""
# @brief : inference demo
"""
import os
import time
import torch.nn as nn
import torch
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
import torchvision.models as models
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
# config
vis = True
# vis = False
vis_row = 4
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
inference_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop