import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
x=torch.unsqueeze(torch.linspace(-1,1,1000),dim=1)
y=x.pow(2)+0.2*torch.rand(x.size())
def save():
net1=torch.nn.Sequential(
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10,1)
)
optimizer=torch.optim.SGD(net1.parameters(),lr=0.5)
loss_func=torch.nn.MSELoss()
for t in range(100):
prediction=net1(x)
loss=loss_func(prediction,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# plot result
plt.figure(1, figsize=(10, 3))
plt.subplot(131)
plt.title('Net1')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
#two method
torch.save(net1,'net.pkl') #savr entire net
torch.save(net1.state_dict(),'net_para.pkl') #just save net parameters
def restore_net():
#restore entire net1 to net2
net2=torch.load('net.pkl')
prediction=net2(x)
plt.subplot(132)
plt.title('Net2')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
def restore_para():
net3=torch.nn.Sequential(
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10,1)
)
net3.load_state_dict(torch.load('net_para.pkl'))
prediction=net3(x)
plt.subplot(133)
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.show()
# 保存 net1 (1. 整个网络, 2. 只有参数)
save()
# 提取整个网络
restore_net()
# 提取网络参数, 复制到新网络
restore_para()
莫烦Pytorch之保存加载网络
最新推荐文章于 2023-01-18 13:18:25 发布