torch.save(model,“模型名字”)
torch.load 是需要依赖源码的,pickle.load 调用 就会找模型定义的类,就会根据路径找,
所以需要在同一个 工程项目里面,否则报错, 还是需要工程文件 相关代码文件
不依赖 源码代码 只用一个模型在任意环境加载的办法 torchscript
c++环境中调用pytorch模型
trace 和 script
trace_module = torch.jit.trace(model,torch.rand(1,1,224,224))
print(trace_module.code) # 查看模型结构
output = trace_module (torch.ones(1, 3, 224, 224)) # 测试
print(output)
trace_modult('model.pt') # 模型保存
# 此时应该用script方法 模型定义有 if else 等控制语句
script_module = torch.jit.script(model)
print(script_module.code)
output = script_module(torch.rand(1,1,224,224))
print(output)
script_modult('model.pt') # 模型保存