保存和加载模型
只保存模型的参数
保存
torch.save(model.state_dict(),'xxx.pth')
加载
model = net() #首先要先定义网络模型
state_dict = torch.load('xxx.pth') # 读取pth文件中的参数
model.load_state_dict(state_dict['model']) #将参数导入模型
这种方法操作比较麻烦,但是比较节省内存。
official example
class MyModule(torch.nn.Module):
m = MyModule()
m.state_dict()
torch.save(m.state_dict(), 'mymodule.pt')
m_state_dict = torch.load('mymodule.pt')
new_m = MyModule()
new_m.load_state_dict(m_state_dict)
附加保存一些其他信息
torch.save({
'epoch': epoch + opt.start,
'model_state_dict': model.state_dict()
'optimizer_state_dict': optimizer.state_dict(),
'loss': epoch_losses.avg},
}, os.path.join(opt.outputs_dir, '{}_epoch_{}.pth'.format(opt.arch, epoch + opt.start)))
保存整个模型
保存
torch.save(net, 'xxx.pt')
加载
test = torch.load('xxx.pt') #注意其中pt文件的路径
这种方式是将整个的网络模型进行保存,使用不便,但是加载方便,适合于简单测试。