线性回归线性回归 是分析一个变量与另外一个(多)个变量之间关系的方法
因变量:y
自变量:x
关系:线性
y=wx+b -> 求解w,b
求解步骤:
1.确定模型 Module: y=wx+b
2.选择损失函数 MSE: 均方差等
3.求解梯度并更新w,b w=w-LR*w.grad b=b-LR*w.grad
LR:步长,即学习率 -> 迭代更新使损失函数值较小即可
import torch
import matplotlib.pyplot as plt
torch.manual_seed(10) # 为CPU设置种子用于生成随机数,以使得结果是确定的
lr = 0.05 # 学习率
# 创建训练数据
x = torch.rand(20, 1) * 10
y = 2*x + (5 + torch.randn(20, 1))
# 构建线性回归参数
w = torch.randn((1), requires_grad=True)
b = torch.zeros((1), requires_grad=True)
for iteration in range(1000):
# 前向传播
wx = torch.mul(w, x)
y_pred = torch.add(wx, b)
# 计算MES loss
loss = (0.5 * (y-y_pred) ** 2).mean()
# 反向传播
loss.backward()
# 更新参数
b.data.sub_(lr * b.grad)
w.data.sub_(lr * w.grad)
# 清零张量的梯度
w.grad.zero_()
b.grad