概念
线性回归是分析一个变量与另外一(多)个变量之间关系的方法
- 因变量:y
- 自变量:x
- 关系:线性
- 表达式:y = wx + b
- 目的:求解w和b
求解步骤:
- 确定模型
Model:y = wx + b - 选择损失函数
均方差MSE: 1 m ∑ i = 1 m ( y i − y i ^ ) 2 \frac{1}{m}\sum_{i=1}^{m}(y_i - \hat{y_i})^2 m1∑i=1m(yi−yi^)2 - 求解梯度并更新w,b
梯度下降法:
w = w – LR * w.grad
b = b – LR * w.grad
LR为步长,学习率
import torch
import matplotlib.pyplot as plt
torch.manual_seed(10) # 初始化随机数种子,保证结果可以复现
lr = 0.1 # 学习率
# 创建训练数据
x = torch.rand(20, 1) * 10
y = 2 * x + (5 + torch.randn(20, 1)) # torch.randn(20, 1)加入噪声
# 初始化w和b
w = torch.randn(1, requires_grad=True)
b = torch.zeros(1, requires_grad=True)
# 开始迭代
for i in range(1000):
# 前向传播
wx = torch.mul(w, x)
y_pre = torch.add(wx, b) # 预测值
# 计算损失
loss = (0.5 * (y - y_pre) ** 2).mean() # 乘以0.5是为了求导过程中消除平方2的影响,mean()求均值
# 反向传播
loss.backward() # 自动求导,得到梯度
# 更新参数
b.data.sub_(lr * b.grad)
w.data.sub_(lr * w.grad)
# 绘图
if loss.data.numpy() < 1:
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), y_pre.data.numpy(), "r-", lw=5)
plt.text(2, 10, "loss=%.4f" % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
plt.xlim(1.5, 10)
plt.ylim(8, 28)
plt.title("i:{} w:{} b:{}".format(i, w.data.numpy(), b.data.numpy()))
plt.pause(0.5)
break