《PyTorch深度学习实践》lecture5: Linear Regression

代码如下:

import torch

x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)

    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred

model = LinearModel()

# criterion = torch.nn.MSELoss(size_average=False)
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)    

for epoch in range(100):
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    print('epoch:', epoch, '\tloss =', loss.item())

    # the grad computer by .backward() will be accumulated. so before backward
    # remember set the grad to zero
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()    # update values of w and b
    print('w =', model.linear.weight.item(), '\tb =', model.linear.bias.item())


x_test = torch.tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

结果如下:

epoch: 0 loss = 39.28617477416992
w = 0.9208120107650757 b = -0.09137517213821411
epoch: 1 loss = 17.5134334564209
w = 1.2339496612548828 b = 0.04360988736152649
epoch: 2 loss = 7.820481300354004
w = 1.4432106018066406 b = 0.1329193264245987
epoch: 3 loss = 3.505107879638672
w = 1.5831613540649414 b = 0.1917589008808136
epoch: 4 loss = 1.5836862325668335
w = 1.6768651008605957 b = 0.2302739918231964
epoch: 5 loss = 0.7279895544052124
w = 1.7397099733352661 b = 0.2552337348461151
epoch: 6 loss = 0.3467276096343994
w = 1.7819631099700928 b = 0.271154522895813
epoch: 7 loss = 0.17667557299137115
w = 1.8104748725891113 b = 0.2810496985912323
epoch: 8 loss = 0.10065186023712158
w = 1.8298159837722778 b = 0.286929726600647
epoch: 9 loss = 0.06649214774370193
w = 1.8430359363555908 b = 0.2901360094547272
epoch: 10 loss = 0.050973616540431976
w = 1.8521695137023926 b = 0.2915635406970978
epoch: 11 loss = 0.04375807195901871
w = 1.858574390411377 b = 0.2918093800544739
epoch: 12 loss = 0.04024308919906616
w = 1.8631564378738403 b = 0.2912718951702118
epoch: 13 loss = 0.03837998956441879
w = 1.8665200471878052 b = 0.2902168035507202
epoch: 14 loss = 0.0372563898563385
w = 1.8690683841705322 b = 0.28882139921188354
epoch: 15 loss = 0.03646640107035637
w = 1.8710706233978271 b = 0.28720390796661377
epoch: 16 loss = 0.0358290858566761
w = 1.872706413269043 b = 0.2854432165622711
epoch: 17 loss = 0.03526361659169197
w = 1.8740954399108887 b = 0.2835918366909027
epoch: 18 loss = 0.034734323620796204
w = 1.8753176927566528 b = 0.28168487548828125
epoch: 19 loss = 0.03422508016228676
w = 1.8764265775680542 b = 0.2797456681728363
epoch: 20 loss = 0.033728890120983124
w = 1.877457618713379 b = 0.27778974175453186
epoch: 21 loss = 0.03324216231703758
w = 1.8784347772598267 b = 0.27582743763923645
epoch: 22 loss = 0.032763510942459106
w = 1.8793737888336182 b = 0.27386561036109924
epoch: 23 loss = 0.0322922021150589
w = 1.8802852630615234 b = 0.27190881967544556
epoch: 24 loss = 0.03182794898748398
w = 1.8811763525009155 b = 0.26996007561683655
epoch: 25 loss = 0.031370509415864944
w = 1.8820518255233765 b = 0.2680213153362274
epoch: 26 loss = 0.030919615179300308
w = 1.8829147815704346 b = 0.26609382033348083
epoch: 27 loss = 0.03047514520585537
w = 1.8837673664093018 b = 0.26417842507362366
epoch: 28 loss = 0.030037226155400276
w = 1.8846111297607422 b = 0.2622756361961365
epoch: 29 loss = 0.029605546966195107
w = 1.8854469060897827 b = 0.26038575172424316
epoch: 30 loss = 0.02918008156120777
w = 1.8862755298614502 b = 0.25850898027420044
epoch: 31 loss = 0.02876061201095581
w = 1.8870972394943237 b = 0.2566453814506531
epoch: 32 loss = 0.02834739163517952
w = 1.887912631034851 b = 0.25479498505592346
epoch: 33 loss = 0.027939941734075546
w = 1.8887217044830322 b = 0.2529577612876892
epoch: 34 loss = 0.02753838524222374
w = 1.8895246982574463 b = 0.25113368034362793
epoch: 35 loss = 0.027142547070980072
w = 1.8903217315673828 b = 0.24932269752025604
epoch: 36 loss = 0.026752525940537453
w = 1.8911129236221313 b = 0.24752472341060638
epoch: 37 loss = 0.026368016377091408
w = 1.8918983936309814 b = 0.24573969841003418
epoch: 38 loss = 0.02598903886973858
w = 1.8926780223846436 b = 0.24396750330924988
epoch: 39 loss = 0.025615666061639786
w = 1.8934520483016968 b = 0.2422080934047699
epoch: 40 loss = 0.025247540324926376
w = 1.8942204713821411 b = 0.24046136438846588
epoch: 41 loss = 0.024884670972824097
w = 1.8949834108352661 b = 0.23872721195220947
epoch: 42 loss = 0.024527017027139664
w = 1.8957407474517822 b = 0.2370055615901947
epoch: 43 loss = 0.024174446240067482
w = 1.896492600440979 b = 0.2352963238954544
epoch: 44 loss = 0.02382710948586464
w = 1.897239089012146 b = 0.23359942436218262
epoch: 45 loss = 0.023484615609049797
w = 1.8979802131652832 b = 0.23191477358341217
epoch: 46 loss = 0.023147130385041237
w = 1.8987159729003906 b = 0.23024225234985352
epoch: 47 loss = 0.022814491763710976
w = 1.8994464874267578 b = 0.22858180105686188
epoch: 48 loss = 0.022486625239253044
w = 1.9001716375350952 b = 0.2269333153963089
epoch: 49 loss = 0.02216344140470028
w = 1.9008915424346924 b = 0.2252967208623886
epoch: 50 loss = 0.021844908595085144
w = 1.9016063213348389 b = 0.22367194294929504
epoch: 51 loss = 0.021531004458665848
w = 1.9023159742355347 b = 0.22205887734889984
epoch: 52 loss = 0.0212215818464756
w = 1.9030205011367798 b = 0.22045743465423584
epoch: 53 loss = 0.02091650851070881
w = 1.9037199020385742 b = 0.21886752545833588
epoch: 54 loss = 0.020615965127944946
w = 1.9044142961502075 b = 0.21728909015655518
epoch: 55 loss = 0.020319659262895584
w = 1.9051035642623901 b = 0.21572202444076538
epoch: 56 loss = 0.020027615129947662
w = 1.9057879447937012 b = 0.2141662836074829
epoch: 57 loss = 0.0197397843003273
w = 1.9064674377441406 b = 0.2126217633485794
epoch: 58 loss = 0.01945609599351883
w = 1.907141923904419 b = 0.2110883742570877
epoch: 59 loss = 0.019176488742232323
w = 1.9078116416931152 b = 0.20956604182720184
epoch: 60 loss = 0.018900904804468155
w = 1.90847647190094 b = 0.20805467665195465
epoch: 61 loss = 0.01862926036119461
w = 1.9091365337371826 b = 0.20655421912670135
epoch: 62 loss = 0.01836152747273445
w = 1.9097918272018433 b = 0.2050645798444748
epoch: 63 loss = 0.018097639083862305
w = 1.9104423522949219 b = 0.203585684299469
epoch: 64 loss = 0.017837535589933395
w = 1.911088228225708 b = 0.202117457985878
epoch: 65 loss = 0.017581157386302948
w = 1.9117294549942017 b = 0.2006598263978958
epoch: 66 loss = 0.017328530550003052
w = 1.9123660326004028 b = 0.1992127001285553
epoch: 67 loss = 0.01707947999238968
w = 1.9129979610443115 b = 0.19777601957321167
epoch: 68 loss = 0.016834083944559097
w = 1.9136254787445068 b = 0.19634971022605896
epoch: 69 loss = 0.016592107713222504
w = 1.9142483472824097 b = 0.19493368268013
epoch: 70 loss = 0.01635364629328251
w = 1.9148668050765991 b = 0.19352786242961884
epoch: 71 loss = 0.016118638217449188
w = 1.9154807329177856 b = 0.19213217496871948
epoch: 72 loss = 0.015886958688497543
w = 1.9160902500152588 b = 0.19074656069278717
epoch: 73 loss = 0.01565859466791153
w = 1.9166953563690186 b = 0.18937093019485474
epoch: 74 loss = 0.015433606691658497
w = 1.9172961711883545 b = 0.1880052238702774
epoch: 75 loss = 0.015211805701255798
w = 1.917892575263977 b = 0.1866493672132492
epoch: 76 loss = 0.014993196353316307
w = 1.9184846878051758 b = 0.18530328571796417
epoch: 77 loss = 0.014777696691453457
w = 1.9190726280212402 b = 0.18396693468093872
epoch: 78 loss = 0.014565300196409225
w = 1.9196562767028809 b = 0.1826401948928833
epoch: 79 loss = 0.014355980791151524
w = 1.9202357530593872 b = 0.18132303655147552
epoch: 80 loss = 0.014149629510939121
w = 1.9208109378814697 b = 0.18001537024974823
epoch: 81 loss = 0.013946276158094406
w = 1.921381950378418 b = 0.17871712148189545
epoch: 82 loss = 0.013745900243520737
w = 1.9219489097595215 b = 0.1774282604455948
epoch: 83 loss = 0.013548349030315876
w = 1.9225118160247803 b = 0.1761486977338791
epoch: 84 loss = 0.01335364393889904
w = 1.9230706691741943 b = 0.17487835884094238
epoch: 85 loss = 0.013161790557205677
w = 1.9236254692077637 b = 0.1736171841621399
epoch: 86 loss = 0.012972550466656685
w = 1.9241762161254883 b = 0.17236508429050446
epoch: 87 loss = 0.012786151841282845
w = 1.9247230291366577 b = 0.17112202942371368
epoch: 88 loss = 0.012602370232343674
w = 1.925265908241272 b = 0.1698879450559616
epoch: 89 loss = 0.012421232648193836
w = 1.925804853439331 b = 0.16866275668144226
epoch: 90 loss = 0.01224280335009098
w = 1.9263399839401245 b = 0.16744641959667206
epoch: 91 loss = 0.012066814117133617
w = 1.9268711805343628 b = 0.16623882949352264
epoch: 92 loss = 0.011893432587385178
w = 1.9273985624313354 b = 0.16503995656967163
epoch: 93 loss = 0.01172245480120182
w = 1.9279221296310425 b = 0.16384974122047424
epoch: 94 loss = 0.011553991585969925
w = 1.9284420013427734 b = 0.1626681089401245
epoch: 95 loss = 0.011387962847948074
w = 1.9289580583572388 b = 0.16149497032165527
epoch: 96 loss = 0.011224319227039814
w = 1.929470419883728 b = 0.16033031046390533
epoch: 97 loss = 0.011063016951084137
w = 1.9299790859222412 b = 0.15917403995990753
epoch: 98 loss = 0.010903997346758842
w = 1.9304840564727783 b = 0.15802611410617828
epoch: 99 loss = 0.010747263208031654
w = 1.9309853315353394 b = 0.15688644349575043
y_pred =  tensor([[7.8808]])
Process finished with exit code 0

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值