从零开始学习pytorch之线性回归

代码实现

代码实现较简单,且注释很完整就不在一一赘述。

import torch
import matplotlib.pyplot as plt
# 随机数种子,保证每次随机数产生是一样的
torch.manual_seed(10)
# 学习率
lr = 0.1
# 创建数据集20个点(x,y)
x = torch.rand(20,1)*10
y = 2*x + (5 + torch.randn(20, 1))
# 初始化可训练指标w和b
w = torch.randn((1), requires_grad=True)
b = torch.zeros((1), requires_grad=True)

# 循环迭代1000次
for iteration in range(1000):
    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)
    # 损失函数采用MSE
    loss = (0.5*(y-y_pred)**2).mean()
    # 反向传播
    loss.backward()
    # 自更新过程
    b.data.sub_(lr*b.grad)
    w.data.sub_(lr*w.grad)
    # 每20次迭代画一次回归图像
    if iteration % 20 == 0:
        # 画出散点(x,y)
        plt.scatter(x.data.numpy(), y.data.numpy())
        # 画线性回归曲线
        plt.plot(x.data.numpy(), y_pred.data.numpy(), 'r', lw=5)
        # 显示loss值,文字位置在(2,20),字体大小20,颜色为红色
        plt.text(2, 20, 'Loss:%.4f'%loss.data.numpy(),fontdict={'size':20, 'color': 'red'})
        plt.xlim(1.5, 10)
        plt.ylim(8, 28)
        plt.title('Iteration: {}\nw:{} b:{}'.format(iteration, w.data.numpy(), b.data.numpy()))
        plt.pause(0.5)
        # 损失值小于1则停止迭代
        if loss.data.numpy()<1:
            break

运行结果

在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值