PyTorch模型转换为TorchScript格式

最近入坑了PyTorch,在学习PyTorch Mobile的安卓部分。要想将训练好模型迁移到手机上使用,需要将模型转化为TorchScript,它是PyTorch模型(子类nn.Module)的中间表示,可以在高性能环境(例如C ++)中运行。

转换的方法有两种,一种是通过追踪转换另一中是通过注释转换,本文使用的是通过追踪转换的方法。


import torch
import torchvision
import torch.nn as nn

# 加载模型(根据自己模型修改结构和参数)
model_ft = torchvision.models.mobilenet_v2()
num_ftrs = model_ft.classifier[1].in_features
model_ft.classifier[1] = nn.Linear(num_ftrs, 2, bias=True)
model_ft.load_state_dict(torch.load('/home/well/0.94118mobilenet.pt'))
model_ft.eval()

# 给模型的forward()方法一个示例输入
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model_ft, example)
# 保存模型
traced_script_module.save("/home/well/model.pt")

print("Finished Transformation")

错误信息
RuntimeError: Error(s) in loading state_dict for MobileNetV2:
Missing key(s) in state_dict: “features.0.0.weight”, … (这里表示在state_dict中找不到这些参数)
Unexpected key(s) in state_dict: “module.features.0.0.weight”, …(预料外的参数)

从上面的错误信息中可以发现:这话是参数不匹配的问题,我们传入模型的参数格式是module.features.XXX而模型需要的参数格式是features.XXX。

经过一番查找研究后发现:原因是在多GPU训练的时候,nn.DataParallel(model)对模型进行了包装,所以使用model.state_dict()保存模型,保存的参数格式是module.features.XXX;而我们在转换的过程中加载模型参数的格式的features.XXX,所以报错。


解决方法:在训练的时候使用model.module.state_dict() 代替原来的 model.state_dict() 保存模型,这样我们就可以将PyTorch模型转换为TorchScript格式。
完成转换
得到的TorchScript模型:
模型

PyTorch模型转换TorchScript的步骤如下: 1. 定义和训练PyTorch模型:首先,您需要定义和训练一个PyTorch模型,这可以通过使用标准的PyTorch代码来完成。 2. 导出PyTorch模型:然后,您需要将PyTorch模型导出为TorchScript格式。这可以通过使用PyTorchtorch.jit模块中的trace函数来完成。trace函数接受一个输入样本并生成一个TorchScript模块。 3. 运行TorchScript模型:一旦您导出了TorchScript模型,您可以像普通的Python模块一样使用它。您可以加载模块并使用其forward方法来运行模型。此外,TorchScript模块还可以与C++和Java等其他语言一起使用。 下面是一个简单的示例代码,演示了如何将PyTorch模型转换TorchScript: ``` import torch import torchvision # 定义和训练PyTorch模型 model = torchvision.models.resnet18() example_input = torch.rand(1, 3, 224, 224) traced_script_module = torch.jit.trace(model, example_input) # 导出PyTorch模型TorchScript模块 traced_script_module.save("resnet18.pt") # 加载TorchScript模块并运行模型 loaded_script_module = torch.jit.load("resnet18.pt") output = loaded_script_module(example_input) print(output) ``` 在这个示例代码中,我们首先定义了一个ResNet-18模型,并使用一个随机的输入样本来跟踪模型。然后,我们将跟踪后的模型保存为TorchScript格式。最后,我们加载了TorchScript模块并使用它来运行模型
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值