2.梯度下降线性回归

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

1 Prepare the training set

# 数据集
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

2 Initial guess of weight.

w = 1

3 Define the model:

# 正向传播
def forward(x):
    return x * w

4 Define the cost function

# 计算损失
def cost(xs, ys):
    cost = 0
    for x, y in zip(xs, ys):
        y_pred = forward(x)
        cost += (y_pred - y) ** 2
    return cost/len(xs)

5 Define the gradient function

# 反向传播 ---- 由损失函数权重求偏导得出来的
def gradient(xs, ys):
    grad = 0
    for x, y in zip(xs, ys):
        grad += 2 * x * (x * w - y)
    return grad / len(xs)

6 come to train

# 开始训练
mse_list = []
for epoch in range(100):
    # 计算成本
    cost_val = cost(x_data, y_data)
    # 计算梯度
    grad_val = gradient(x_data, y_data)
    # 更新参数
    w = w - 0.01 * grad_val
    mse_list.append(cost_val)
    print("epoch=", epoch, "cost_val=", cost_val, "w=", w)
epoch= 0 cost_val= 4.666666666666667 w= 1.0933333333333333
epoch= 1 cost_val= 3.8362074074074086 w= 1.1779555555555554
epoch= 2 cost_val= 3.1535329869958857 w= 1.2546797037037036
epoch= 3 cost_val= 2.592344272332262 w= 1.3242429313580246
epoch= 4 cost_val= 2.1310222071581117 w= 1.3873135910979424
epoch= 5 cost_val= 1.7517949663820642 w= 1.4444976559288012
epoch= 6 cost_val= 1.440053319920117 w= 1.4963445413754464
epoch= 7 cost_val= 1.1837878313441108 w= 1.5433523841804047
epoch= 8 cost_val= 0.9731262101573632 w= 1.5859728283235668
epoch= 9 cost_val= 0.7999529948031382 w= 1.6246153643467005
epoch= 10 cost_val= 0.6575969151946154 w= 1.659651263674342
epoch= 11 cost_val= 0.5405738908195378 w= 1.6914171457314033
epoch= 12 cost_val= 0.44437576375991855 w= 1.7202182121298057
epoch= 13 cost_val= 0.365296627844598 w= 1.7463311789976905
epoch= 14 cost_val= 0.3002900634939416 w= 1.7700069356245727
epoch= 15 cost_val= 0.2468517784170642 w= 1.7914729549662791
epoch= 16 cost_val= 0.2029231330489788 w= 1.8109354791694263
epoch= 17 cost_val= 0.16681183417217407 w= 1.8285815011136133
epoch= 18 cost_val= 0.1371267415488235 w= 1.8445805610096762
epoch= 19 cost_val= 0.11272427607497944 w= 1.8590863753154396
epoch= 20 cost_val= 0.09266436490145864 w= 1.872238313619332
epoch= 21 cost_val= 0.07617422636521683 w= 1.8841627376815275
epoch= 22 cost_val= 0.06261859959338009 w= 1.8949742154979183
epoch= 23 cost_val= 0.051475271914629306 w= 1.904776622051446
epoch= 24 cost_val= 0.04231496130368814 w= 1.9136641373266443
epoch= 25 cost_val= 0.03478477885657844 w= 1.9217221511761575
epoch= 26 cost_val= 0.02859463421027894 w= 1.9290280837330496
epoch= 27 cost_val= 0.023506060193480772 w= 1.9356521292512983
epoch= 28 cost_val= 0.01932302619282764 w= 1.9416579305211772
epoch= 29 cost_val= 0.015884386331668398 w= 1.9471031903392007
epoch= 30 cost_val= 0.01305767153735723 w= 1.952040225907542
epoch= 31 cost_val= 0.010733986344664803 w= 1.9565164714895047
epoch= 32 cost_val= 0.008823813841374291 w= 1.9605749341504843
epoch= 33 cost_val= 0.007253567147113681 w= 1.9642546069631057
epoch= 34 cost_val= 0.005962754575689583 w= 1.9675908436465492
epoch= 35 cost_val= 0.004901649272531298 w= 1.970615698239538
epoch= 36 cost_val= 0.004029373553099482 w= 1.9733582330705144
epoch= 37 cost_val= 0.0033123241439168096 w= 1.975844797983933
epoch= 38 cost_val= 0.0027228776607060357 w= 1.9780992835054327
epoch= 39 cost_val= 0.002238326453885249 w= 1.980143350378259
epoch= 40 cost_val= 0.001840003826269386 w= 1.9819966376762883
epoch= 41 cost_val= 0.0015125649231412608 w= 1.983676951493168
epoch= 42 cost_val= 0.0012433955919298103 w= 1.9852004360204722
epoch= 43 cost_val= 0.0010221264385926248 w= 1.9865817286585614
epoch= 44 cost_val= 0.0008402333603648631 w= 1.987834100650429
epoch= 45 cost_val= 0.0006907091659248264 w= 1.9889695845897222
epoch= 46 cost_val= 0.0005677936325753796 w= 1.9899990900280147
epoch= 47 cost_val= 0.0004667516012495216 w= 1.9909325082920666
epoch= 48 cost_val= 0.000383690560742734 w= 1.9917788075181404
epoch= 49 cost_val= 0.00031541069384432885 w= 1.9925461188164473
epoch= 50 cost_val= 0.0002592816085930997 w= 1.9932418143935788
epoch= 51 cost_val= 0.0002131410058905752 w= 1.9938725783835114
epoch= 52 cost_val= 0.00017521137977565514 w= 1.994444471067717
epoch= 53 cost_val= 0.0001440315413480261 w= 1.9949629871013967
epoch= 54 cost_val= 0.0001184003283899171 w= 1.9954331083052663
epoch= 55 cost_val= 9.733033217332803e-05 w= 1.9958593515301082
epoch= 56 cost_val= 8.000985883901657e-05 w= 1.9962458120539648
epoch= 57 cost_val= 6.57716599593935e-05 w= 1.9965962029289281
epoch= 58 cost_val= 5.406722767150764e-05 w= 1.9969138906555615
epoch= 59 cost_val= 4.444566413387458e-05 w= 1.997201927527709
epoch= 60 cost_val= 3.65363112808981e-05 w= 1.9974630809584561
epoch= 61 cost_val= 3.0034471708953996e-05 w= 1.9976998600690001
epoch= 62 cost_val= 2.4689670610172655e-05 w= 1.9979145397958935
epoch= 63 cost_val= 2.0296006560253656e-05 w= 1.9981091827482769
epoch= 64 cost_val= 1.6684219437262796e-05 w= 1.9982856590251044
epoch= 65 cost_val= 1.3715169898293847e-05 w= 1.9984456641827613
epoch= 66 cost_val= 1.1274479219506377e-05 w= 1.9985907355257035
epoch= 67 cost_val= 9.268123006398985e-06 w= 1.9987222668766378
epoch= 68 cost_val= 7.61880902783969e-06 w= 1.9988415219681517
epoch= 69 cost_val= 6.262999634617916e-06 w= 1.9989496465844576
epoch= 70 cost_val= 5.1484640551938914e-06 w= 1.9990476795699081
epoch= 71 cost_val= 4.232266273994499e-06 w= 1.9991365628100501
epoch= 72 cost_val= 3.479110977946351e-06 w= 1.999217150281112
epoch= 73 cost_val= 2.859983851026929e-06 w= 1.999290216254875
epoch= 74 cost_val= 2.3510338359374262e-06 w= 1.9993564627377531
epoch= 75 cost_val= 1.932654303533636e-06 w= 1.9994165262155628
epoch= 76 cost_val= 1.5887277332523938e-06 w= 1.999470983768777
epoch= 77 cost_val= 1.3060048068548734e-06 w= 1.9995203586170245
epoch= 78 cost_val= 1.0735939958924364e-06 w= 1.9995651251461022
epoch= 79 cost_val= 8.825419799121559e-07 w= 1.9996057134657994
epoch= 80 cost_val= 7.254887315754342e-07 w= 1.9996425135423248
epoch= 81 cost_val= 5.963839812987369e-07 w= 1.999675878945041
epoch= 82 cost_val= 4.902541385825727e-07 w= 1.999706130243504
epoch= 83 cost_val= 4.0301069098738336e-07 w= 1.9997335580874436
epoch= 84 cost_val= 3.312926995781724e-07 w= 1.9997584259992822
epoch= 85 cost_val= 2.723373231729343e-07 w= 1.9997809729060159
epoch= 86 cost_val= 2.2387338352920307e-07 w= 1.9998014154347876
epoch= 87 cost_val= 1.8403387118941732e-07 w= 1.9998199499942075
epoch= 88 cost_val= 1.5128402140063082e-07 w= 1.9998367546614149
epoch= 89 cost_val= 1.2436218932547864e-07 w= 1.9998519908930161
epoch= 90 cost_val= 1.0223124683409346e-07 w= 1.9998658050763347
epoch= 91 cost_val= 8.403862850836479e-08 w= 1.9998783299358769
epoch= 92 cost_val= 6.908348768398496e-08 w= 1.9998896858085284
epoch= 93 cost_val= 5.678969725349543e-08 w= 1.9998999817997325
epoch= 94 cost_val= 4.66836551287917e-08 w= 1.9999093168317574
epoch= 95 cost_val= 3.8376039345125727e-08 w= 1.9999177805941268
epoch= 96 cost_val= 3.154680994333735e-08 w= 1.9999254544053418
epoch= 97 cost_val= 2.593287985380858e-08 w= 1.9999324119941766
epoch= 98 cost_val= 2.131797981222471e-08 w= 1.9999387202080534
epoch= 99 cost_val= 1.752432687141379e-08 w= 1.9999444396553017
# 预测
print("x=4, y=", forward(4))

# 绘图
plt.plot(mse_list)
plt.xlabel("epoch")
plt.ylabel("cost")
plt.show()
x=4, y= 7.999777758621207

在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值