import torch
import torchvision
import torch.nn as nn
#model = torchvision.models.resnet50(pretrained=True)
model=torchvision.models.resnet18() # 加载模型
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 196) # make the change
model.load_state_dict(torch.load("resnet18.pth"))
model.eval() # 注意,将模型设置为eval模式,再保存
example = torch.rand(1, 3, 224, 224) # 输入模型的尺寸
traced_script_module = torch.jit.trace(model, example)
# 保存模型
traced_script_module.save("resnet18.pt")
Pytorch 模型转化为torchscript 模型——用于C++部署
最新推荐文章于 2024-06-11 21:01:17 发布