- 自定义数据,采用torch实现简单的线性回归,假设基础模型为:,通过迭代更新获得预测权重。
- 准备训练数据。
- 计算预测值
- 计算损失值,同时参数梯度设置为0,最后采用反向传播
- 更新参数
- 画图分析
# -*- coding: utf-8 -*-
"""
Created on Mon Apr 24 08:04:00 2023
@author: 茶墨先生
"""
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.axisartist.axislines import Subplot
"准备数据"
#y=5x+7.4
x=torch.rand([800,1])#torch.Size([800, 1])
y_true=x*5+7.4#torch.Size([800, 1])
"计算预测值,设置参数"
w=torch.rand([1,1],requires_grad=True)
b=torch.tensor(0.,requires_grad=True)
"定义学习率"
LR=0.01
"循环操作,更新参数"
for i in range(400):
"计算损失"
y_preidct=torch.matmul(x,w)+b
# pri