pytorch手动实现线性回归


一、前言

简单的使用pytorch拟合一条直线


二、实现

import torch
import matplotlib.pyplot as plt

# 学习率
learning_rate = 0.01

# 数据集(500条训练, 50条验证)
train_x = torch.rand([500, 1])
train_y = train_x*2 + 1
val_x = torch.rand([50, 1])
val_y = val_x*2 + 1
print(f"train_x: {train_x}")
print(f"train_y: {train_y}")
print(f"val_x: {val_x}")
print(f"val_y: {val_y}")

# 随机初始化网络权重(w, b)
w = torch.rand([1, 1], requires_grad=True, dtype=torch.float32)      # requires_grad=True 表示该参数需要计算梯度
b = torch.tensor(0, requires_grad=True, dtype=torch.float32)         # requires_grad=True 表示该参数需要计算梯度
print(f"w: {w}")
print(f"b: {b}")

# 初始化画布(绘制训练精度)
fig = plt.figure(figsize=(5, 4))

# 训练
for i in range(2000):

    # 前向传播
    predict_y = torch.matmul(train_x, w) + b

    # 计算损失(均方误差)
    loss = (train_y - predict_y).pow(2).mean()

    # 梯度清零
    if w.grad is not None:
        w.grad.data.zero_()
    if b.grad is not None:
        b.grad.data.zero_()

    # 反向传播,计算梯度
    loss.backward()

    # 更新梯度
    w.data = w.data - learning_rate*w.grad
    b.data = b.data - learning_rate*b.grad

    # 打印损失值
    print(f"epochs:{i + 1}, loss: {loss}")

    # 绘制训练精度
    if i % 50 == 0 and i != 0:      # 每隔50步校验精度并绘制图像
        with torch.no_grad():       # 验证时不记录梯度

            # 绘制验证集真实值(直线)
            plt.plot(val_x.numpy().reshape(-1), val_y.numpy().reshape(-1))
            # 送入验证集,前向传播计算验证集预测结果
            test_y = torch.matmul(val_x, w) + b
            # 绘制验证集预测值(散点)
            plt.scatter(val_x.numpy().reshape(-1), test_y.detach().numpy().reshape(-1), c='g')

            plt.draw()      # 绘制
            plt.pause(0.1)  # 暂停
            fig.clf()       # 清除当前帧


三、训练过程

在这里插入图片描述

  • 15
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 13
    评论
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

什么都干的派森

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值