将训练好的模型参数保存为.pth格式,为了部署以及查看需要将其转为.onnx格式,转换的代码如下
以googlenet为例,定义的函数中,第一个参数传入.pth保存的路径,第二个参数传入.onnx要保存的路径
'''
author:long
date:2024/5/22
'''
import torch.onnx
from models.googlenet import googleNet
import torch
import torch.onnx
# 定义所需模型版本
def convert_goodlenet_pth_to_onnx(pth_path, onnx_path):
# 加载预训练的模型权重
googlenet = googleNet()
googlenet.load_state_dict(torch.load(pth_path, map_location=torch.device('cpu')))
# 将模型设置为评估模式
googlenet.eval()
# 定义一个适合作为模型输入的张量,这里假设输入为单个 RGB 图像,大小为 224x224(训练部分设定的是32x32,与训练图像大小一致)
dummy_input = torch.randn(1, 3, 32, 32)
# 将模型导出为 ONNX 格式
torch.onnx.export(googlenet,
dummy_input,
onnx_path,
opset_version=11, # 使用合适的 ONNX OpSet 版本
input_names=['input'], # 输入张量名称
output_names=['output'], # 输出张量名称
dynamic_axes={'input': {0: 'batch_size'}, # 批次维度可变
'output': {0: 'batch_size'}}) # 输出批次维度可变
# 使用函数进行转换
convert_goodlenet_pth_to_onnx(
'E:/pytorch-cifar100-master/checkpoint/googlenet/Tuesday_02_April_2024_09h_17m_49s/googlenet-1-iterative.pth',
'E:/pytorch-cifar100-master/onnx/model_onnx/googlenet_model.onnx')