手动实现一个线性回归

# -*- coding: utf-8 -*-
# @Time     :2020/3/1 11:38
# @Author   :XiaoMa
# @File     :5.py
import torch as t
from matplotlib import pyplot as plt
#设置随机数种子
t.manual_seed(1000)
from IPython import display

def get_fake_data(batch_size=8):
    '''产生随机数据:y=x*2+3,加上了一些噪声'''
    x=t.rand(batch_size,1)*20
    y=x*2+(1+t.randn(batch_size,1))*3
    return x,y

# x,y=get_fake_data()
# plt.scatter(x.squeeze().numpy(),y.squeeze().numpy())
# plt.show()

#随机初始化参数
w=t.rand(1,1)
b=t.zeros(1,1)
lr=0.001
for ii in range(20000):
    x,y=get_fake_data()


    #forward:计算loss
    y_pred=x.mm(w) + b.expand_as(y)
    loss=0.5*(y_pred-y)**2  #均方误差
    loss=loss.sum()

    #backward:手动计算梯度
    dloss=1
    dy_pred=dloss*(y_pred-y)

    dw=x.t().mm(dy_pred)
    db=dy_pred.sum()

    #更新参数
    w.sub_(lr*dw)
    b.sub_(lr*db)

    if ii %1000==0:
        #画图
        display.clear_output(True)
        x=t.arange(0,20).view(-1,1)
        y=x.float().mm(w) + b.expand_as(x)
        plt.plot(x.numpy(),y.numpy())   #预测值

        x2,y2=get_fake_data(batch_size=20)
        plt.scatter(x2.numpy(),y2.numpy())  #true data

        plt.xlim(0,20)
        plt.ylim(0,41)
        plt.show()
        plt.pause(0.5)
        print('w:',w.item(),'b',b.item())

程序学习的结果:

w: 1.59645676612854 b 0.14356936514377594
w: 1.95401132106781 b 2.5674352645874023
w: 1.9625165462493896 b 2.9557881355285645
w: 1.8890539407730103 b 3.001418352127075
w: 2.0090134143829346 b 2.9461421966552734
w: 2.0715832710266113 b 3.1086935997009277
w: 1.9367080926895142 b 2.9967336654663086
w: 2.0456297397613525 b 2.9070372581481934
w: 1.8773678541183472 b 2.946399450302124
w: 2.1052446365356445 b 3.148202419281006
w: 1.81801438331604 b 3.0318338871002197
w: 2.119907855987549 b 2.948784351348877
w: 2.113875150680542 b 3.1145777702331543
w: 1.9084047079086304 b 2.946380376815796
w: 2.09352970123291 b 2.8815550804138184
w: 2.078930616378784 b 3.0505666732788086
w: 1.9736963510513306 b 3.040769100189209
w: 1.9155480861663818 b 2.9243123531341553
w: 2.0103886127471924 b 3.0951311588287354
w: 1.8826669454574585 b 3.0347161293029785

w接近2,b接近3。

拟合的结果:

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值