一.pytorch保存训练好的模型
假设你的模型定义为:
class Net(torch.nn.Module):
......
两种方式:
仅仅保存和加载模型参数:
#保存
PATH="./model.pkl"
the_model = Net()
torch.save(the_model.state_dict(), PATH)
#加载
the_model = Net()
the_model.load_state_dict(torch.load(PATH))
保存和加载整个模型
#保存
the_model=Net()
torch.save(the_model, PATH)
#加载
the_model=Net()
the_model.eval()#加这个是为了和训练时dropout等的设置保持一致
the_model = torch.load(PATH)
参考链接:
https://blog.csdn.net/u011276025/article/details/78507950
https://blog.csdn.net/u011276025/article/details/72817353
二. pytorch自己定义损失函数
两种方式,这个链接讲的很清楚: