import torch
from torchvision.models.mobilenetv2 import mobilenet_v2
model = mobilenet_v2(pretrained=False)
model.load_state_dict(torch.load('net.pth'))
model.eval()
dummy_input = torch.randn(1,3,224,224)
trace_script_module = torch.jit.trace(model,dummy_input)
trace_script_module.save('net.torchscript')
pt转torchscript
最新推荐文章于 2024-04-23 16:01:20 发布