# Author:Richard
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
# 回归
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # 将一维变二维
y = x.pow(2) + 0.2 * torch.rand(x.size())
# 保存和提取
神经网络
def save():
# save net1
net1 = nn.Sequential(
nn.Linear(1, 10),
nn.ReLU(),
nn.Linear(10, 1)
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.45)
loss_func = nn.MSELoss()
for r in range(100):
prediction = net1(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.save(net1, 'net1.pkl') # entire net
torch.save(net1.state_dict(), 'net1_parameters.pkl') # save parameters
# plot result
plt.figure(1, figsize=(10, 3))
plt.subplot(131)
plt.title('Net1')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
def restore_net():
net2 = torch.load('net1.pkl')
prediction = net2(x)
plt.subplot(132)
plt.title('Net2')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
def restore_params():
net3 = nn.Sequential(
nn.Linear(1, 10),
nn.ReLU(),
nn.Linear(10, 1)
)
net3.load_state_dict(torch.load('net1_parameters.pkl'))
prediction = net3(x)
plt.subplot(133)
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
save()
restore_net()
restore_params()
plt.show()
CNN_保存和提取
最新推荐文章于 2024-07-13 15:44:59 发布