保存
torch.save(net1, 'net1.pkl') # 保存整个网络
torch.save(net1.state(), 'net_params.pkl') # 只保存网络中的参数
提取
- 提取整个神经网络
def restore_net():
net2 = torch.load('net1.pkl') # 提取网络
prediction = net2(x)
- 提取网络参数
def restore_params():
# 新建 net3 要求与原网络保持相同的结构
net3 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
# 将保存的参数复制到 net3
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x)