保存网络:
1. torch.save(net1, 'net.pkl') #保存整个网络
2. torch.save(net1.state_dict(), 'net_params.pkl') #只保存网络中的参数,(速度快,占内存少)
提取网络:
1. net2 = torch.load('net.pkl')
2. net3 = torch.nn.Sequential(
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10,1)
)
net3.load_state_dict(torch.load('net_params.pkl'))