保存模型与加载模型
常用的两种保存与加载模型方式
# 1.保存整个网络
torch.save(net, PATH)
#针对上面保存方法,加载的方法是:
model_dict=torch.load(PATH)
# 如果有多块GPU,训练和测试使用的不是同一块GPU,则加载的方法是
model_dict=torch.load(PATH, map_location = {'cuda:3', 'cuda:0'})
# 上面'cuda:3'是训练时使用的GPU编号,'cuda:0'是测试时使用的GPU编号
# 2.保存网络中的参数, 速度快,占空间少
torch.save(net.state_dict(),PATH)
#针对上面的保存方法,加载的方法是:
model_dict=model.load_state_dict(torch.load(PATH))