1、网络构建
2、网络训练
3、网络结构和参数的保存
4、保存文件的重新导入
import torch
import matplotlib.pyplot as plt
import numpy as np
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + x.pow(5) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1)
n_input = 6
n_hidden = 12
n_output = 1
# 构建网络架构
net1 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
# 优化器
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss()
# 网络训练
loss_reply = np.empty(50)
for t in range(50):
prediction = net1(x) # 前向传播
loss = loss_func(prediction, y) # 计算误差
loss_reply[t] =loss.data.numpy()
optimizer.zero_grad() # 梯度清零
loss.backward() # 反向传递
optimizer.step() # 参数更新优化
# 保存网络结构及其参数 单独保存
torch.save(net1, '