- 导入模块
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
2. 超参数设置
关键词: 超参数
模型超参数:根据经验进行设定,在开始学习过程之前设置值的参数,而不是通过训练得到的参数,比如迭代次数、隐藏层的层数、每层神经元的个数、学习率r等。
模型参数:由模型通过学习得到的变量,比如权重w和偏置b。
# 超参数设置
input_size = 1
output_size = 1
num_epochs = 100
learning_rate = 0.001
3. 输入数据
# 小数据
x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],
[9.779], [6.182], [7.59], [2.167], [7.042],
[10.791], [5.313], [7.997], [3.1]], dtype=np.float32)
y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],
[3.366], [2.596], [2.53], [1.221], [2.827],
[3.465], [1.65], [2.904], [1.3]], dtype=np.float32)
4. 线性回归
关键词: 线性回归 torch.nn.Linear
线性回归(Linear regression)是一种以线性模型来建模自变量与因变量关系的方法。通常来说,当自变量只有一个的情况被称为简单线性回归,自变量大于一个的情况被称为多重线性回归。在线性回归模型中, 模型的未知参数由数据中估计得到。最常用的拟合方法是最小二乘法,但是也有许多其他的拟合方法。因此需要甄别的是,使用最小二乘法求解并不是构成线性回归模型的必要条件。
具体理论细节可以参考《动手学深度学习-线性回归》,也推荐阅读西瓜书第三章:线性模型。
# 线性回归模型
model = nn.Linear(input_size, output_size)
也可以看看源码实现(推荐)。
5. 损失函数设置
关键词: 损失函数
损失函数作用:衡量模型模型预测的好坏。简单一点说就是:损失函数是用来表现预测与实际数据的差距程度。因此:一般来说损失函数值越小,模型表现就越好。
pytorch常用损失函数。
# loss
criterion = nn.MSELoss()
6. 优化器设置
关键词: 优化器 torch.optim
理解优化器,看这篇博客。
# optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
7. 模型训练
# 保存loss值,画loss曲线
losses = []
# 训练模型
for epoch in range(num_epochs):
train_loss = 0
# numpy_arrays -> torch_tensors
inputs = torch.from_numpy(x_train)
targets = torch.from_numpy(y_train)
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, targets)
# 反向传播+优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 保存loss值,画loss曲线
train_loss += loss.item()
losses.append(train_loss / len(x_train))
if (epoch + 1) % 5 == 0:
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, loss.item()))
8. loss曲线和回归曲线
关键词: matplolib
一般会用matplotlib库,画对应的曲线,matplotlib教程。
# 画loss曲线和回归曲线
predicted = model(torch.from_numpy(x_train)).detach().numpy()
fig = plt.figure()
ax1 = fig.add_subplot(121)
plt.plot(np.arange(len(losses)), losses, label='loss')
plt.legend()
ax2 = fig.add_subplot(122)
plt.plot(x_train, y_train, 'ro', label='Original data')
plt.plot(x_train, predicted, label='Fitted line')
plt.legend()
plt.show()
9. 保存模型
# 保存模型
torch.save(model.state_dict(), 'linear_regression_model.pt')
参考链接:
Welcome to PyTorch Tutorialspytorch.org