1.只保存模型参数
保存模型参数
torch.save(net.state_dict(), 'net_parameter.pkl')
加载模型参数
#定义模型结构
model = create_net()
#加载模型参数
model.load_state_dict(torch.load('net_parameter.pkl'))
2.保存完整模型
即保存模型结构又保存模型参数
torch.save(net, 'net_model.pkl')
加载模型:
net_loaded = torch.load('net_model.pkl')