本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson8/resnet_inference.py
这篇文章首先会简单介绍一下 PyTorch
中提供的图像分类的网络,然后重点介绍 ResNet
的使用,以及 ResNet
的源码。
模型概览
在torchvision.model
中,有很多封装好的模型。
可以分类 3 类:
- 经典网络
- alexnet
- vgg
- resnet
- inception
- densenet
- googlenet
- 轻量化网络
- squeezenet
- mobilenet
- shufflenetv2
- 自动神经结构搜索方法的网络
- mnasnet
ResNet18 使用
以 ResNet 18
为例。
首先加载训练好的模型参数:
resnet18 = models.resnet18()
# 修改全连接层的输出
num_ftrs = resnet18.fc.in_features
resnet18.fc = nn.Linear(num_ftrs, 2)
# 加载模型参数
checkpoint = torch.load(m_path)
resnet18.load_state_dict(checkpoint['model_state_dict'])
然后比较重要的是把模型放到 GPU 上,并且转换到`eval`模式:
resnet18.to(device)
resnet18.eval()
在 inference 时,主要流程如下:
-
代码要放在
with torch.no_grad():
下。torch.no_grad()
会关闭反向传播,可以减少内存、加快速度。 -
根据路径读取图片,把图片转换为 tensor,然后使用
unsqueeze_(0)
方法把形状扩大为 B × C × H × W B \times C \times H \times W B×C×H×W,再把 tensor 放到 GPU 上 。 -
模型的输出数据
outputs
的形状是 1 × 2 1 \times 2 1×2,表示batch_size
为 1,分类数量为 2。torch.max(outputs,0)
是返回outputs
中每一列最大的元素和索引,torch.max(outputs,1)
是返回outputs
中每一行最大的元素和索引。这里使用
_, pred_int = torch.max(outputs.data, 1)
返回最大元素的索引,然后根据索引获得 label:pred_str = classes[int(pred_int)]
。
关键代码如下:
with torch.no_grad():
for idx, img_name in enumerate(img_names):
path_img = os.path.join(img_dir, img_name)
# step 1/4 : path --> img
img_rgb = Image.open(path_img).convert('RGB')
# step 2/4 : img --> tensor
img_tensor = img_transform(img_rgb, inference_transform)
img_tensor.unsqueeze_(0)
img_tensor = img_tensor.to(device)
# step 3/4 : tensor --> vector
outputs = resnet18(img_tensor)
# step 4/4 : get label
_, pred_int = torch.max(outputs.data, 1)
pred_str = classes[int(pred_int)]
全部代码如下所示:
import os
import time
import torch.nn as nn
import torch
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
import torchvision.models as models
import enviroments
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
# config
vis = True
# vis = False
vis_row = 4
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
inference_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
classes = ["ants", "bees"]
def img_transform(img_rgb, transform=None):
"""
将数据转换为模型读取的形式
:param img_rgb: PIL Image
:param transform: torchvision.transform
:return: tensor
"""
if transform is None:
raise ValueError("找不到transform!必须有transform对img进行处理")
img_t = transform(img_rgb)
return img_t
def get_img_name(img_dir, format="jpg"):
"""
获取文件夹下format格式的文件名
:param img_dir: str
:param format: str
:return: list
"""
file_names = os.listdir(img_dir)
# 使用 list(filter(lambda())) 筛选出 jpg 后缀的文件
img_names = list(filter(lambda x: x.endswith(format), file_names))
if len(img_names) < 1:
raise ValueError("{}下找不到{}格式数据".format(img_dir, format))
return img_names
def get_model(m_path, vis_model=False):
resnet18 = models.resnet18()
# 修改全连接层的输出
num_ftrs = resnet18.fc.in_features
resnet18.fc = nn.Linear(num_ftrs, 2)
# 加载模型参数
checkpoint = torch.load(m_path)
resnet18.load_state_dict(checkpoint['model_state_dict'])
if vis_model:
from torchsummary import summary
summary(resnet18, input_size=(3, 224, 224), device="cpu")
return resnet18
if __name__ == "__main__":
img_dir = os.path.join(enviroments.hymenoptera_data_dir,"val/bees")
model_path = "./checkpoint_14_epoch.pkl"
time_total = 0
img_list, img_pred = list(), list()
# 1. data
img_names = get_img_name(img_dir)
num_img = len(img_names)
# 2. model
resnet18 = get_model(model_path, True)
resnet18.to(device)
resnet18.eval()
with torch.no_grad():
for idx, img_name in enumerate(img_names):
path_img = os.path.join(img_dir, img_name)
# step 1/4 : path --> img
img_rgb = Image.open(path_img).convert('RGB')
# step 2/4 : img --> tensor
img_tensor = img_transform(img_rgb, inference_transform)
img_tensor.unsqueeze_(0)
img_tensor = img_tensor.to(device)
# step 3/4 : tensor --> vector
time_tic = time.time()
outputs = resnet18(img_tensor)
time_toc = time.time()
# step 4/4 : visualization