8.PyTorch深度体验
8.1.图像分类预测
模型如何完成图像分类?
将图像转换为tensor --> 模型 --> 输出向量 --> 取向量的最大值作为预测结果
代码基本步骤:
1. 获取数据与模型
2. 数据变换,如RGB → 4D-Tensor
3. 前向传播
4. 输出保存预测结果
注意事项:
1. 确保 model处于eval状态而非training
2. 设置torch.no_grad(),减少内存消耗
3. 数据预处理需保持一致, RGB or BGR
代码实现:
以模型微调中的分类模型为例,进行预测
# -*- coding: utf-8 -*-
import os
import time
import torch.nn as nn
import torch
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
import torchvision.models as models
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
# 配置可视化开关
vis = True
# vis = False
# 设置可视化时每行有几张图片
vis_row = 4
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
# 对数据预处理,注意和模型的预处理方式保持一致
inference_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
# 标签的分类
classes = ["ants", "bees"]
# 将图片进行预处理得到张量,即数据转换为模型读取的形式
def img_transform(img_rgb, transform=None):
if transform is None:
raise ValueError("找不到transform!必须有transform对img进行处理")
img_t = transform(img_rgb)
return img_t
# 获取文件夹下format格式的文件名
def get_img_name(img_dir, format="jpg"):
file_names = os.listdir(img_dir)
img_names = list(filter(lambda x: x.endswith(format), file_names))
if len(img_names) < 1:
raise ValueError("{}下找不到{}格式数据".format(img_dir, format))
return img_names
def get_model(m_path, vis_model=False):
# 创建resnet18模型
resnet18 = models.resnet18()
# 获取全连接层的输入参数
num_ftrs = resnet18.fc.in_features
# 自定义全连接层,设置输入为2,即二分类
resnet18.fc = nn.Linear(num_ftrs, 2)
# 根据模型路径加载模型,为检查点
checkpoint = torch.load(m_path)
# 将参数加载到模型中
resnet18.load_state_dict(checkpoint['model_state_dict'])
# 打印模型的信息,如每一层的类型、shape 和 参数量等
if vis_model:
# 需要导入torchsummary包,安装:pip install torchsummary
from torchsummary import summary
summary(resnet18, input_size=(3, 224, 224), device="cpu")
return resnet18
if __name__ == "__main__":
# 指定数据路径
img_dir = os.path.join("..", "..", "data/hymenoptera_data/val/ants")
# 指定模型路径,为蚂蚁蜜蜂之前保存的检查点
model_path = "./checkpoint_24_epoch.pkl"
# 用来统计总预测时间
time_total = 0
# 定义两个list,分别用来存放图片和对应的预测类型
img_list, img_pred = list(), list()
# 1. data
# 获取指定路径下所有文件名
img_names = get_img_name(img_dir)
# 获取文件的数量
num_img = len(img_names)
# 2. model
# 获取模型
resnet18 = get_model(model_path, True)
# 将模型放到指定设备上
resnet18.to(device)
# 将模型设置为测试模式
resnet18.eval()
# 下面的所有运算无需保存梯度
with torch.no_grad():
# 遍历文件名
for idx, img_name in enumerate(img_names):
# 拼接每个文件的全路径
path_img = os.path.join(img_dir, img_name)
# step 1/4 : path --> img 根据路径读取rgb图片
img_rgb = Image.open(path_img).convert('RGB')
# step 2/4 : img --> tensor 将rgb图像转换为张量
img_tensor = img_transform(img_rgb, inference_transform)
# 3d --> 4d
img_tensor.unsqueeze_(0)
# 将数据放到指定设备上
img_tensor = img_tensor.to(device)
# step 3/4 : tensor --> vector
# 统计运行时间
time_tic = time.time()