PyTorch模型转TorchScript实战:一步步教你完成转换

最近入坑了PyTorch,在学习PyTorch Mobile的安卓部分。要想将训练好模型迁移到手机上使用,需要将模型转化为TorchScript,它是PyTorch模型(子类nn.Module)的中间表示,可以在高性能环境(例如C ++)中运行。

转换的方法有两种,一种是通过追踪转换另一中是通过注释转换,本文使用的是通过追踪转换的方法。


import torch
import torchvision
import torch.nn as nn

# 加载模型(根据自己模型修改结构和参数)
model_ft = torchvision.models.mobilenet_v2()
num_ftrs = model_ft.classifier[1].in_features
model_ft.classifier[1] = nn.Linear(num_ftrs, 2, bias=True)
model_ft.load_state_dict(torch.load('/home/well/0.94118mobilenet.pt'))
model_ft.eval()

# 给模型的forward()方法一个示例输入
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model_ft, example)
# 保存模型
traced_script_module.save("/home/well/model.pt")

print("Finished Transformation")


RuntimeError: Error(s) in loading state_dict for MobileNetV2:

Missing key(s) in state_dict: “features.0.0.weight”, … (这里表示在state_dict中找不到这些参数)

Unexpected key(s) in state_dict: “module.features.0.0.weight”, …(预料外的参数)

从上面的错误信息中可以发现:这话是参数不匹配的问题,我们传入模型的参数格式是module.features.XXX而模型需要的参数格式是features.XXX。

经过一番查找研究后发现:原因是在多GPU训练的时候,nn.DataParallel(model)对模型进行了包装,所以使用model.state_dict()保存模型,保存的参数格式是module.features.XXX;而我们在转换的过程中加载模型参数的格式的features.XXX,所以报错。


解决方法:在训练的时候使用

model.module.state_dict()

代替原来的

model.state_dict()

保存模型,这样我们就可以将PyTorch模型转换为TorchScript格式。

得到的TorchScript模型:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值