Pytorch 反向传播实例,梯度下降

1.手写更新权重

#求 y = w*x
x_data = [1.0,2.0,3.0,4.0]
y_data = [2.0,4.0,6.0,8.0]

w = 1.0

def grad(x,y): #计算梯度
    # (y^ - y)^2 =(wx -y)^2   grad:2w(wx-y)
    return 2 * x * (w * x - y)

def loss(x,y):
    return (y - (w * x)) * (y - (w * x))

for i in range(30):
    for x, y in zip(x_data , y_data):
        w = w - 0.01 * grad(x,y)
        print(x,y, w * x ,loss(x,y), w)

输出

1.0 2.0 1.02 0.9603999999999999 1.02
2.0 4.0 2.1968 3.2515302399999997 1.0984
3.0 6.0 3.782064 4.919240100095999 1.260688
4.0 8.0 5.9890713600000005 4.043833995172248 1.4972678400000001
1.0 2.0 1.5073224832 0.24273113556021422 1.5073224832
2.0 4.0 3.093473369088 0.8217905325526613 1.546736684544
3.0 6.0 4.88497224397824 1.2432868966989221 1.62832408132608
4.0 8.0 6.989041501206938 1.0220370862819224 1.7472603753017344
1.0 2.0 1.7523151677956996 0.06134777610407243 1.7523151677956996
2.0 4.0 3.5442599087440874 0.20769903077794757 1.7721299543720437
3.0 6.0 5.439439687755227 0.3142278636639573 1.8131465625850758
4.0 8.0 7.491758650231406 0.2583092696146019 1.8729396625578516
1.0 2.0 1.8754808693066947 0.015505013908616454 1.8754808693066947
2.0 4.0 3.770884799524318 0.05249377508901196 1.885442399762159
3.0 6.0 5.7181883034149115 0.079417832332166 1.9060627678049704
4.0 8.0 7.744490728429519 0.06528498785847768 1.9361226821073798
1.0 2.0 1.9374002284652323 0.003918731396205113 1.9374002284652323
2.0 4.0 3.8848164203760276 0.013267257014991991 1.9424082101880138
3.0 6.0 5.8583241970625135 0.020072033137981508 1.9527747323541713
4.0 8.0 7.871547272003346 0.016500103329782457 1.9678868180008364
1.0 2.0 1.9685290816408196 0.000990418702370196 1.9685290816408196
2.0 4.0 3.942093510219108 0.003353161558744533 1.971046755109554
3.0 6.0 5.928775017569503 0.005072998122224593 1.9762583391898343
4.0 8.0 7.935422682596349 0.004170229923051861 1.9838556706490873
1.0 2.0 1.9841785572361055 0.00025031805113119034 1.9841785572361055
2.0 4.0 3.970888545314434 0.0008474767939097611 1.985444272657217
3.0 6.0 5.964192910736754 0.0012821476415060546 1.988064303578918
4.0 8.0 7.967534905734657 0.0010539823456576249 1.9918837264336642
1.0 2.0 1.992046051904991 6.3265290298098e-05 1.992046051904991
2.0 4.0 3.9853647355051836 0.0002141909668332349 1.9926823677525918
3.0 6.0 5.981998624671376 0.00032404951372199694 1.9939995415571252
4.0 8.0 7.983678753035381 0.00026638310248008516 1.9959196882588452
1.0 2.0 1.9960012944936683 1.5989645726367822e-05 1.9960012944936683
2.0 4.0 3.9926423818683494 5.413454457119351e-05 1.9963211909341747
3.0 6.0 5.99095012969807 8.190015248175754e-05 1.9969833765660232
4.0 8.0 7.991794784259583 6.732556534678497e-05 1.9979486960648958
1.0 2.0 1.9979897221435978 4.041217059940929e-06 1.9979897221435978
2.0 4.0 3.99630108874422 1.3681944478137058e-05 1.99815054437211
3.0 6.0 5.995450339155391 2.0699413800969838e-05 1.9984834463851302
4.0 8.0 7.995874974167554 1.701583811834411e-05 1.9989687435418886
1.0 2.0 1.9989893686710507 1.0213756830538208e-06 1.9989893686710507
2.0 4.0 3.998140438354733 3.457969512547346e-06 1.9990702191773666
3.0 6.0 5.997712739176322 5.231562075532798e-06 1.9992375797254407
4.0 8.0 7.997926216853199 4.300576539955676e-06 1.9994815542132998
1.0 2.0 1.9994919231290338 2.5814210681081686e-07 1.9994919231290338
2.0 4.0 3.9990651385574223 8.73965916818469e-07 1.9995325692787111
3.0 6.0 5.9988501204256295 1.3222230355545188e-06 1.999616706808543
4.0 8.0 7.998957442519237 1.0869261006946292e-06 1.9997393606298093
1.0 2.0 1.999744573417213 6.524273919423822e-08 1.999744573417213
2.0 4.0 3.999530015087672 2.2088581781592946e-07 1.999765007543836
3.0 6.0 5.999421918557837 3.3417815377348864e-07 1.9998073061859456
4.0 8.0 7.999475872825772 2.747092947642348e-07 1.999868968206443
1.0 2.0 1.9998715888423142 1.648942541820837e-08 1.9998715888423142
2.0 4.0 3.999763723469858 5.58265986959534e-08 1.999881861734929
3.0 6.0 5.999709379867925 8.446006116737636e-08 1.9999031266226417
4.0 8.0 7.999736504413585 6.94299240601421e-08 1.9999341261033963
1.0 2.0 1.9999354435813284 4.1675311917088825e-09 1.9999354435813284
2.0 4.0 3.999881216189644 1.4109593602641154e-08 1.999940608094822

利用pytorch 更新权重

但是在定义好模型之后,使用pytorch框架不需要我们手动的求导,我们可以通过反向传播将梯度往回传播。通常有二个过程,forward和backward:

import torch
from torch import nn
from torch.autograd import Variable
#求 y = w*x
x_data = [1.0,2.0,3.0,4.0]
y_data = [2.0,4.0,6.0,8.0]

# w = 1.0
w = Variable(torch.Tensor([1.0]),requires_grad=True)
# def grad(x,y):
#     # (y^ - y)^2 =(wx -y)^2   grad:2w(wx-y)
#     return 2 * x * (w * x - y)

def loss(x,y):
    return (y - (w * x)) * (y - (w * x))

for i in range(10):
    for x, y in zip(x_data , y_data):
        
        l = loss(x,y)
        l.backward()
        w.data = w.data - 0.01 * w.grad.data
        w.grad.data.zero_()

        print(x,y, w * x ,loss(x,y), w)
1.0 2.0 tensor([1.0200], grad_fn=<MulBackward0>) tensor([0.9604], grad_fn=<MulBackward0>) tensor([1.0200], requires_grad=True)
2.0 4.0 tensor([2.1968], grad_fn=<MulBackward0>) tensor([3.2515], grad_fn=<MulBackward0>) tensor([1.0984], requires_grad=True)
3.0 6.0 tensor([3.7821], grad_fn=<MulBackward0>) tensor([4.9192], grad_fn=<MulBackward0>) tensor([1.2607], requires_grad=True)
4.0 8.0 tensor([5.9891], grad_fn=<MulBackward0>) tensor([4.0438], grad_fn=<MulBackward0>) tensor([1.4973], requires_grad=True)
1.0 2.0 tensor([1.5073], grad_fn=<MulBackward0>) tensor([0.2427], grad_fn=<MulBackward0>) tensor([1.5073], requires_grad=True)
2.0 4.0 tensor([3.0935], grad_fn=<MulBackward0>) tensor([0.8218], grad_fn=<MulBackward0>) tensor([1.5467], requires_grad=True)
3.0 6.0 tensor([4.8850], grad_fn=<MulBackward0>) tensor([1.2433], grad_fn=<MulBackward0>) tensor([1.6283], requires_grad=True)
4.0 8.0 tensor([6.9890], grad_fn=<MulBackward0>) tensor([1.0220], grad_fn=<MulBackward0>) tensor([1.7473], requires_grad=True)
1.0 2.0 tensor([1.7523], grad_fn=<MulBackward0>) tensor([0.0613], grad_fn=<MulBackward0>) tensor([1.7523], requires_grad=True)
2.0 4.0 tensor([3.5443], grad_fn=<MulBackward0>) tensor([0.2077], grad_fn=<MulBackward0>) tensor([1.7721], requires_grad=True)
3.0 6.0 tensor([5.4394], grad_fn=<MulBackward0>) tensor([0.3142], grad_fn=<MulBackward0>) tensor([1.8131], requires_grad=True)
4.0 8.0 tensor([7.4918], grad_fn=<MulBackward0>) tensor([0.2583], grad_fn=<MulBackward0>) tensor([1.8729], requires_grad=True)
1.0 2.0 tensor([1.8755], grad_fn=<MulBackward0>) tensor([0.0155], grad_fn=<MulBackward0>) tensor([1.8755], requires_grad=True)
2.0 4.0 tensor([3.7709], grad_fn=<MulBackward0>) tensor([0.0525], grad_fn=<MulBackward0>) tensor([1.8854], requires_grad=True)
3.0 6.0 tensor([5.7182], grad_fn=<MulBackward0>) tensor([0.0794], grad_fn=<MulBackward0>) tensor([1.9061], requires_grad=True)
4.0 8.0 tensor([7.7445], grad_fn=<MulBackward0>) tensor([0.0653], grad_fn=<MulBackward0>) tensor([1.9361], requires_grad=True)
1.0 2.0 tensor([1.9374], grad_fn=<MulBackward0>) tensor([0.0039], grad_fn=<MulBackward0>) tensor([1.9374], requires_grad=True)
2.0 4.0 tensor([3.8848], grad_fn=<MulBackward0>) tensor([0.0133], grad_fn=<MulBackward0>) tensor([1.9424], requires_grad=True)
3.0 6.0 tensor([5.8583], grad_fn=<MulBackward0>) tensor([0.0201], grad_fn=<MulBackward0>) tensor([1.9528], requires_grad=True)
4.0 8.0 tensor([7.8715], grad_fn=<MulBackward0>) tensor([0.0165], grad_fn=<MulBackward0>) tensor([1.9679], requires_grad=True)
1.0 2.0 tensor([1.9685], grad_fn=<MulBackward0>) tensor([0.0010], grad_fn=<MulBackward0>) tensor([1.9685], requires_grad=True)
2.0 4.0 tensor([3.9421], grad_fn=<MulBackward0>) tensor([0.0034], grad_fn=<MulBackward0>) tensor([1.9710], requires_grad=True)
3.0 6.0 tensor([5.9288], grad_fn=<MulBackward0>) tensor([0.0051], grad_fn=<MulBackward0>) tensor([1.9763], requires_grad=True)
4.0 8.0 tensor([7.9354], grad_fn=<MulBackward0>) tensor([0.0042], grad_fn=<MulBackward0>) tensor([1.9839], requires_grad=True)
1.0 2.0 tensor([1.9842], grad_fn=<MulBackward0>) tensor([0.0003], grad_fn=<MulBackward0>) tensor([1.9842], requires_grad=True)
2.0 4.0 tensor([3.9709], grad_fn=<MulBackward0>) tensor([0.0008], grad_fn=<MulBackward0>) tensor([1.9854], requires_grad=True)
3.0 6.0 tensor([5.9642], grad_fn=<MulBackward0>) tensor([0.0013], grad_fn=<MulBackward0>) tensor([1.9881], requires_grad=True)
4.0 8.0 tensor([7.9675], grad_fn=<MulBackward0>) tensor([0.0011], grad_fn=<MulBackward0>) tensor([1.9919], requires_grad=True)
1.0 2.0 tensor([1.9920], grad_fn=<MulBackward0>) tensor([6.3264e-05], grad_fn=<MulBackward0>) tensor([1.9920], requires_grad=True)
2.0 4.0 tensor([3.9854], grad_fn=<MulBackward0>) tensor([0.0002], grad_fn=<MulBackward0>) tensor([1.9927], requires_grad=True)
3.0 6.0 tensor([5.9820], grad_fn=<MulBackward0>) tensor([0.0003], grad_fn=<MulBackward0>) tensor([1.9940], requires_grad=True)
4.0 8.0 tensor([7.9837], grad_fn=<MulBackward0>) tensor([0.0003], grad_fn=<MulBackward0>) tensor([1.9959], requires_grad=True)
1.0 2.0 tensor([1.9960], grad_fn=<MulBackward0>) tensor([1.5989e-05], grad_fn=<MulBackward0>) tensor([1.9960], requires_grad=True)
2.0 4.0 tensor([3.9926], grad_fn=<MulBackward0>) tensor([5.4134e-05], grad_fn=<MulBackward0>) tensor([1.9963], requires_grad=True)
3.0 6.0 tensor([5.9910], grad_fn=<MulBackward0>) tensor([8.1901e-05], grad_fn=<MulBackward0>) tensor([1.9970], requires_grad=True)
4.0 8.0 tensor([7.9918], grad_fn=<MulBackward0>) tensor([6.7321e-05], grad_fn=<MulBackward0>) tensor([1.9979], requires_grad=True)
1.0 2.0 tensor([1.9980], grad_fn=<MulBackward0>) tensor([4.0410e-06], grad_fn=<MulBackward0>) tensor([1.9980], requires_grad=True)
2.0 4.0 tensor([3.9963], grad_fn=<MulBackward0>) tensor([1.3681e-05], grad_fn=<MulBackward0>) tensor([1.9982], requires_grad=True)
3.0 6.0 tensor([5.9955], grad_fn=<MulBackward0>) tensor([2.0698e-05], grad_fn=<MulBackward0>) tensor([1.9985], requires_grad=True)
4.0 8.0 tensor([7.9959], grad_fn=<MulBackward0>) tensor([1.7013e-05], grad_fn=<MulBackward0>) tensor([1.9990], requires_grad=True)

https://blog.csdn.net/m0_37306360/article/details/79307354

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值