参考:https://blog.51cto.com/u_16213376/7938863
在pytorch中,保存训练好的模型可以选择保存整个模型,也可以选择保存模型参数。
import torch
import torch.nn as nn
#定义示例模型
class MyModel(nn.Module):
def _init_(self):
super(MyModel,self).__init_()
self.fc =nn.Linear(10,1)
def forward(self,x):
return self.fc(x)
# 创建模型实例并进行训练
model =MyModel()
optimizer =torch.optim.sGD(model.parameters(),lr=0.1)
criterion =nn.MSELoss()
# 训练过程...
1.保存整个模型
# 保存整个模型
torch.save(model,'model.pth')
加载模型
#加载整个模型
model = torch.load('model.pth',map_location = 'cuda:0')
2.保存模型参数
# 保存模型参数
torch.save(model.state_dict(),'model_params.pth')
加载模型
# 创建模型实例
model = MyModel()
# 加载模型参数
model.load_state_dict(torch.load('model_params.pth'))