1.方法
保存训练好的整个网络:torch.save(net1,‘net1.pth’)
只保存网络的参数:torch.save(net1.state_dict(),‘net1_params.pth’)
调用整个训练好的网络:net2 = torch.load(‘net1.pth’)
只调用网络的参数:net3.load_state_dict(torch.load(‘net1_params.pth’))
注意:只保存网络中的参数速度快, 占内存少,但是只调用网络的参数时,新网络需要提前搭建和net1网络相同的架构,再使用上面的调用指令。详细可以见下面的实例。
其中,net1是训练好的网络的名称,’ '内部是保存的文件名称(后缀是.pth或者.pkl)
2.实例
2.1 实验结果
以我的[pytorch学习笔记二]数据的拟合为例,将训练好的net1保存,使用net2调用整个网络,net3只调用net2的参数,最终拟合的效果一模一样。
2.2完整代码
# 1.导入必要的模块
import torch
import torch.nn.functional as F # F中包含很多函数比如激励函数
import matplotlib.pyplot as plt #用于绘图
# 2.生成要拟合的数据点
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = x.pow(2) + 0.2 * torch.rand(x.siz