前言
本文主要通过实战的方式,记录各种模型推理的方法
模型训练
首先我们先使用Pytorch训练一个最简单的十分类神经网络,如下:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
# 加载训练数据
training_data = datasets.FashionMNIST(
root=r"./Datasets/",
train=True,
download=True,
transform=ToTensor(),
)
# 加载验证数据
test_data = datasets.FashionMNIST(
root=r"./Datasets/",
train=False,
download=True,
transform=ToTensor(),
)
# Create data loaders.
batch_size = 16
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
# 定义神经网络模型
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork().to(device)
# print(model)
# 定义损失函数,优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
# 定义训练过程
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
# 定义验证方法(在验证数据集中进行验证)
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
epochs = 100
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn)
print("Done!")
模型推理
常规Pytorch模型
Pytorch官方入门文档所给出的模型持久化及加载方法,使用torch.save()方法对模型进行持久化,所保存的模型为动态图模型。如下:
# (需承接上面的训练代码,才可正常运行)
# 保存模型
model_path = "./model"
if not os.path.isdir(model_path):
os.makedirs(model_path)
torch.save(model.state_dict(), os.path.join(model_path, 'model.pth'))
print("Saved PyTorch Model State to model.pth")
# 加载模型进行推理
model = NeuralNetwork()
model.load_state_dict(torch.load("./model/model.pth"))
classes = [
"T-shirt/top",
"Trouser",
"Pullover",
"Dress",
"Coat",
"Sandal",
"Shirt",
"Sneaker",
"Bag",
"Ankle boot",
]
model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
pred = model(x)
predicted, actual = classes[pred[0].argmax(0)], classes[y]
print(f'Predicted: "{predicted}", Actual: "{actual}"')
TorchScript
TorchScript是一种从PyTorch代码创建可序列化和可优化模型的方法,是一种静态图模型。TorchScript模型可以从Python进程中保存,并加载到没有Python依赖的进程中(比如说 C++环境)。使用方法如下:
/* 保存模型 */
# 通过trace的方法生成IR需要一个输入样例
dummy_input = torch.rand(1, 1, 28, 28)
# IR生成
with torch.no_grad():
jit_model = torch.jit.trace(model, dummy_input)
# 将模型序列化
jit_model.save('./model/jit_model.pt')
/* 加载、推理模型 */
# 加载序列化后的模型
jit_model = torch.jit.load('./model/jit_model.pt')
x, y = test_data[0][0], test_data[0][1]
start_time = time.time()
pred = jit_model.forward(x)
print(f"spend time: {time.time()-start_time}")
print(pred[0].argmax(0))
ONNX
开放神经网络交换(Open Neural Network Exchange),简称ONNX。是微软和Facebook提出用来表示深度学习模型的开放格式。所谓开放就是ONNX定义了一组和环境,平台均无关的标准格式,来增强各种AI模型的可交互性。
目前,工业上使用Onnx的意义,往往是作为模型适配推理引擎的一种手段,如先把Pytorch模型转为Onnx,再到TensorRT上完成部署。
以下为得到一个超分辨率 Onnx模型的代码示范:
'''
得到一个超分辨率模型,由常规Pytorch模型转化为Onnx模型
'''
import os
import cv2
import numpy as np
import requests
import torch
import torch.onnx
from torch import nn
# 定义
class SuperResolutionNet(nn.Module):
def __init__(self, upscale_factor):
super().__init__()
self.upscale_factor = upscale_factor
self.img_upsampler = nn.Upsample(
scale_factor=self.upscale_factor,
mode='bicubic',
align_corners=False)
self.conv1 = nn.Conv2d(3,64,kernel_size=9,padding=4)
self.conv2 = nn.Conv2d(64,32,kernel_size=1,padding=0)
self.conv3 = nn.Conv2d(32,3,kernel_size=5,padding=2)
self.relu = nn.ReLU()
def forward(self, x):
x = self.img_upsampler(x)
out = self.relu(self.conv1(x))
out = self.relu(self.conv2(out))
out = self.conv3(out)
return out
# Download checkpoint and test image (下载模型文件和测试图片)
urls = ['https://download.openmmlab.com/mmediting/restorers/srcnn/srcnn_x4k915_1x16_1000k_div2k_20200608-4186f232.pth',
'https://raw.githubusercontent.com/open-mmlab/mmediting/master/tests/data/face/000001.png']
names = ['srcnn.pth', 'face.png']
for url, name in zip(urls, names):
if not os.path.exists(name):
open(name, 'wb').write(requests.get(url).content)
def init_torch_model():
torch_model = SuperResolutionNet(upscale_factor=3)
state_dict = torch.load('srcnn.pth')['state_dict']
# Adapt the checkpoint
for old_key in list(state_dict.keys()):
new_key = '.'.join(old_key.split('.')[1:])
state_dict[new_key] = state_dict.pop(old_key)
torch_model.load_state_dict(state_dict)
torch_model.eval()
return torch_model
model = init_torch_model()
input_img = cv2.imread('face.png').astype(np.float32)
# HWC to NCHW
input_img = np.transpose(input_img, [2, 0, 1])
input_img = np.expand_dims(input_img, 0)
# Inference
torch_output = model(torch.from_numpy(input_img)).detach().numpy()
# NCHW to HWC
torch_output = np.squeeze(torch_output, 0)
torch_output = np.clip(torch_output, 0, 255)
torch_output = np.transpose(torch_output, [1, 2, 0]).astype(np.uint8)
# Show image
cv2.imwrite("face_torch.png", torch_output)
# 通过 torch.onnx.export()转为Onnx格式的模型文件
x = torch.randn(1, 3, 256, 256)
with torch.no_grad():
torch.onnx.export(
model,
x,
"srcnn.onnx",
opset_version=11,
input_names=['input'],
output_names=['output'])
# 检查Onnx模型文件
import onnx
onnx_model = onnx.load("srcnn.onnx")
try:
onnx.checker.check_model(onnx_model)
except Exception:
print("Model incorrect")
else:
print("Model correct")
# 加载及使用Onnx模型
import onnxruntime
ort_session = onnxruntime.InferenceSession("srcnn.onnx")
ort_inputs = {'input': input_img}
ort_output = ort_session.run(['output'], ort_inputs)[0]
ort_output = np.squeeze(ort_output, 0)
ort_output = np.clip(ort_output, 0, 255)
ort_output = np.transpose(ort_output, [1, 2, 0]).astype(np.uint8)
cv2.imwrite("face_ort.png", ort_output)
参考文档
Save and Load the Model — PyTorch Tutorials 2.1.1+cu121 documentation