在pytorch平台上,训练好模型。在推断时,推断文件路径发生了改变,结果报模型无法识别。
net=torch.load('XXX.pt')
错误如下
return _load(f,map_location,pickle_module,**pickle_load_args)
result = unpickler.load()
ModuleNotFoundError: No module named 'models'
根据网上资料:重新生成了’.tjm’可以解决我的问题,具体如下:
def convert_model(model, input=torch.tensor(torch.rand(size=(1,3,112,112)))):
model = torch.jit.trace(model, input)
torch.jit.save(model,'XXX/model.tjm')
这里input根据自己的输入做相应的修改
模型载入方式如下:
model = torch.jit.load('XXX/model.tjm')
此外,提供了其它方法。不过这里未做验证。
具体可以参考:
添加链接描述