pytorch保存模型与加载:
模型的保存
torch.save(net,PATH)#保存模型的整个网络,包括网络的整个结构和参数
torch.save(net.state_dict,PATH)#只保存网络中的参数
模型的加载
分别对应上边的加载方法。
model_dict=torch.load(PATH)
model_dict=net.load_state_dict(torch.load(PATH))
在自定义的网络中的使用:
import torch
import torch.nn as nn
class neuralModel(nn.Module):
def __init__(self,device):super(neuralModel,self).__init__()
self.device=device#初始化函数
def dump(self,filename):#保存模型参数
torch.save(self.state_dict(),filename)
def load(self,filename):
state_dict=torch.load(open(filename,"rb"),map_location=self.device)
self.load_state_dict(state_dict,strict=True)
其中map_location为改变设备(gpu0,gpu1,cpu…)
参考链接
pytorch------cpu与gpu load时相互转化 torch.load(map_location=)
[Pytorch]Pytorch 保存模型与加载模型(转)