Pytorch实现线性回归实战

随机初始化一个二维数据集,数据中只包含X轴Y轴坐标,然后使用Pytorch训练一个一元线性回归模型模拟这些数据。
随机生成二维数据并可视化

import numpy as np
import random
import matplotlib.pyplot as plt
x = np.arange(20)
y = np.array([5 * x[i] + random.randint(1, 20) for i in range(len(x))])
print(x, y)
plt.xlabel("X")
plt.ylabel("Y")
plt.scatter(x, y)
plt.show()

可视化结果:
在这里插入图片描述
现在要做的事是找出一条直线,最大限度逼近这些点,使误差最小。要达到这一目的,需要借助pytorch训练一个一元线性回归模型,首先使用from_numpy方法将上面生成的数据转换成Tensor。

import torch
x_train = torch.from_numpy(x).float()
y_train = torch.from_numpy(y).float()

然后,借助pytorch的nn.Modlue搭建线性模型,新建LinearRegression类继承nn.Modlue。

class LinearRegression(torch.nn.Module):
    def __init__(self):
        super(LinearRegression, self).__init__()
        self.linear = torch.nn.Linear(1, 1)
    def forward(self, x):
        return self.linear(x)

因为随机生成的输入数据x和输出数据y都是一维的,所以在__int__方法中声明的linear模型的输入输出都是1。在forward方法中简单地调用linear模型,将x传入即可。下面开始创建模型,优化器采用SGD,损失函数采用MSELoss。

model = LinearRegression()
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), 0.001)
epochs = 10
for i in range(epochs):
    input_data = x_train.unsqueeze(1)
    target = y_train.unsqueeze(1)
    out = model(input_data)
    loss = criterion(out, target)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print("Epoch:[{}/{}],loss:[{:.4f}]".format(i+1, epochs, loss.item()))
    if (i+1) % 2 == 0:
        predict = model(input_data)
        plt.plot(x_train.data.numpy(), predict.squeeze(1).data.numpy(), "r")
        loss = criterion(predict, target)
        plt.title("Loss:{:4f}".format(loss.item()))
        plt.xlabel("X")
        plt.ylabel("Y")
        plt.scatter(x_train, y_train)
        plt.show()

每迭代2个epoch打印出拟合的直线图,如下:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述在这里插入图片描述

在这里插入图片描述

从上图可以看出,随着迭代次数的增加,损失值越来越小,直线拟合得越来越好,至此就完成了简单的一元线性回归。

  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值