pytorch保存模型与加载:
模型的保存
torch.save(net,PATH)#保存模型的整个网络,包括网络的整个结构和参数
torch.save(net.state_dict,PATH)#只保存网络中的参数
模型的加载
分别对应上边的加载方法。
model_dict=torch.load(PATH)
model_dict=net.load_state_dict(torch.load(PATH))
pytorch保存模型与加载:
模型的保存
torch.save(net,PATH)#保存模型的整个网络,包括网络的整个结构和参数
torch.save(net.state_dict,PATH)#只保存网络中的参数
模型的加载
分别对应上边的加载方法。
model_dict=torch.load(PATH)
model_dict=net.load_state_dict(torch.load(PATH))