.pth文件转为.onnx格式

将训练好的模型参数保存为.pth格式,为了部署以及查看需要将其转为.onnx格式,转换的代码如下

以googlenet为例,定义的函数中,第一个参数传入.pth保存的路径,第二个参数传入.onnx要保存的路径

'''
author:long
date:2024/5/22
'''
import torch.onnx
from models.googlenet import googleNet
import torch
import torch.onnx


# 定义所需模型版本
def convert_goodlenet_pth_to_onnx(pth_path, onnx_path):
    # 加载预训练的模型权重
    googlenet = googleNet()
    googlenet.load_state_dict(torch.load(pth_path, map_location=torch.device('cpu')))

    # 将模型设置为评估模式
    googlenet.eval()

    # 定义一个适合作为模型输入的张量,这里假设输入为单个 RGB 图像,大小为 224x224(训练部分设定的是32x32,与训练图像大小一致)
    dummy_input = torch.randn(1, 3, 32, 32)

    # 将模型导出为 ONNX 格式
    torch.onnx.export(googlenet,
                      dummy_input,
                      onnx_path,
                      opset_version=11,  # 使用合适的 ONNX OpSet 版本
                      input_names=['input'],  # 输入张量名称
                      output_names=['output'],  # 输出张量名称
                      dynamic_axes={'input': {0: 'batch_size'},  # 批次维度可变
                                    'output': {0: 'batch_size'}})  # 输出批次维度可变


# 使用函数进行转换
convert_goodlenet_pth_to_onnx(
    'E:/pytorch-cifar100-master/checkpoint/googlenet/Tuesday_02_April_2024_09h_17m_49s/googlenet-1-iterative.pth',
    'E:/pytorch-cifar100-master/onnx/model_onnx/googlenet_model.onnx')

  • 4
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值