1.pytorch——TwoLayerNet

什么是PyTorch?

PyTorch是一个基于Python的科学计算库,它有以下特点:

  • 类似于NumPy,但是它可以使用GPU
  • 可以用它定义深度学习模型,可以灵活地进行深度学习模型的训练和使用

Torch

Torch类似与NumPy的ndarray,唯一的区别是Torch可以在GPU上加速运算。

import torch

构造一个未初始化的5x3矩阵:

x = torch.empty(5,3)
x
tensor([[1.9349e-19, 4.5445e+30, 4.7429e+30],
        [7.1354e+31, 7.1118e-04, 1.7444e+28],
        [7.3909e+22, 4.5828e+30, 3.2483e+33],
        [1.9690e-19, 6.8589e+22, 1.3340e+31],
        [1.1708e-19, 7.2128e+22, 9.2216e+29]])

构建一个随机初始化的矩阵:

x = torch.rand(5,3)
x
tensor([[0.6541, 0.2212, 0.4101],
        [0.0580, 0.9050, 0.7793],
        [0.9319, 0.8730, 0.2262],
        [0.4927, 0.4342, 0.1903],
        [0.0228, 0.4848, 0.7103]])

构建一个全部为0,类型为long的矩阵:

x = torch.zeros(5,3,dtype=torch.long)
x
tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]])
x = torch.zeros(5,3).long()
x.dtype
torch.int64

从数据直接直接构建tensor:

x = torch.tensor([5.5,3])
x
tensor([5.5000, 3.0000])

也可以从一个已有的tensor构建一个tensor。这些方法会重用原来tensor的特征,例如,数据类型,除非提供新的数据。

x = x.new_ones(5,3,dtype=torch.double)
x
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], dtype=torch.float64)
x = torch.randn_like(x, dtype=torch.float)
x
tensor([[ 1.1003,  1.5814,  0.5290],
        [ 1.1359, -1.6799,  0.5144],
        [ 1.3869, -0.9930,  2.4564],
        [ 1.1065,  0.4230, -1.4289],
        [-0.2244, -0.6768, -0.2486]])

得到tensor的形状:

x.shape
torch.Size([5, 3])
注意

``torch.Size`` 返回的是一个tuple

Operations

有很多种tensor运算。我们先介绍加法运算。

y = torch.rand(5,3)
y
tensor([[0.0396, 0.5141, 0.2989],
        [0.0204, 0.1849, 0.1429],
        [0.0464, 0.3222, 0.5240],
        [0.1020, 0.3913, 0.4771],
        [0.8843, 0.6870, 0.9798]])
x
tensor([[ 1.1003,  1.5814,  0.5290],
        [ 1.1359, -1.6799,  0.5144],
        [ 1.3869, -0.9930,  2.4564],
        [ 1.1065,  0.4230, -1.4289],
        [-0.2244, -0.6768, -0.2486]])
x + y
tensor([[ 1.1399,  2.0955,  0.8278],
        [ 1.1563, -1.4949,  0.6573],
        [ 1.4333, -0.6708,  2.9804],
        [ 1.2085,  0.8143, -0.9518],
        [ 0.6599,  0.0102,  0.7312]])

另一种着加法的写法

torch.add(x, y)
tensor([[ 1.1399,  2.0955,  0.8278],
        [ 1.1563, -1.4949,  0.6573],
        [ 1.4333, -0.6708,  2.9804],
        [ 1.2085,  0.8143, -0.9518],
        [ 0.6599,  0.0102,  0.7312]])

加法:把输出作为一个变量

result = torch.empty(5,3)
# torch.add(x, y, out=result)
result = x + y
result
tensor([[ 2.2402,  3.6768,  1.3568],
        [ 2.2921, -3.1748,  1.1717],
        [ 2.8203, -1.6638,  5.4367],
        [ 2.3151,  1.2373, -2.3807],
        [ 0.4355, -0.6666,  0.4826]])

in-place加法

y.add_(x)
y
tensor([[ 1.1399,  2.0955,  0.8278],
        [ 1.1563, -1.4949,  0.6573],
        [ 1.4333, -0.6708,  2.9804],
        [ 1.2085,  0.8143, -0.9518],
        [ 0.6599,  0.0102,  0.7312]])
注意

任何in-place的运算都会以``_``结尾。 举例来说:``x.copy_(y)``, ``x.t_()``, 会改变 ``x``。

各种类似NumPy的indexing都可以在PyTorch tensor上面使用。

x[1:, 1:]
tensor([[-1.6799,  0.5144],
        [-0.9930,  2.4564],
        [ 0.4230, -1.4289],
        [-0.6768, -0.2486]])

Resizing: 如果你希望resize/reshape一个tensor,可以使用torch.view

x = torch.randn(4,4)
y = x.view(16)
z = x.view(-1,8)
z

tensor([[ 0.6418,  0.6175, -0.4498,  0.3640,  1.1340,  0.4898, -1.6624, -0.1805],
        [ 0.1124, -0.5237,  0.6692, -0.2532,  0.5561,  0.8064,  2.0955,  0.2386]])

如果你有一个只有一个元素的tensor,使用.item()方法可以把里面的value变成Python数值。

x = torch.randn(1)
x
tensor([-1.1493])
x.item()
-1.1493233442306519
z.transpose(1,0)
tensor([[-0.5683, -0.2612],
        [ 1.3885, -0.4682],
        [-2.0829, -1.0596],
        [-0.7613,  0.7447],
        [-1.9115,  0.7603],
        [ 0.3732, -0.4281],
        [-0.2055,  0.5495],
        [-1.2300,  0.1025]])

更多阅读

各种Tensor operations, 包括transposing, indexing, slicing,
mathematical operations, linear algebra, random numbers在
<https://pytorch.org/docs/torch>.

Numpy和Tensor之间的转化

在Torch Tensor和NumPy array之间相互转化非常容易。

Torch Tensor和NumPy array会共享内存,所以改变其中一项也会改变另一项。

把Torch Tensor转变成NumPy Array

a = torch.ones(5)
a
tensor([1., 1., 1., 1., 1.])
b = a.numpy()
b
array([1., 1., 1., 1., 1.], dtype=float32)

共享空间的,改变numpy array里面的值。

b[1] = 2
b
array([1., 2., 1., 1., 1.], dtype=float32)
a
tensor([1., 2., 1., 1., 1.])

把NumPy ndarray转成Torch Tensor

import numpy as np
a = np.ones(5)
b = torch.from_numpy(a)
np.add(a, 1, out=a)
print(a)
[2. 2. 2. 2. 2.]
b
tensor([2., 2., 2., 2., 2.], dtype=torch.float64)

所有CPU上的Tensor都支持转成numpy或者从numpy转成Tensor。

CUDA Tensors

使用.to方法,Tensor可以被移动到别的device上。

if torch.cuda.is_available():
    device = torch.device("cuda")
    y = torch.ones_like(x, device=device)
    x = x.to(device)
    z = x + y
    print(z)
    print(z.to("cpu", torch.double))
    
tensor([[ 1.6418,  1.6175,  0.5502,  1.3640],
        [ 2.1340,  1.4898, -0.6624,  0.8195],
        [ 1.1124,  0.4763,  1.6692,  0.7468],
        [ 1.5561,  1.8064,  3.0955,  1.2386]], device='cuda:0')
tensor([[ 1.6418,  1.6175,  0.5502,  1.3640],
        [ 2.1340,  1.4898, -0.6624,  0.8195],
        [ 1.1124,  0.4763,  1.6692,  0.7468],
        [ 1.5561,  1.8064,  3.0955,  1.2386]], dtype=torch.float64)
y.to("cpu").data.numpy()
y.cpu().data.numpy()
array([[1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.]], dtype=float32)
model = model.cuda()
model
---------------------------------------------------------------------------

NameError                                 Traceback (most recent call last)

<ipython-input-51-46b061b39714> in <module>
----> 1 model = model.cuda()
      2 model


NameError: name 'model' is not defined

用numpy实现两层神经网络

一个全连接ReLU神经网络,一个隐藏层,没有bias。用来从x预测y,使用L2 Loss。

  • h = W 1 X h = W_1X h=W1X
  • a = m a x ( 0 , h ) a = max(0, h) a=max(0,h)
  • y h a t = W 2 a y_{hat} = W_2a yhat=W2a

这一实现完全使用numpy来计算前向神经网络,loss,和反向传播。

  • forward pass
  • loss
  • backward pass

numpy ndarray是一个普通的n维array。它不知道任何关于深度学习或者梯度(gradient)的知识,也不知道计算图(computation graph),只是一种用来计算数学运算的数据结构。

N, D_in , H , D_out =64, 1000, 100, 10

#随即创建一些训练数据
x = np.random.randn(N, D_in)
y = np.random.randn(N, D_out)

w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)

learning_rate = 1e-6
for t in range(500):
    #forward pass
    h = x.dot(w1) # N * H
    h_relu = np.maximum(h, 0) # N * H
    y_pred = h_relu.dot(w2) # N * D_out
    
    #compute loss
    loss = np.square(y_pred - y).sum()
    print(t, loss)
    
    #backward pass
    #compute the gradient
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.T.dot(grad_y_pred)
    grad_h_relu = grad_y_pred.dot(w2.T)
    grad_h = grad_h_relu.copy()
    grad_h[h<0] = 0
    grad_w1 = x.T.dot(grad_h)
    
    
    #update weights of w1 and w2
    
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2
0 25334215.287635807
1 20374858.530391537
2 19154101.065471698
3 19263551.278936088
4 19040296.233091436
5 17497664.237746056
6 14454281.222012255
7 10739506.65697521
8 7289156.248603398
9 4694186.699477899
10 2967377.5432831226
11 1906598.6969797632
12 1272580.0355528235
13 894000.0129839063
14 661958.8451412251
15 513929.85777407093
16 414565.65026462806
17 344251.30231047235
18 291915.349179358
19 251272.61332286143
20 218689.38517518342
21 191872.31077442865
22 169362.80764838745
23 150215.5373069444
24 133741.25360490504
25 119476.14944955244
26 107058.39578879708
27 96181.16752371681
28 86613.81510498092
29 78166.28884161207
30 70693.74561783693
31 64060.97710734237
32 58150.32434206815
33 52871.67172027933
34 48146.78406894338
35 43905.06551627368
36 40094.12110174494
37 36664.65013049503
38 33571.74288191136
39 30777.79875115835
40 28250.69532578625
41 25962.43073056632
42 23887.88342833115
43 22004.995571225365
44 20292.761371698038
45 18733.866168958553
46 17313.46061632143
47 16017.043635489896
48 14831.764806976491
49 13746.90374076042
50 12752.355161261938
51 11839.74445628274
52 11001.589083176736
53 10230.7931073741
54 9521.456005977543
55 8868.059024709655
56 8265.353465171358
57 7708.9968200856965
58 7194.965305001253
59 6719.88398089461
60 6280.014637857247
61 5873.129862867207
62 5496.1852484231185
63 5146.449936644871
64 4822.037217732374
65 4520.565905319106
66 4240.231010609592
67 3979.3998353274173
68 3736.5178442950655
69 3510.205564729258
70 3299.3723255959476
71 3102.6808444579488
72 2919.0242293898646
73 2747.502199084057
74 2587.1646097547086
75 2437.1843961922536
76 2296.8253142022004
77 2165.390833302422
78 2042.505303111871
79 1927.4749099049677
80 1819.5763639515817
81 1718.352923188214
82 1623.305474358471
83 1534.068552526478
84 1450.2114687374653
85 1371.3957764145089
86 1297.246676100891
87 1227.549598101267
88 1162.0820189477251
89 1100.455249012592
90 1042.396519545738
91 987.6797432452919
92 936.0742187050621
93 887.3843860600244
94 841.4477011631107
95 798.0888454513869
96 757.1349524674558
97 718.4542064444188
98 681.9041221982977
99 647.361042720042
100 614.6870231687881
101 583.7802862791442
102 554.5424255030925
103 526.8644268120976
104 500.66751341991204
105 475.85869622059784
106 452.36508877166375
107 430.10280640965664
108 409.0167786007788
109 389.0186491466657
110 370.0582005223282
111 352.077347719442
112 335.0209419968366
113 318.8370269875089
114 303.4810577624329
115 288.9089112188498
116 275.0702763381073
117 261.93280421734073
118 249.45379278036242
119 237.59770777591865
120 226.33454331032152
121 215.63080034005475
122 205.45876226415658
123 195.78849916183148
124 186.59920797306393
125 177.85998667985905
126 169.55186901635133
127 161.6436700363447
128 154.12076088592244
129 146.9621078219012
130 140.15104955748464
131 133.66782801380995
132 127.49674846132552
133 121.62244347239431
134 116.0303405166996
135 110.70587289034268
136 105.6339388788144
137 100.80261419185247
138 96.2007333550909
139 91.81610334699522
140 87.637806791804
141 83.65640393465992
142 79.86372898222554
143 76.2485209088295
144 72.80314714392738
145 69.51708876891166
146 66.38407006492528
147 63.396514570521504
148 60.54755446492679
149 57.83076301257011
150 55.23904791601904
151 52.767356398979004
152 50.40979919893606
153 48.160677518772125
154 46.01470873017509
155 43.96655200833007
156 42.012503712950114
157 40.14704239244712
158 38.36803093694154
159 36.66907807188697
160 35.047325312613346
161 33.49877691503397
162 32.02027749102027
163 30.60868880721456
164 29.26129685151092
165 27.974691607167713
166 26.74596406381223
167 25.5723248213664
168 24.451276263231552
169 23.380343159032208
170 22.35745285666475
171 21.380202482632278
172 20.44713699488127
173 19.556006429926953
174 18.704465801697285
175 17.890743833340114
176 17.11297673816101
177 16.369754580709795
178 15.659362564112072
179 14.980507348363846
180 14.331595087247905
181 13.711428279017573
182 13.118542690030957
183 12.551750103002467
184 12.01001844935579
185 11.492115809997001
186 10.996927381783173
187 10.523731880888228
188 10.070988744209336
189 9.63788516493964
190 9.223781012420899
191 8.827778574335346
192 8.449043172899524
193 8.086866684034884
194 7.740526640765878
195 7.4093664511673225
196 7.092460933701248
197 6.789323524882154
198 6.499335484406904
199 6.221873382357317
200 5.956678490554909
201 5.702771118847534
202 5.459799126526686
203 5.227341233970224
204 5.004978720435554
205 4.792261435483557
206 4.588637708750726
207 4.393767343322392
208 4.207315596093808
209 4.028866636266113
210 3.8580624013445286
211 3.6946174796161793
212 3.5381977948322265
213 3.388628044152935
214 3.2453254036487804
215 3.108177116148267
216 2.9768802053857755
217 2.8511979626340747
218 2.7308894745331367
219 2.6157180750843265
220 2.505447622574158
221 2.3998758740385338
222 2.2988235037189426
223 2.2020783005939286
224 2.1094250086315345
225 2.020795880778742
226 1.9359125554032537
227 1.8545913583539053
228 1.7767230511111665
229 1.7021794684827787
230 1.6307750165625343
231 1.5624125374421238
232 1.496933791586528
233 1.4342299107952625
234 1.3741832897358774
235 1.3166701094275295
236 1.2616093438118203
237 1.2088893167299952
238 1.1583677956055443
239 1.1099717063870067
240 1.06363375874979
241 1.0192290768117487
242 0.9766972428617544
243 0.9359655062235948
244 0.896946596515714
245 0.8595614977713766
246 0.8237640693325716
247 0.7894665578082823
248 0.7566069404031577
249 0.7251473096927634
250 0.6949900244461539
251 0.6660982640975344
252 0.6384134623778901
253 0.6118854314667179
254 0.5864715012376099
255 0.5621237195751553
256 0.5387973936169649
257 0.5164492130995048
258 0.4950304601107148
259 0.4745048777091669
260 0.4548517138164056
261 0.4360108987974487
262 0.41795529679841203
263 0.40064930542150146
264 0.3840651840432217
265 0.3681741601089238
266 0.35294795814725355
267 0.3383586387542944
268 0.3243737257828931
269 0.3109685201131789
270 0.29812203400735643
271 0.28581729529120814
272 0.27402028980935
273 0.2627120717125866
274 0.2518716080701229
275 0.2414825249503969
276 0.23152434315028145
277 0.22197985994093244
278 0.21283371662507466
279 0.20406477117942012
280 0.19565853924141252
281 0.1876038317723699
282 0.17988707487745573
283 0.1724850219203954
284 0.16538763550345265
285 0.15858574272733406
286 0.15206382812592006
287 0.14581138343379163
288 0.1398211360116848
289 0.13407721694570685
290 0.12856899487487736
291 0.1232890039417609
292 0.11822861057825637
293 0.11337745584346702
294 0.10872585818801292
295 0.10426529162831612
296 0.09998831941461715
297 0.09588778786809343
298 0.0919563479440833
299 0.08818826277368867
300 0.0845740561187151
301 0.08110950808568776
302 0.07778737321052365
303 0.0746040562320697
304 0.07155009442247652
305 0.0686217812218026
306 0.06581370329679459
307 0.06312087123092486
308 0.060538771913194234
309 0.058063362291236195
310 0.0556898089966511
311 0.05341296056894223
312 0.051229998008646846
313 0.04913823912933135
314 0.04713074565826449
315 0.045205659738035926
316 0.04336027165138762
317 0.041589739810210205
318 0.039891703371243196
319 0.0382637503195946
320 0.03670271927720279
321 0.03520534281000358
322 0.03376921814784341
323 0.03239316262092825
324 0.03107268665937482
325 0.029805831219471234
326 0.0285911747685064
327 0.027426457229259342
328 0.02630902754660684
329 0.02523731288561953
330 0.02420943257621564
331 0.023223789373129283
332 0.02227841991357767
333 0.02137213826815304
334 0.020502261593485886
335 0.019667948636840227
336 0.01886770463433915
337 0.018100148170853404
338 0.017364059467052127
339 0.01665794531451906
340 0.015980745797478504
341 0.015331104817876067
342 0.01470834083643782
343 0.014110854808525872
344 0.013537526587532987
345 0.012987522212931316
346 0.012459986342134726
347 0.011953858609721152
348 0.011468496336575587
349 0.011002995464967424
350 0.010556302237280091
351 0.010127815877428937
352 0.009717180350438195
353 0.009322978705188698
354 0.008944682136274952
355 0.008581869400019542
356 0.00823386117592606
357 0.007899937327970195
358 0.007579613746223686
359 0.0072723449497759905
360 0.0069776103346947515
361 0.0066949305132009695
362 0.00642371218519191
363 0.006163431577827152
364 0.005913800196430723
365 0.005674282407486013
366 0.005444427944936273
367 0.005223943446347
368 0.005012427217802753
369 0.004809467831321666
370 0.004614823062870745
371 0.004428190289122751
372 0.004248980286427807
373 0.004077040240836071
374 0.003912107956869145
375 0.003753905814461201
376 0.003602045390968732
377 0.0034563612830528535
378 0.0033165935885892963
379 0.0031824888668309003
380 0.0030539079823533044
381 0.0029304405792940505
382 0.0028120129591956387
383 0.002698355985630814
384 0.002589314753388437
385 0.0024846831437707584
386 0.002384319356652268
387 0.0022879938270884363
388 0.002195584658723335
389 0.002106976166248489
390 0.0020219009196118373
391 0.0019402609513182963
392 0.0018619191512416616
393 0.0017867618697377814
394 0.0017146399382584516
395 0.0016454302797030704
396 0.0015790217470142889
397 0.0015153127866680803
398 0.001454233871695838
399 0.0013955589316798833
400 0.0013392636424871228
401 0.0012852397298847525
402 0.001233401763519252
403 0.0011836581487782486
404 0.0011359311808256198
405 0.001090149637126118
406 0.001046205401478894
407 0.0010040557494509772
408 0.0009635978870914478
409 0.0009247647044230064
410 0.000887494988920557
411 0.0008517289752057796
412 0.0008174130268166096
413 0.000784476596845799
414 0.0007528717966589211
415 0.0007225533729159693
416 0.000693472546986185
417 0.0006655394315200123
418 0.0006387409048014005
419 0.0006130254124941684
420 0.0005883457567994732
421 0.000564660382725073
422 0.0005419298458532028
423 0.0005201136067256363
424 0.000499188194677741
425 0.0004791040300932137
426 0.0004598272525339332
427 0.0004413308778895235
428 0.0004235752254291101
429 0.000406533409697829
430 0.00039017815354769336
431 0.00037448638364704273
432 0.00035942530194821266
433 0.00034498053221839256
434 0.0003311058115281321
435 0.0003177915421709333
436 0.00030501336631993714
437 0.00029275072124081397
438 0.0002809848089917374
439 0.0002696905184368078
440 0.00025885005281103796
441 0.0002484496260651409
442 0.00023846970720111386
443 0.00022888740350260136
444 0.00021968907893003906
445 0.00021086168963789113
446 0.0002023913659737589
447 0.0001942617426980819
448 0.00018645841193800977
449 0.00017897071658953535
450 0.00017178717232827704
451 0.0001648874852404425
452 0.00015826594535648398
453 0.00015191149884200494
454 0.00014581353294725518
455 0.00013995921752531823
456 0.00013434034659954317
457 0.0001289471134513853
458 0.00012377364303487812
459 0.00011880636343506332
460 0.00011403835476002077
461 0.00010946214356958942
462 0.00010506926719671814
463 0.00010085288260742702
464 9.680597411858776e-05
465 9.292254941511825e-05
466 8.919686492947959e-05
467 8.561968099665454e-05
468 8.218521402125142e-05
469 7.888894391076981e-05
470 7.572466482028333e-05
471 7.268721556866596e-05
472 6.977307152461972e-05
473 6.697503860908728e-05
474 6.429013277390669e-05
475 6.171340499052229e-05
476 5.923935484865936e-05
477 5.6864834685891744e-05
478 5.4585026232469776e-05
479 5.2396789700975556e-05
480 5.029655757472796e-05
481 4.828062177178192e-05
482 4.6346139268741796e-05
483 4.448996177321761e-05
484 4.2707215701204954e-05
485 4.099645252965485e-05
486 3.9353763534080844e-05
487 3.777692152287511e-05
488 3.626353812892362e-05
489 3.481092277564877e-05
490 3.341684177193813e-05
491 3.207892006099393e-05
492 3.079401388119994e-05
493 2.9560543931996793e-05
494 2.837685769264627e-05
495 2.724061228910188e-05
496 2.6149668297050352e-05
497 2.510248786347572e-05
498 2.409741010054895e-05
499 2.3133157983955057e-05
N, D_in, H, D_out = 64, 1000, 100, 10

# 随机创建一些训练数据
x = np.random.randn(N, D_in)
y = np.random.randn(N, D_out)

w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)

learning_rate = 1e-6
for it in range(500):
    # Forward pass
    h = x.dot(w1) # N * H
    h_relu = np.maximum(h, 0) # N * H
    y_pred = h_relu.dot(w2) # N * D_out
    
    # compute loss
    loss = np.square(y_pred - y).sum()
    print(it, loss)
    
    # Backward pass
    # compute the gradient
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.T.dot(grad_y_pred)
    grad_h_relu = grad_y_pred.dot(w2.T)
    grad_h = grad_h_relu.copy()
    grad_h[h<0] = 0
    grad_w1 = x.T.dot(grad_h)
    
    # update weights of w1 and w2
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2
0 31597565.807145506
1 28079489.807411373
2 30286103.139230084
3 32938103.40633983
4 31060312.953437418
5 23591961.5869837
6 14104364.598467454
7 7197577.685750365
8 3543090.2497073784
9 1907582.096372563
10 1187431.494283868
11 845038.7177934409
12 657544.9834643572
13 538540.5455826041
14 453339.8572324506
15 387548.24541849294
16 334570.40055113274
17 290893.55421129346
18 254315.84619625227
19 223394.5451185491
20 197085.5431816586
21 174540.98269047114
22 155167.78421306403
23 138391.0928132686
24 123796.47606366324
25 111032.14761959936
26 99844.14276953277
27 89989.7336206896
28 81283.90690007518
29 73566.9720319214
30 66706.31111640025
31 60587.472024584335
32 55118.330152509254
33 50218.25779192675
34 45826.492954582056
35 41875.52631039062
36 38314.32724728195
37 35097.074012477475
38 32185.457519968124
39 29550.381743841663
40 27159.697622440388
41 24987.219865371393
42 23010.015641716767
43 21207.31209197924
44 19561.708775736704
45 18057.610663428164
46 16681.445966191917
47 15421.002824830968
48 14265.03723138246
49 13204.413056592552
50 12230.152704431828
51 11334.172249338571
52 10510.268409824534
53 9751.638747394154
54 9052.382271594579
55 8404.80902496801
56 7807.724665527439
57 7256.137171418995
58 6746.643323543885
59 6278.073964758552
60 5844.415220830577
61 5443.275746045911
62 5071.8440722835585
63 4727.754156400912
64 4408.718364551612
65 4112.8945903226595
66 3838.158138747932
67 3583.108135677134
68 3346.049053541465
69 3125.7419391915187
70 2920.781163223124
71 2730.16855154828
72 2552.9562062928844
73 2387.9197681552478
74 2234.2227956971974
75 2091.0181360303523
76 1957.5372778287879
77 1833.014662671584
78 1716.9430962994284
79 1608.6091190391367
80 1507.4271557342722
81 1413.0225572402596
82 1324.8829829115525
83 1242.4587019581013
84 1165.4083659970038
85 1093.3747437183165
86 1026.0540296880386
87 963.0716992094443
88 904.100583400209
89 848.929946568556
90 797.3215970994083
91 748.9641394730661
92 703.6509566710063
93 661.1934137494043
94 621.4452258036113
95 584.1626277231894
96 549.2113852157568
97 516.4275790098668
98 485.66996837552006
99 456.8102214641775
100 429.7313781582611
101 404.32626421882264
102 380.4880276813551
103 358.0965489158332
104 337.0938082473366
105 317.35049282230085
106 298.81827196450575
107 281.41389381891895
108 265.06177691775133
109 249.69955797456583
110 235.26111240676062
111 221.6867026714699
112 208.9141030544568
113 196.9048193627038
114 185.61605609732348
115 174.99669790748334
116 164.99865686346237
117 155.58910720376468
118 146.74340569350755
119 138.41264064917692
120 130.56702168032868
121 123.16860867120181
122 116.20489583021345
123 109.64842556496555
124 103.46934646200087
125 97.65163544280708
126 92.17000854089764
127 87.00786546938292
128 82.13893750059043
129 77.55014558047262
130 73.22513768095547
131 69.15023567888296
132 65.30762054250481
133 61.68395578851212
134 58.26591921614971
135 55.0411147197825
136 52.00298877484919
137 49.134613079066376
138 46.42796035392694
139 43.874034530150276
140 41.465657176586205
141 39.19272755503084
142 37.0466209165635
143 35.020414780290295
144 33.106894349783396
145 31.302130782815656
146 29.596825036623674
147 27.987016151882997
148 26.466267728345514
149 25.029839129918095
150 23.67422730062444
151 22.392583585687476
152 21.181706581352515
153 20.03738142267923
154 18.956815018139288
155 17.936102580526693
156 16.970811358510225
157 16.05833414716448
158 15.19618460875444
159 14.381217470125671
160 13.610990598953908
161 12.88234679323368
162 12.193592209502157
163 11.542164372236021
164 10.926553038800062
165 10.344368737013935
166 9.793433582468172
167 9.272394760593668
168 8.779376198753301
169 8.313533210481083
170 7.872528831246813
171 7.455235607679949
172 7.060316598510679
173 6.6867721614966005
174 6.333422394310859
175 5.999098644669621
176 5.682446803860494
177 5.382956709710728
178 5.099404136947674
179 4.8311023634757015
180 4.577096717045771
181 4.3365399841061905
182 4.108819233704848
183 3.89315507377259
184 3.68925164639867
185 3.4959895376765298
186 3.3129915538331796
187 3.1396033858447794
188 2.975490960171686
189 2.8201009812974402
190 2.6729452915275496
191 2.5334751684057046
192 2.401467773706234
193 2.2763626955753717
194 2.157888955317783
195 2.0456384254644644
196 1.939274612360872
197 1.8384913218606285
198 1.7430432680946177
199 1.6526026559477704
200 1.566959709939377
201 1.4857279776722736
202 1.4087779921903971
203 1.3358387392110767
204 1.2667130555957786
205 1.2012530763870417
206 1.1391708894083343
207 1.0803292476907056
208 1.0245588612411622
209 0.9716853960695578
210 0.9216001289958042
211 0.8741177312141296
212 0.8290772444743862
213 0.7863821620779994
214 0.7459055889255748
215 0.707554992966772
216 0.6711944436231445
217 0.6367046789490501
218 0.6039954202261099
219 0.5729884021329463
220 0.5435880218709657
221 0.5157212857988482
222 0.4892730686439616
223 0.4642017304458196
224 0.44042454248466983
225 0.4178739970026196
226 0.3964943131001646
227 0.37620976269576895
228 0.35696843680115864
229 0.3387261365997246
230 0.32141459951807316
231 0.30500245845185026
232 0.2894409444968663
233 0.2746695828103276
234 0.26065286604645044
235 0.24735822622896284
236 0.2347463061674509
237 0.22279141696686938
238 0.21144364672950422
239 0.2006747899963951
240 0.19046058952465703
241 0.1807674202555531
242 0.17157485535765893
243 0.1628496501778489
244 0.1545694422654706
245 0.14671150347587303
246 0.13925867049082277
247 0.1321870792764574
248 0.12548036262635948
249 0.11911011080471992
250 0.11306557887226854
251 0.10732795980066841
252 0.10188461953413473
253 0.09672167883220786
254 0.09182050570844008
255 0.08716687187971987
256 0.08275218497178044
257 0.07856208686100846
258 0.07458539303111852
259 0.0708119823038599
260 0.06722859685292391
261 0.06382713283520294
262 0.06059838065937535
263 0.05753427486204631
264 0.05462658865371925
265 0.05186765873673116
266 0.04924735450307155
267 0.04675956929510704
268 0.04439755489648213
269 0.042156351172044026
270 0.04002957404324313
271 0.03800914716659075
272 0.036091132108693025
273 0.03427056689650086
274 0.032542762201626094
275 0.030902743188357812
276 0.02934517860345364
277 0.02786596503804529
278 0.026461996681195574
279 0.02512875994623632
280 0.02386312845683796
281 0.02266188522168401
282 0.021520788745339185
283 0.02043765283074779
284 0.01940915039923162
285 0.018432392568252874
286 0.01750516967002941
287 0.01662514130784462
288 0.01578884103934663
289 0.014994948301492703
290 0.014241064000583805
291 0.01352518720783814
292 0.012845953602895234
293 0.012200625873809745
294 0.011587626221406064
295 0.011005427255654662
296 0.010452676608382696
297 0.009927873085670197
298 0.009429590975425558
299 0.008956123157375269
300 0.008506642515476847
301 0.008079935494658193
302 0.0076746269735815205
303 0.007289830365893044
304 0.006924121303848243
305 0.006576865037395724
306 0.006247035922304687
307 0.005933786295552544
308 0.0056363019303708375
309 0.005353894911356894
310 0.005085619117416969
311 0.004830740294557994
312 0.004588760252137112
313 0.004358836564065572
314 0.00414062925804401
315 0.003933296103277345
316 0.003736299493238851
317 0.0035492088034091767
318 0.003371490219027961
319 0.0032027496086899903
320 0.003042505708346286
321 0.002890247802024095
322 0.0027456237187764598
323 0.002608248993457238
324 0.0024777843443794384
325 0.0023538415554850305
326 0.002236178950835067
327 0.0021243586101917086
328 0.002018146153604408
329 0.0019172312221077215
330 0.0018213573448919137
331 0.0017303478245103857
332 0.0016438805621220819
333 0.0015617051431819253
334 0.0014836505924261324
335 0.0014095003642586694
336 0.0013390749706005164
337 0.0012722291791287827
338 0.0012086810431179047
339 0.0011482990840535352
340 0.0010909388241306622
341 0.0010364662962211726
342 0.000984722666912955
343 0.0009355741893337296
344 0.0008888644685272739
345 0.0008444823665394344
346 0.0008023462714322948
347 0.000762290514637186
348 0.0007242573235923483
349 0.0006881240448970477
350 0.0006537828388570909
351 0.0006211610712098699
352 0.0005901717584344897
353 0.0005607269764501985
354 0.0005327748223605441
355 0.0005062023205985773
356 0.0004809548607569782
357 0.00045696578545951886
358 0.0004341744648880087
359 0.00041252806747616626
360 0.00039196444635461485
361 0.00037242670540796406
362 0.0003538593551810756
363 0.00033622032762124895
364 0.00031946458386265385
365 0.0003035478949458371
366 0.0002884213865236895
367 0.00027404772239252224
368 0.00026038986688670633
369 0.00024741594715849233
370 0.00023508880967038626
371 0.00022338182712731665
372 0.000212253178193716
373 0.00020168492305352754
374 0.00019164001123476342
375 0.00018209342887737804
376 0.00017302454145199216
377 0.0001644113491920083
378 0.00015622321069192402
379 0.00014844287222834496
380 0.0001410516804368488
381 0.0001340283152828722
382 0.00012735836417921862
383 0.00012101971417390681
384 0.00011499633993774859
385 0.00010927331984090808
386 0.00010383438436228924
387 9.86660406097472e-05
388 9.375762644419199e-05
389 8.90913951809421e-05
390 8.465752527437288e-05
391 8.044521992690333e-05
392 7.644219291182101e-05
393 7.263948656548763e-05
394 6.902769236266978e-05
395 6.559323292504226e-05
396 6.233184160465803e-05
397 5.9230891088905416e-05
398 5.62842733724913e-05
399 5.3485256759678115e-05
400 5.082598562036606e-05
401 4.8298380575814275e-05
402 4.589619853164927e-05
403 4.361345900905317e-05
404 4.144458093309367e-05
405 3.9384154728070174e-05
406 3.7426315780041754e-05
407 3.556620511394618e-05
408 3.379807331218156e-05
409 3.21178084349399e-05
410 3.052148783830991e-05
411 2.900474948874023e-05
412 2.7562961455702106e-05
413 2.619282252750542e-05
414 2.4890794975888806e-05
415 2.3653534334418585e-05
416 2.247793411563736e-05
417 2.136134550582711e-05
418 2.029981700742502e-05
419 1.9291873413005388e-05
420 1.8333220506315805e-05
421 1.7422220985376215e-05
422 1.655663802389037e-05
423 1.573425937368422e-05
424 1.4952543718513352e-05
425 1.4209597230589208e-05
426 1.350359420485985e-05
427 1.2832736713890408e-05
428 1.2195425112390544e-05
429 1.1589859620511499e-05
430 1.1014299921115875e-05
431 1.0467343026433166e-05
432 9.947439271582704e-06
433 9.453425890996714e-06
434 8.984081678182295e-06
435 8.537940147651637e-06
436 8.11392645957772e-06
437 7.710968777639817e-06
438 7.328102767560694e-06
439 6.964220021801154e-06
440 6.6185490000439925e-06
441 6.2899803676130456e-06
442 5.977738330326263e-06
443 5.681027094530374e-06
444 5.398980314719876e-06
445 5.130984899456161e-06
446 4.876380962459173e-06
447 4.634328388284737e-06
448 4.40427971480433e-06
449 4.185640019760787e-06
450 3.977854464494969e-06
451 3.7804457177215285e-06
452 3.5928894928403697e-06
453 3.4146159079952964e-06
454 3.2451421320008377e-06
455 3.084088194721186e-06
456 2.931063870707191e-06
457 2.7856599653480605e-06
458 2.6474936586859393e-06
459 2.516105409962482e-06
460 2.3912437872486217e-06
461 2.2725863498251613e-06
462 2.1598108188357582e-06
463 2.052678848145079e-06
464 1.9508790859971743e-06
465 1.854093283742596e-06
466 1.762126874632385e-06
467 1.6747050180043608e-06
468 1.5916420057443094e-06
469 1.5127160770604112e-06
470 1.4376921968104319e-06
471 1.3663700797364631e-06
472 1.2985849710780375e-06
473 1.2341721306445045e-06
474 1.1729493448734245e-06
475 1.1148195245938046e-06
476 1.0595404086568715e-06
477 1.0069866732585408e-06
478 9.570407994143825e-07
479 9.09573565597039e-07
480 8.644740799175941e-07
481 8.216239623684362e-07
482 7.808815811019122e-07
483 7.421601513031236e-07
484 7.053590767649718e-07
485 6.703892758912358e-07
486 6.37158053228801e-07
487 6.055827079503043e-07
488 5.755541192093355e-07
489 5.470174452130455e-07
490 5.198947315098625e-07
491 4.941164996709988e-07
492 4.696258018927648e-07
493 4.463522901604017e-07
494 4.2423000053859065e-07
495 4.032036961090774e-07
496 3.8321534344879454e-07
497 3.6421733681246917e-07
498 3.4617490268759104e-07
499 3.290185529707219e-07

PyTorch: Tensors

这次我们使用PyTorch tensors来创建前向神经网络,计算损失,以及反向传播。

一个PyTorch Tensor很像一个numpy的ndarray。但是它和numpy ndarray最大的区别是,PyTorch Tensor可以在CPU或者GPU上运算。如果想要在GPU上运算,就需要把Tensor换成cuda类型。

N, D_in , H , D_out =64, 1000, 100, 10

#随即创建一些训练数据
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

w1 = torch.randn(D_in, H)
w2 = torch.randn(H, D_out)

learning_rate = 1e-6
for t in range(500):
    #forward pass
    h = x.mm(w1) # N * H
    h_relu = h.clamp(min=0) # N * H
    y_pred = h_relu.mm(w2) # N * D_out
    
    #compute loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)
    
    #backward pass
    #compute the gradient
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.T)
    grad_h = grad_h_relu.clone()
    grad_h[h<0] = 0
    grad_w1 = x.t().mm(grad_h)
    
    
    #update weights of w1 and w2
    
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2
0 30821202.0
1 28674406.0
2 29092262.0
3 27564054.0
4 22621974.0
5 15395313.0
6 9156566.0
7 5077433.0
8 2906762.0
9 1812965.75
10 1257466.0
11 951011.125
12 763287.5625
13 635287.375
14 540419.875
15 466047.875
16 405620.75
17 355436.46875
18 313124.40625
19 277171.9375
20 246383.953125
21 219822.3125
22 196799.703125
23 176721.328125
24 159117.734375
25 143634.25
26 129976.3515625
27 117897.53125
28 107162.078125
29 97593.40625
30 89045.3984375
31 81386.140625
32 74498.921875
33 68299.8125
34 62703.71875
35 57658.16015625
36 53090.56640625
37 48945.359375
38 45175.859375
39 41741.79296875
40 38604.15234375
41 35736.73828125
42 33113.71484375
43 30711.556640625
44 28507.4375
45 26482.947265625
46 24621.123046875
47 22907.12890625
48 21328.83984375
49 19873.146484375
50 18529.009765625
51 17286.509765625
52 16136.8095703125
53 15072.66015625
54 14086.2421875
55 13171.8076171875
56 12323.322265625
57 11535.865234375
58 10804.150390625
59 10124.171875
60 9491.26953125
61 8902.451171875
62 8353.650390625
63 7842.248046875
64 7365.56640625
65 6920.96044921875
66 6506.15966796875
67 6119.27685546875
68 5757.8681640625
69 5420.1455078125
70 5104.1171875
71 4808.48583984375
72 4531.66455078125
73 4272.287109375
74 4029.109130859375
75 3801.191650390625
76 3587.4287109375
77 3386.970947265625
78 3198.674560546875
79 3022.67236328125
80 2857.218505859375
81 2701.724853515625
82 2555.427490234375
83 2417.804931640625
84 2288.317138671875
85 2166.52734375
86 2051.769775390625
87 1943.5867919921875
88 1841.6810302734375
89 1745.5994873046875
90 1654.9727783203125
91 1569.553955078125
92 1488.9031982421875
93 1412.754150390625
94 1340.870849609375
95 1272.9654541015625
96 1208.8045654296875
97 1148.1708984375
98 1090.8656005859375
99 1036.671630859375
100 985.4346313476562
101 936.9486083984375
102 891.053955078125
103 847.6085815429688
104 806.493896484375
105 767.5343017578125
106 730.6255493164062
107 695.6507568359375
108 662.4972534179688
109 631.056884765625
110 601.23193359375
111 572.961669921875
112 546.1463012695312
113 520.684326171875
114 496.5182189941406
115 473.5635986328125
116 451.7565002441406
117 431.0503845214844
118 411.36627197265625
119 392.6615295410156
120 374.888671875
121 357.988525390625
122 341.9074401855469
123 326.61126708984375
124 312.055908203125
125 298.2068786621094
126 285.03143310546875
127 272.48046875
128 260.52679443359375
129 249.14720153808594
130 238.30284118652344
131 227.98141479492188
132 218.13040161132812
133 208.73988342285156
134 199.79110717773438
135 191.26290893554688
136 183.12332153320312
137 175.3556671142578
138 167.9490966796875
139 160.876220703125
140 154.1220245361328
141 147.67764282226562
142 141.52392578125
143 135.64617919921875
144 130.03079223632812
145 124.66637420654297
146 119.54234313964844
147 114.64482116699219
148 109.96110534667969
149 105.48971557617188
150 101.21070098876953
151 97.11981201171875
152 93.20227813720703
153 89.45639038085938
154 85.8707504272461
155 82.44464111328125
156 79.16072845458984
157 76.0174331665039
158 73.00721740722656
159 70.12435150146484
160 67.3656005859375
161 64.72277069091797
162 62.18870544433594
163 59.76327896118164
164 57.43769073486328
165 55.209991455078125
166 53.07682800292969
167 51.029754638671875
168 49.06551742553711
169 47.18373107910156
170 45.3774528503418
171 43.645606994628906
172 41.98408508300781
173 40.38923645019531
174 38.85969161987305
175 37.39216232299805
176 35.98359680175781
177 34.63069152832031
178 33.331607818603516
179 32.08498001098633
180 30.88756561279297
181 29.73792266845703
182 28.634796142578125
183 27.57413101196289
184 26.55473518371582
185 25.574743270874023
186 24.63345718383789
187 23.729412078857422
188 22.85962677001953
189 22.02401351928711
190 21.221220016479492
191 20.448314666748047
192 19.70565414428711
193 18.99169158935547
194 18.304136276245117
195 17.643733978271484
196 17.00810432434082
197 16.39716339111328
198 15.808847427368164
199 15.242757797241211
200 14.697977066040039
201 14.173712730407715
202 13.669013977050781
203 13.182851791381836
204 12.715142250061035
205 12.264898300170898
206 11.831254959106445
207 11.413905143737793
208 11.012290954589844
209 10.624683380126953
210 10.25228500366211
211 9.892794609069824
212 9.546954154968262
213 9.21341609954834
214 8.892077445983887
215 8.582769393920898
216 8.284662246704102
217 7.996941566467285
218 7.71997594833374
219 7.452714443206787
220 7.195420742034912
221 6.947061538696289
222 6.707915782928467
223 6.477261066436768
224 6.25474739074707
225 6.040435791015625
226 5.833762168884277
227 5.634443283081055
228 5.4419050216674805
229 5.2565016746521
230 5.077347278594971
231 4.904781341552734
232 4.738302230834961
233 4.577463150024414
234 4.422373294830322
235 4.273029327392578
236 4.128481864929199
237 3.989236831665039
238 3.8546643257141113
239 3.7253384590148926
240 3.59982967376709
241 3.4790711402893066
242 3.3623387813568115
243 3.2497498989105225
244 3.1409313678741455
245 3.0359601974487305
246 2.9346282482147217
247 2.8366596698760986
248 2.7421607971191406
249 2.6508469581604004
250 2.562695264816284
251 2.477661371231079
252 2.395388603210449
253 2.3159914016723633
254 2.2393572330474854
255 2.165213108062744
256 2.093616485595703
257 2.0244529247283936
258 1.9576847553253174
259 1.893256664276123
260 1.8308887481689453
261 1.77060067653656
262 1.712496280670166
263 1.6561808586120605
264 1.601933240890503
265 1.5492973327636719
266 1.4985600709915161
267 1.4496207237243652
268 1.402113437652588
269 1.356400966644287
270 1.312066674232483
271 1.269322395324707
272 1.2279443740844727
273 1.187847375869751
274 1.1491739749908447
275 1.1118865013122559
276 1.0758365392684937
277 1.040905237197876
278 1.0071351528167725
279 0.9745337963104248
280 0.9428754448890686
281 0.9123725295066833
282 0.8828024864196777
283 0.8542860746383667
284 0.826766312122345
285 0.7999786734580994
286 0.7741292119026184
287 0.7492832541465759
288 0.7250736951828003
289 0.7017436027526855
290 0.6791509985923767
291 0.6572655439376831
292 0.6361614465713501
293 0.6157830357551575
294 0.5959858298301697
295 0.5768381953239441
296 0.5583407878875732
297 0.540433406829834
298 0.5231001973152161
299 0.5063762664794922
300 0.4901966452598572
301 0.47454091906547546
302 0.45932844281196594
303 0.4446500241756439
304 0.43046051263809204
305 0.41681838035583496
306 0.4034736752510071
307 0.3905998468399048
308 0.37813836336135864
309 0.36606940627098083
310 0.35443535447120667
311 0.3431414067745209
312 0.3322266638278961
313 0.32166072726249695
314 0.31143835186958313
315 0.30154722929000854
316 0.2919841706752777
317 0.2827008068561554
318 0.2737547755241394
319 0.2651146650314331
320 0.2566356956958771
321 0.248525470495224
322 0.24072128534317017
323 0.2330290973186493
324 0.22569483518600464
325 0.2185531109571457
326 0.21167731285095215
327 0.20494616031646729
328 0.19847410917282104
329 0.19225287437438965
330 0.1861550360918045
331 0.18028999865055084
332 0.17460964620113373
333 0.1691409945487976
334 0.1637795865535736
335 0.15864278376102448
336 0.15363363921642303
337 0.1487993448972702
338 0.1441034972667694
339 0.13956747949123383
340 0.13518734276294708
341 0.1309289038181305
342 0.12682877480983734
343 0.12283414602279663
344 0.11901231855154037
345 0.11525852233171463
346 0.11166303604841232
347 0.10811809450387955
348 0.10473602265119553
349 0.10145066678524017
350 0.09827884286642075
351 0.09522039443254471
352 0.09225571900606155
353 0.08936706185340881
354 0.0865907073020935
355 0.08386905491352081
356 0.08124000579118729
357 0.07871467620134354
358 0.07626952975988388
359 0.0738663524389267
360 0.07154399901628494
361 0.06931310147047043
362 0.06715406477451324
363 0.06507933884859085
364 0.06305069476366043
365 0.061085835099220276
366 0.059167586266994476
367 0.057304710149765015
368 0.05554167553782463
369 0.053813301026821136
370 0.05215133726596832
371 0.05052737891674042
372 0.04896567016839981
373 0.047447193413972855
374 0.04598095640540123
375 0.0445450022816658
376 0.043159373104572296
377 0.041825197637081146
378 0.040540993213653564
379 0.03929363191127777
380 0.03807539492845535
381 0.03689442574977875
382 0.03575964272022247
383 0.03465254232287407
384 0.033597681671381
385 0.032543864101171494
386 0.03154915198683739
387 0.030576415359973907
388 0.029638908803462982
389 0.028719531372189522
390 0.027832962572574615
391 0.02697983756661415
392 0.02615041472017765
393 0.025348911061882973
394 0.02456277422606945
395 0.023813020437955856
396 0.023093685507774353
397 0.02237887680530548
398 0.021699219942092896
399 0.021029114723205566
400 0.020388856530189514
401 0.019758347421884537
402 0.01917180046439171
403 0.018578900024294853
404 0.018007956445217133
405 0.01747051067650318
406 0.016933707520365715
407 0.016424883157014847
408 0.01591930352151394
409 0.015438120812177658
410 0.014967124909162521
411 0.014518809504806995
412 0.01408022828400135
413 0.013661021366715431
414 0.013250891119241714
415 0.012856172397732735
416 0.012469397857785225
417 0.012093006633222103
418 0.011725702323019505
419 0.011375300586223602
420 0.011029819026589394
421 0.010706603527069092
422 0.010386413894593716
423 0.010078263469040394
424 0.009775394573807716
425 0.009487164206802845
426 0.00920187309384346
427 0.008926000446081161
428 0.008664819411933422
429 0.008413596078753471
430 0.008159523829817772
431 0.007919560186564922
432 0.007686922792345285
433 0.00746076088398695
434 0.007247535977512598
435 0.00703480513766408
436 0.006826439872384071
437 0.006629243493080139
438 0.006442691199481487
439 0.006254896055907011
440 0.00607301527634263
441 0.005901591386646032
442 0.0057271188125014305
443 0.005564812570810318
444 0.005402186885476112
445 0.005250112619251013
446 0.005099089350551367
447 0.004956612829118967
448 0.0048160189762711525
449 0.004681501537561417
450 0.004548225551843643
451 0.004416907671838999
452 0.004294888582080603
453 0.00417358847334981
454 0.004059929866343737
455 0.003947295248508453
456 0.0038358187302947044
457 0.0037314717192202806
458 0.003628714708611369
459 0.0035308802034705877
460 0.0034328163601458073
461 0.0033394114580005407
462 0.0032483614049851894
463 0.003162553533911705
464 0.0030734280589967966
465 0.0029936330392956734
466 0.0029146212618798018
467 0.002833412028849125
468 0.002760896924883127
469 0.0026881967205554247
470 0.002612858545035124
471 0.002546401694417
472 0.0024806938599795103
473 0.002415820024907589
474 0.002350389724597335
475 0.0022883249912410975
476 0.0022289252374321222
477 0.0021712465677410364
478 0.0021138694137334824
479 0.002061977982521057
480 0.0020073300693184137
481 0.0019576898775994778
482 0.0019066648092120886
483 0.0018596879672259092
484 0.0018120453460142016
485 0.0017658255528658628
486 0.001723453402519226
487 0.0016791753005236387
488 0.0016371491365134716
489 0.0015966123901307583
490 0.0015580977778881788
491 0.0015215465100482106
492 0.0014834373723715544
493 0.0014468058943748474
494 0.0014111835043877363
495 0.0013774464605376124
496 0.0013448935933411121
497 0.0013099341886118054
498 0.0012811265187337995
499 0.0012520885793492198
N, D_in, H, D_out = 64, 1000, 100, 10

# 随机创建一些训练数据
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

w1 = torch.randn(D_in, H)
w2 = torch.randn(H, D_out)

learning_rate = 1e-6
for it in range(500):
    # Forward pass
    h = x.mm(w1) # N * H
    h_relu = h.clamp(min=0) # N * H
    y_pred = h_relu.mm(w2) # N * D_out
    
    # compute loss
    loss = (y_pred - y).pow(2).sum().item()
    print(it, loss)
    
    # Backward pass
    # compute the gradient
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h<0] = 0
    grad_w1 = x.t().mm(grad_h)
    
    # update weights of w1 and w2
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2
0 29762046.0
1 23066852.0
2 18711686.0
3 14801243.0
4 11081590.0
5 7897110.0
6 5449245.0
7 3736755.0
8 2600430.75
9 1863506.625
10 1384443.75
11 1066823.5
12 849056.375
13 694301.3125
14 579745.6875
15 491825.8125
16 422203.4375
17 365664.1875
18 318882.1875
19 279607.6875
20 246269.28125
21 217793.21875
22 193362.53125
23 172216.5
24 153804.609375
25 137681.765625
26 123520.8125
27 111042.7734375
28 100011.8671875
29 90226.9375
30 81530.6328125
31 73786.09375
32 66875.734375
33 60690.84375
34 55149.2578125
35 50171.6015625
36 45695.19140625
37 41664.26953125
38 38026.62890625
39 34739.9609375
40 31765.595703125
41 29070.78515625
42 26625.50390625
43 24404.0390625
44 22385.072265625
45 20547.837890625
46 18877.71875
47 17354.1875
48 15966.14453125
49 14698.5810546875
50 13539.5
51 12478.2666015625
52 11506.1162109375
53 10615.9287109375
54 9800.8994140625
55 9052.6728515625
56 8365.7060546875
57 7733.314453125
58 7152.142578125
59 6617.36279296875
60 6124.8115234375
61 5671.1171875
62 5252.8173828125
63 4866.875
64 4510.9111328125
65 4182.16357421875
66 3878.61083984375
67 3598.0048828125
68 3338.7373046875
69 3099.269775390625
70 2877.714599609375
71 2672.69140625
72 2482.994384765625
73 2307.268798828125
74 2144.48486328125
75 1993.695556640625
76 1853.860107421875
77 1724.239990234375
78 1604.04248046875
79 1492.542236328125
80 1389.138427734375
81 1293.1014404296875
82 1203.9498291015625
83 1121.193115234375
84 1044.33984375
85 972.9322509765625
86 906.5712280273438
87 844.8587646484375
88 787.5062255859375
89 734.164794921875
90 684.6029663085938
91 638.521484375
92 595.64892578125
93 555.75244140625
94 518.6096801757812
95 484.016845703125
96 451.8203125
97 421.8199157714844
98 393.85888671875
99 367.8256530761719
100 343.55609130859375
101 320.93060302734375
102 299.8389587402344
103 280.17230224609375
104 261.82965087890625
105 244.72543334960938
106 228.7706756591797
107 213.89105224609375
108 200.00607299804688
109 187.0375213623047
110 174.9371337890625
111 163.6450958251953
112 153.09640502929688
113 143.24517822265625
114 134.04623413085938
115 125.45079803466797
116 117.42243194580078
117 109.923828125
118 102.91181182861328
119 96.36072540283203
120 90.23654174804688
121 84.51056671142578
122 79.15670013427734
123 74.15044403076172
124 69.4683609008789
125 65.08931732177734
126 60.99584197998047
127 57.16313934326172
128 53.57727813720703
129 50.22196578979492
130 47.08306121826172
131 44.142723083496094
132 41.391536712646484
133 38.81666946411133
134 36.40475082397461
135 34.146507263183594
136 32.032447814941406
137 30.05124855041504
138 28.195178985595703
139 26.457429885864258
140 24.828100204467773
141 23.301620483398438
142 21.87250518798828
143 20.531452178955078
144 19.275754928588867
145 18.09778594970703
146 16.993410110473633
147 15.958050727844238
148 14.98732852935791
149 14.076690673828125
150 13.222716331481934
151 12.42151927947998
152 11.670351028442383
153 10.965628623962402
154 10.304265022277832
155 9.683117866516113
156 9.100723266601562
157 8.55402946472168
158 8.040616035461426
159 7.5590338706970215
160 7.107036590576172
161 6.682579040527344
162 6.284026145935059
163 5.909652233123779
164 5.557806491851807
165 5.227550506591797
166 4.917217254638672
167 4.625788688659668
168 4.352065563201904
169 4.094728946685791
170 3.8532333374023438
171 3.6257545948028564
172 3.4124934673309326
173 3.211496591567993
174 3.0228679180145264
175 2.845728874206543
176 2.679048776626587
177 2.522198438644409
178 2.374821662902832
179 2.2362537384033203
180 2.1057934761047363
181 1.9831745624542236
182 1.8678261041641235
183 1.7593063116073608
184 1.6573617458343506
185 1.5612578392028809
186 1.4708653688430786
187 1.3858541250228882
188 1.3058593273162842
189 1.230438232421875
190 1.1596105098724365
191 1.092834234237671
192 1.0301061868667603
193 0.9709413051605225
194 0.9153381586074829
195 0.8628535866737366
196 0.813481330871582
197 0.7669540047645569
198 0.7231323719024658
199 0.6818897128105164
200 0.6430266499519348
201 0.6064165234565735
202 0.5718790888786316
203 0.5394691824913025
204 0.5088962316513062
205 0.4800434708595276
206 0.45284461975097656
207 0.42723414301872253
208 0.4031059145927429
209 0.3803500831127167
210 0.35893672704696655
211 0.33870431780815125
212 0.31967484951019287
213 0.3017108738422394
214 0.28475069999694824
215 0.26878929138183594
216 0.2536822557449341
217 0.23949642479419708
218 0.22609341144561768
219 0.2134789377450943
220 0.20155112445354462
221 0.1902989149093628
222 0.1796754151582718
223 0.16970567405223846
224 0.1602526158094406
225 0.151326984167099
226 0.1429295688867569
227 0.13499712944030762
228 0.12754875421524048
229 0.12047047913074493
230 0.11378441751003265
231 0.10751700401306152
232 0.10158450156450272
233 0.09595350176095963
234 0.09067893773317337
235 0.08568556606769562
236 0.08095695078372955
237 0.07652147859334946
238 0.07231495529413223
239 0.06834664940834045
240 0.06459417939186096
241 0.06105583906173706
242 0.05770888924598694
243 0.054558444768190384
244 0.05159997195005417
245 0.048756301403045654
246 0.04609570652246475
247 0.043589841574430466
248 0.04121812433004379
249 0.038960859179496765
250 0.03684431314468384
251 0.03485583886504173
252 0.03295939043164253
253 0.0311793964356184
254 0.029478339478373528
255 0.027884500101208687
256 0.026384294033050537
257 0.02496730536222458
258 0.023619506508111954
259 0.022354189306497574
260 0.021145502105355263
261 0.02000790275633335
262 0.018936434760689735
263 0.017926964908838272
264 0.01696084626019001
265 0.016061466187238693
266 0.015196256339550018
267 0.014391188509762287
268 0.013624911196529865
269 0.012900453992187977
270 0.012218968942761421
271 0.011573925614356995
272 0.010961524210870266
273 0.010383285582065582
274 0.009837167337536812
275 0.009322320111095905
276 0.008837325498461723
277 0.008375686593353748
278 0.00793911051005125
279 0.0075232405215501785
280 0.0071358405984938145
281 0.006764193065464497
282 0.00641527259722352
283 0.00608763936907053
284 0.0057768551632761955
285 0.005478579085320234
286 0.005202075466513634
287 0.004938025958836079
288 0.004694466479122639
289 0.0044562676921486855
290 0.004234412685036659
291 0.004024997353553772
292 0.003824956715106964
293 0.003640178358182311
294 0.003460776060819626
295 0.003292034612968564
296 0.0031321286223828793
297 0.0029824054799973965
298 0.0028359838761389256
299 0.0027039656415581703
300 0.002574746496975422
301 0.0024520941078662872
302 0.0023386343382298946
303 0.0022269415203481913
304 0.002125771250575781
305 0.0020305360667407513
306 0.0019376248819753528
307 0.0018483225721865892
308 0.0017664702609181404
309 0.0016856689471751451
310 0.0016127214767038822
311 0.0015427290927618742
312 0.0014718484599143267
313 0.0014098582323640585
314 0.0013507503317669034
315 0.0012916673440486193
316 0.001236538402736187
317 0.0011855436023324728
318 0.0011359271593391895
319 0.0010888624237850308
320 0.0010436526499688625
321 0.001002494478598237
322 0.0009628265397623181
323 0.0009244600078091025
324 0.0008861556416377425
325 0.0008527687168680131
326 0.0008187140920199454
327 0.000786679214797914
328 0.0007560487138107419
329 0.0007279418059624732
330 0.0006995805888436735
331 0.0006734815542586148
332 0.000648274552077055
333 0.0006256361375562847
334 0.0006018236745148897
335 0.0005800087237730622
336 0.0005596462287940085
337 0.0005399221554398537
338 0.0005202004103921354
339 0.0005017591174691916
340 0.0004848327371291816
341 0.00046761537669226527
342 0.00045302699436433613
343 0.00043724518036469817
344 0.00042313701123930514
345 0.00040981153142638505
346 0.0003975038998760283
347 0.00038421081262640655
348 0.00037191042792983353
349 0.0003594216250348836
350 0.00034796350519172847
351 0.00033656222512945533
352 0.0003264406113885343
353 0.00031546747777611017
354 0.0003067828365601599
355 0.0002968181506730616
356 0.00028814078541472554
357 0.0002811422455124557
358 0.00027236260939389467
359 0.00026421257643960416
360 0.00025751403882168233
361 0.00024965431657619774
362 0.000241793692111969
363 0.00023624097229912877
364 0.00023044089903123677
365 0.00022408438962884247
366 0.00021723259123973548
367 0.0002115389797836542
368 0.0002055977238342166
369 0.00020040776871610433
370 0.00019543645612429827
371 0.00019004891510121524
372 0.00018505321349948645
373 0.0001806973887141794
374 0.00017648277571424842
375 0.00017197246779687703
376 0.00016713119111955166
377 0.0001630442129680887
378 0.00015953612455632538
379 0.0001560034288559109
380 0.00015216428437270224
381 0.00014784287486691028
382 0.00014408843708224595
383 0.00014078867388889194
384 0.00013772134843748063
385 0.00013467103417497128
386 0.00013189941819291562
387 0.0001292629021918401
388 0.00012685802357736975
389 0.00012337841326370835
390 0.00012066255294485018
391 0.00011857244680868462
392 0.00011528359755175188
393 0.000112665664346423
394 0.00011032742622774094
395 0.00010776333510875702
396 0.00010584241681499407
397 0.00010347444913350046
398 0.00010182162804994732
399 9.969556413125247e-05
400 9.77654053713195e-05
401 9.581765334587544e-05
402 9.418823174200952e-05
403 9.213548037223518e-05
404 9.012558439280838e-05
405 8.867125143297017e-05
406 8.695671567693353e-05
407 8.490754407830536e-05
408 8.312655700137839e-05
409 8.103609434328973e-05
410 7.975030894158408e-05
411 7.847649976611137e-05
412 7.696000102441758e-05
413 7.55072760512121e-05
414 7.409827230731025e-05
415 7.291947258636355e-05
416 7.175814243964851e-05
417 7.038572221063077e-05
418 6.912941171322018e-05
419 6.782382115488872e-05
420 6.693716568406671e-05
421 6.592227146029472e-05
422 6.481944001279771e-05
423 6.370364280883223e-05
424 6.258935172809288e-05
425 6.171658606035635e-05
426 6.061311432858929e-05
427 5.965070522506721e-05
428 5.834892363054678e-05
429 5.745037560700439e-05
430 5.65473637834657e-05
431 5.5780627008061856e-05
432 5.494891956914216e-05
433 5.410148878581822e-05
434 5.3175201173871756e-05
435 5.213084659771994e-05
436 5.159818465472199e-05
437 5.078062531538308e-05
438 4.9985017540166155e-05
439 4.9173006118508056e-05
440 4.8425557906739414e-05
441 4.765587436850183e-05
442 4.6965040382929146e-05
443 4.621029074769467e-05
444 4.5560092985397205e-05
445 4.501940566115081e-05
446 4.41442389274016e-05
447 4.375029675429687e-05
448 4.322260429034941e-05
449 4.238531255396083e-05
450 4.197190355625935e-05
451 4.1233062802348286e-05
452 4.061531217303127e-05
453 4.0216145862359554e-05
454 3.9599675801582634e-05
455 3.9136128179961815e-05
456 3.86391366191674e-05
457 3.810250564129092e-05
458 3.775205186684616e-05
459 3.682448004838079e-05
460 3.6464942240854725e-05
461 3.61634956789203e-05
462 3.56695891241543e-05
463 3.5050823498750106e-05
464 3.4349395718891174e-05
465 3.391788777662441e-05
466 3.353248393977992e-05
467 3.335763904033229e-05
468 3.2964871934382245e-05
469 3.2450650905957446e-05
470 3.203255982953124e-05
471 3.161902350257151e-05
472 3.117786400252953e-05
473 3.097375883953646e-05
474 3.052403553738259e-05
475 3.0225233786040917e-05
476 2.9853252272005193e-05
477 2.9447268389048986e-05
478 2.9190603527240455e-05
479 2.87192397081526e-05
480 2.8442493203328922e-05
481 2.807253804348875e-05
482 2.7960741135757416e-05
483 2.7481011784402654e-05
484 2.706970553845167e-05
485 2.6672611056710593e-05
486 2.66089009528514e-05
487 2.6334080757806078e-05
488 2.6062687538797036e-05
489 2.568760100984946e-05
490 2.5465848011663184e-05
491 2.506902819732204e-05
492 2.497553032299038e-05
493 2.477494126651436e-05
494 2.4436470994260162e-05
495 2.430865788483061e-05
496 2.4154409402399324e-05
497 2.4027987819863483e-05
498 2.394388684479054e-05
499 2.353984564251732e-05

简单的autograd

x = torch.tensor(1., requires_grad=True)
w = torch.tensor(2., requires_grad=True)
b = torch.tensor(3., requires_grad=True)

y = w*x + b # y = 2*1+3

y.backward()

# dy / dw = x
print(w.grad)

print(x.grad)

print(b.grad)

tensor(1.)
tensor(2.)
tensor(1.)

PyTorch: Tensor和autograd

PyTorch的一个重要功能就是autograd,也就是说只要定义了forward pass(前向神经网络),计算了loss之后,PyTorch可以自动求导计算模型所有参数的梯度。

一个PyTorch的Tensor表示计算图中的一个节点。如果x是一个Tensor并且x.requires_grad=True那么x.grad是另一个储存着x当前梯度(相对于一个scalar,常常是loss)的向量。

N, D_in, H, D_out = 64, 1000, 100, 10

# 随机创建一些训练数据
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

w1 = torch.randn(D_in, H, requires_grad=True)
w2 = torch.randn(H, D_out, requires_grad=True)

learning_rate = 1e-6
for it in range(500):
    # Forward pass
    y_pred = x.mm(w1).clamp(min=0).mm(w2)
#     h = x.mm(w1) # N * H
#     h_relu = h.clamp(min=0) # N * H
#     y_pred = h_relu.mm(w2) 
    
    
    # compute loss
    loss = (y_pred - y).pow(2).sum() # computation graph
    print(it, loss.item())
    
    # Backward pass
    loss.backward()
    
    # update weights of w1 and w2
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad
        w1.grad.zero_()
        w2.grad.zero_()
0 28603812.0
1 21509956.0
2 17139246.0
3 13368428.0
4 9967077.0
5 7080959.0
6 4908604.0
7 3384567.5
8 2376112.25
9 1717356.0
10 1287226.25
11 999041.8125
12 799757.0625
13 656432.875
14 549429.375
15 466880.71875
16 401173.71875
17 347640.1875
18 303300.5625
19 266035.71875
20 234399.546875
21 207355.28125
22 184061.296875
23 163880.171875
24 146348.625
25 131007.640625
26 117544.3984375
27 105702.03125
28 95252.3984375
29 86000.84375
30 77779.984375
31 70458.3046875
32 63923.94921875
33 58076.515625
34 52835.26953125
35 48126.80078125
36 43887.1640625
37 40064.92578125
38 36617.53125
39 33501.04296875
40 30679.677734375
41 28122.765625
42 25801.435546875
43 23692.06640625
44 21771.732421875
45 20022.263671875
46 18426.498046875
47 16969.8203125
48 15638.943359375
49 14422.6962890625
50 13309.759765625
51 12290.703125
52 11357.62109375
53 10501.95703125
54 9716.404296875
55 8995.1220703125
56 8332.451171875
57 7723.36279296875
58 7162.19384765625
59 6644.8291015625
60 6167.48974609375
61 5727.1201171875
62 5320.66845703125
63 4945.279296875
64 4597.9931640625
65 4276.81494140625
66 3979.608154296875
67 3704.691162109375
68 3449.92626953125
69 3213.78857421875
70 2994.93359375
71 2791.93359375
72 2603.681640625
73 2428.797607421875
74 2266.406005859375
75 2115.59228515625
76 1975.464599609375
77 1845.5513916015625
78 1724.6732177734375
79 1611.1732177734375
80 1505.619384765625
81 1407.3436279296875
82 1315.8502197265625
83 1230.65966796875
84 1151.3043212890625
85 1077.4136962890625
86 1008.4945678710938
87 944.1804809570312
88 884.1934814453125
89 828.2359619140625
90 775.979248046875
91 727.1993408203125
92 681.6557006835938
93 639.1015625
94 599.3277587890625
95 562.1483764648438
96 527.3839111328125
97 494.88507080078125
98 464.4801940917969
99 436.0336608886719
100 409.39739990234375
101 384.4566345214844
102 361.1117248535156
103 339.2373046875
104 318.7471923828125
105 299.56396484375
106 281.58721923828125
107 264.7126770019531
108 248.8983612060547
109 234.06085205078125
110 220.15518188476562
111 207.1001434326172
112 194.85372924804688
113 183.36294555664062
114 172.5733642578125
115 162.43841552734375
116 152.9230194091797
117 143.98126220703125
118 135.5908203125
119 127.70321655273438
120 120.28816223144531
121 113.3166732788086
122 106.76827239990234
123 100.60713195800781
124 94.81292724609375
125 89.36551666259766
126 84.24512481689453
127 79.42481994628906
128 74.88658905029297
129 70.61520385742188
130 66.59761810302734
131 62.81562042236328
132 59.25368881225586
133 55.90215301513672
134 52.74388885498047
135 49.768306732177734
136 46.96535110473633
137 44.325904846191406
138 41.84016799926758
139 39.49709701538086
140 37.28816223144531
141 35.20510482788086
142 33.243961334228516
143 31.39250373840332
144 29.647743225097656
145 28.001697540283203
146 26.451231002807617
147 24.988101959228516
148 23.607093811035156
149 22.304035186767578
150 21.075408935546875
151 19.915746688842773
152 18.820959091186523
153 17.78841209411621
154 16.814485549926758
155 15.893689155578613
156 15.025279998779297
157 14.204578399658203
158 13.430122375488281
159 12.698761940002441
160 12.0077543258667
161 11.355605125427246
162 10.73978042602539
163 10.157876968383789
164 9.607673645019531
165 9.08797836303711
166 8.597233772277832
167 8.13359260559082
168 7.695069313049316
169 7.280887603759766
170 6.88947057723999
171 6.519353866577148
172 6.169189453125
173 5.838376998901367
174 5.525732040405273
175 5.23012113571167
176 4.950403690338135
177 4.686232089996338
178 4.436119556427002
179 4.19944429397583
180 3.9758095741271973
181 3.7641453742980957
182 3.5642166137695312
183 3.375070095062256
184 3.195831060409546
185 3.0263314247131348
186 2.8660569190979004
187 2.714259147644043
188 2.5708301067352295
189 2.435189723968506
190 2.3065028190612793
191 2.1847314834594727
192 2.0695741176605225
193 1.9606516361236572
194 1.85748291015625
195 1.7598847150802612
196 1.6673800945281982
197 1.579771637916565
198 1.4970272779464722
199 1.4184705018997192
200 1.3440656661987305
201 1.2737298011779785
202 1.2071841955184937
203 1.1440441608428955
204 1.0843579769134521
205 1.0275739431381226
206 0.9739753603935242
207 0.9232385158538818
208 0.8751430511474609
209 0.8295402526855469
210 0.7863116264343262
211 0.7455581426620483
212 0.7067188620567322
213 0.6700259447097778
214 0.635224461555481
215 0.6022682189941406
216 0.5710885524749756
217 0.5414544343948364
218 0.5133770704269409
219 0.48688995838165283
220 0.46155205368995667
221 0.4376630485057831
222 0.4151599109172821
223 0.3936854898929596
224 0.3733852803707123
225 0.35407060384750366
226 0.3358007073402405
227 0.31849658489227295
228 0.3021582067012787
229 0.2865024209022522
230 0.27178332209587097
231 0.25776925683021545
232 0.24457751214504242
233 0.2320222556591034
234 0.22005896270275116
235 0.20875762403011322
236 0.1980513334274292
237 0.18790960311889648
238 0.17824923992156982
239 0.16914011538028717
240 0.16048070788383484
241 0.15228711068630219
242 0.14448371529579163
243 0.13711956143379211
244 0.13010428845882416
245 0.12341280281543732
246 0.11708761006593704
247 0.11110865324735641
248 0.1054525226354599
249 0.1000976487994194
250 0.09495141357183456
251 0.0901293009519577
252 0.0855506956577301
253 0.08119291067123413
254 0.07704810798168182
255 0.07311572134494781
256 0.06938790529966354
257 0.06589090079069138
258 0.06254205107688904
259 0.05936010181903839
260 0.05633501708507538
261 0.05348720774054527
262 0.05076350271701813
263 0.04818660765886307
264 0.04575677588582039
265 0.04344141110777855
266 0.04124864190816879
267 0.03914762660861015
268 0.03716236725449562
269 0.035286713391542435
270 0.03349978104233742
271 0.03181125223636627
272 0.03020293265581131
273 0.028678853064775467
274 0.027234811335802078
275 0.025878487154841423
276 0.02457866445183754
277 0.023332009091973305
278 0.022166945040225983
279 0.02103567123413086
280 0.01999017409980297
281 0.01898922026157379
282 0.01804371178150177
283 0.017139766365289688
284 0.01628217101097107
285 0.015467074699699879
286 0.014697002246975899
287 0.013974348083138466
288 0.013279332779347897
289 0.012619655579328537
290 0.01199110597372055
291 0.011399160139262676
292 0.010838690213859081
293 0.01029850821942091
294 0.009792177937924862
295 0.009305846877396107
296 0.008856789208948612
297 0.008420594036579132
298 0.008007455617189407
299 0.007622586563229561
300 0.00725289061665535
301 0.0068967887200415134
302 0.006563145201653242
303 0.006249995436519384
304 0.005944596603512764
305 0.005662217270582914
306 0.005391240119934082
307 0.005139282438904047
308 0.004890920128673315
309 0.004658904857933521
310 0.004439562559127808
311 0.004230792168527842
312 0.0040278807282447815
313 0.003844894003123045
314 0.003668016055598855
315 0.0034953823778778315
316 0.0033368412405252457
317 0.0031839325092732906
318 0.003038664348423481
319 0.002901293570175767
320 0.0027697195764631033
321 0.0026466629933565855
322 0.0025264383293688297
323 0.002412375994026661
324 0.0023051120806485415
325 0.002205394906923175
326 0.002107948763296008
327 0.002018790692090988
328 0.0019332945812493563
329 0.0018522653263062239
330 0.0017736934823915362
331 0.0016979016363620758
332 0.0016273841029033065
333 0.0015602399362251163
334 0.0014960700646042824
335 0.0014339329209178686
336 0.0013758850982412696
337 0.0013191846664994955
338 0.0012677692575380206
339 0.001218773890286684
340 0.0011696070432662964
341 0.0011240314925089478
342 0.0010790792293846607
343 0.0010395455174148083
344 0.0009993078419938684
345 0.0009605029481463134
346 0.0009250571019947529
347 0.0008913534693419933
348 0.0008590583456680179
349 0.000827247160486877
350 0.0007943366654217243
351 0.0007663737633265555
352 0.0007388820522464812
353 0.0007135853520594537
354 0.0006881675799377263
355 0.0006647685659117997
356 0.0006424027378670871
357 0.0006203672382980585
358 0.0005991467041894794
359 0.0005788102862425148
360 0.0005594320246018469
361 0.0005407712887972593
362 0.0005233221454545856
363 0.0005067690508440137
364 0.0004894177545793355
365 0.0004742353339679539
366 0.0004593220364768058
367 0.0004456668975763023
368 0.00043160858331248164
369 0.0004183050768915564
370 0.00040634439210407436
371 0.0003921492607332766
372 0.00038030429277569056
373 0.00036922251456417143
374 0.0003592130960896611
375 0.0003476277634035796
376 0.0003381051355972886
377 0.0003284666163381189
378 0.00031903706258162856
379 0.0003097275912296027
380 0.000301450549159199
381 0.0002933111391030252
382 0.00028482198831625283
383 0.00027778634103015065
384 0.00026992749189957976
385 0.0002622645697556436
386 0.00025584539980627596
387 0.0002498405228834599
388 0.00024335448688361794
389 0.00023660749138798565
390 0.00023054613848216832
391 0.00022477634774986655
392 0.00021959720470476896
393 0.0002133014058927074
394 0.0002079078694805503
395 0.00020280040916986763
396 0.00019803468603640795
397 0.00019349635113030672
398 0.00018889660714194179
399 0.00018401294073555619
400 0.0001796396099962294
401 0.0001745019108057022
402 0.00017023498367052525
403 0.00016654256614856422
404 0.00016264227451756597
405 0.00015898699348326772
406 0.00015564728528261185
407 0.00015177333261817694
408 0.00014877824287395924
409 0.00014555662346538156
410 0.00014265587378758937
411 0.00013916932221036404
412 0.00013571874296758324
413 0.00013283707085065544
414 0.00013008344103582203
415 0.00012777592928614467
416 0.00012533064000308514
417 0.00012301310198381543
418 0.00012068656360497698
419 0.00011819545761682093
420 0.00011553128570085391
421 0.00011352939327480271
422 0.00011061802069889382
423 0.00010867905075429007
424 0.00010649643081706017
425 0.000104456179542467
426 0.00010268781625200063
427 0.00010056342580355704
428 9.867809421848506e-05
429 9.665834659244865e-05
430 9.507365030003712e-05
431 9.329960448667407e-05
432 9.142476483248174e-05
433 8.99045480764471e-05
434 8.806039113551378e-05
435 8.679734310135245e-05
436 8.528588659828529e-05
437 8.38433115859516e-05
438 8.228021761169657e-05
439 8.066281588980928e-05
440 7.926095713628456e-05
441 7.805722270859405e-05
442 7.670130435144529e-05
443 7.498567720176652e-05
444 7.378555892501026e-05
445 7.273519440786913e-05
446 7.153732440201566e-05
447 7.026853563729674e-05
448 6.906060298206285e-05
449 6.806549936300144e-05
450 6.697177741443738e-05
451 6.578808824997395e-05
452 6.47310953354463e-05
453 6.382436549756676e-05
454 6.271247548284009e-05
455 6.182176002766937e-05
456 6.0791386204073206e-05
457 5.990738281980157e-05
458 5.9091311413794756e-05
459 5.82421307626646e-05
460 5.714652797905728e-05
461 5.614483598037623e-05
462 5.5391399655491114e-05
463 5.469625466503203e-05
464 5.3892130381427705e-05
465 5.3176667279331014e-05
466 5.2416893595363945e-05
467 5.133208105689846e-05
468 5.067092570243403e-05
469 5.016025170334615e-05
470 4.927466216031462e-05
471 4.869355689152144e-05
472 4.805625940207392e-05
473 4.746561171486974e-05
474 4.697928670793772e-05
475 4.628240276360884e-05
476 4.5762437366647646e-05
477 4.516990520642139e-05
478 4.469654959393665e-05
479 4.3825479224324226e-05
480 4.3354240915505216e-05
481 4.277497646398842e-05
482 4.224866279400885e-05
483 4.1636692913016304e-05
484 4.1184190195053816e-05
485 4.0546630771132186e-05
486 3.993419522885233e-05
487 3.948801168007776e-05
488 3.900128649547696e-05
489 3.855006434605457e-05
490 3.812120121438056e-05
491 3.78087570425123e-05
492 3.725534043041989e-05
493 3.6764533433597535e-05
494 3.6435758374864236e-05
495 3.5982033296022564e-05
496 3.5620309063233435e-05
497 3.5159100661985576e-05
498 3.473739343462512e-05
499 3.431739241932519e-05
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

w1 = torch.randn(D_in, H, requires_grad=True)
w2 = torch.randn(H, D_out, requires_grad=True)

learning_rate = 1e-6

    # Forward pass
y_pred = x.mm(w1).clamp(min=0).mm(w2)
#     h = x.mm(w1) # N * H
#     h_relu = h.clamp(min=0) # N * H
#     y_pred = h_relu.mm(w2) 


# compute loss
loss = (y_pred - y).pow(2).sum() # computation graph
print(it, loss.item())
499 40225440.0

PyTorch: nn

这次我们使用PyTorch中nn这个库来构建网络。
用PyTorch autograd来构建计算图和计算gradients,
然后PyTorch会帮我们自动计算gradient。

import torch.nn as nn

N, D_in, H, D_out = 64, 1000, 100, 10

# 随机创建一些训练数据
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H, bias=False), # w_1 * x + b_1
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out, bias=False),
)

torch.nn.init.normal_(model[0].weight)
torch.nn.init.normal_(model[2].weight)

# model = model.cuda()

loss_fn = nn.MSELoss(reduction='sum')

learning_rate = 1e-6
for it in range(500):
    # Forward pass
    y_pred = model(x) # model.forward() 
    
    # compute loss
    loss = loss_fn(y_pred, y) # computation graph
    print(it, loss.item())
    
    # Backward pass
    loss.backward()
    
    # update weights of w1 and w2
    with torch.no_grad():
        for param in model.parameters(): # param (tensor, grad)
            param -= learning_rate * param.grad
            
    model.zero_grad()
0 29766616.0
1 23777656.0
2 21478428.0
3 19479030.0
4 16546593.0
5 12759395.0
6 9002785.0
7 5962393.5
8 3857000.0
9 2524503.75
10 1715654.125
11 1226126.875
12 922215.625
13 725122.0
14 590359.1875
15 493087.6875
16 419570.0625
17 361689.0
18 314677.03125
19 275769.59375
20 243146.203125
21 215326.390625
22 191323.1875
23 170553.375
24 152482.8125
25 136687.875
26 122837.8203125
27 110644.3359375
28 99876.796875
29 90339.4453125
30 81865.671875
31 74322.84375
32 67591.5546875
33 61566.359375
34 56164.59375
35 51312.8203125
36 46951.60546875
37 43018.44140625
38 39464.69921875
39 36250.8828125
40 33337.57421875
41 30694.43359375
42 28291.27734375
43 26104.666015625
44 24109.50390625
45 22287.9375
46 20621.345703125
47 19095.482421875
48 17697.484375
49 16415.283203125
50 15236.939453125
51 14153.2724609375
52 13155.939453125
53 12237.396484375
54 11390.154296875
55 10608.13671875
56 9885.791015625
57 9217.955078125
58 8600.14453125
59 8028.2314453125
60 7498.20068359375
61 7006.6513671875
62 6550.58642578125
63 6126.98974609375
64 5733.33447265625
65 5367.470703125
66 5027.2177734375
67 4710.51318359375
68 4415.5185546875
69 4140.70849609375
70 3884.505615234375
71 3645.63427734375
72 3422.61865234375
73 3214.387939453125
74 3019.879150390625
75 2837.98583984375
76 2667.9033203125
77 2508.7353515625
78 2359.810546875
79 2220.4052734375
80 2089.83349609375
81 1967.515380859375
82 1852.96875
83 1745.389892578125
84 1644.5032958984375
85 1549.7891845703125
86 1460.8857421875
87 1377.3944091796875
88 1298.957763671875
89 1225.267333984375
90 1155.9970703125
91 1090.8516845703125
92 1029.61474609375
93 971.9910888671875
94 917.7593383789062
95 866.7032470703125
96 818.632080078125
97 773.3716430664062
98 730.7655029296875
99 690.5889282226562
100 652.73095703125
101 617.0667114257812
102 583.4531860351562
103 551.7386474609375
104 521.83056640625
105 493.60662841796875
106 466.98382568359375
107 441.85614013671875
108 418.13775634765625
109 395.74127197265625
110 374.6009521484375
111 354.6639404296875
112 335.8072509765625
113 317.98974609375
114 301.15496826171875
115 285.24609375
116 270.205078125
117 255.9881134033203
118 242.5457763671875
119 229.84669494628906
120 217.86180114746094
121 206.54847717285156
122 195.84739685058594
123 185.7223358154297
124 176.1398162841797
125 167.06561279296875
126 158.47686767578125
127 150.3487548828125
128 142.65390014648438
129 135.36338806152344
130 128.45870971679688
131 121.91324615478516
132 115.71708679199219
133 109.84144592285156
134 104.27703857421875
135 99.00000762939453
136 94.0049057006836
137 89.26721954345703
138 84.77240753173828
139 80.50917053222656
140 76.46817779541016
141 72.63577270507812
142 69.00017547607422
143 65.55001831054688
144 62.28118896484375
145 59.179534912109375
146 56.23418045043945
147 53.438323974609375
148 50.78583526611328
149 48.268836975097656
150 45.880165100097656
151 43.61277770996094
152 41.45987319946289
153 39.41508865356445
154 37.473873138427734
155 35.6302490234375
156 33.8792724609375
157 32.216007232666016
158 30.636571884155273
159 29.13726806640625
160 27.711807250976562
161 26.35842514038086
162 25.073198318481445
163 23.850412368774414
164 22.688386917114258
165 21.58495330810547
166 20.536270141601562
167 19.53920555114746
168 18.59156608581543
169 17.690359115600586
170 16.83396339416504
171 16.02001953125
172 15.245658874511719
173 14.510287284851074
174 13.810375213623047
175 13.144716262817383
176 12.511774063110352
177 11.910293579101562
178 11.337478637695312
179 10.79334545135498
180 10.275464057922363
181 9.782782554626465
182 9.313438415527344
183 8.867620468139648
184 8.443604469299316
185 8.039955139160156
186 7.656317710876465
187 7.29049015045166
188 6.942845821380615
189 6.611854553222656
190 6.297031402587891
191 5.997185707092285
192 5.711885929107666
193 5.440595626831055
194 5.182295799255371
195 4.936399459838867
196 4.702075481414795
197 4.479272365570068
198 4.266942024230957
199 4.065062522888184
200 3.8726344108581543
201 3.6897974014282227
202 3.515371322631836
203 3.349323272705078
204 3.1913087368011475
205 3.040682792663574
206 2.8973116874694824
207 2.7609076499938965
208 2.6310231685638428
209 2.507091522216797
210 2.389131546020508
211 2.2768702507019043
212 2.1699297428131104
213 2.0681097507476807
214 1.970934271812439
215 1.878630518913269
216 1.7905535697937012
217 1.7067335844039917
218 1.6266751289367676
219 1.550572395324707
220 1.4780292510986328
221 1.4089950323104858
222 1.3430979251861572
223 1.2802886962890625
224 1.2206246852874756
225 1.1635589599609375
226 1.1092796325683594
227 1.0576143264770508
228 1.008253812789917
229 0.9614090919494629
230 0.9165420532226562
231 0.8738452792167664
232 0.8332104682922363
233 0.7944720983505249
234 0.7575938701629639
235 0.7223514914512634
236 0.688768744468689
237 0.6568116545677185
238 0.6262905597686768
239 0.5972128510475159
240 0.5694993138313293
241 0.5430572032928467
242 0.5179727673530579
243 0.49388280510902405
244 0.4710555076599121
245 0.44927406311035156
246 0.42847856879234314
247 0.40856847167015076
248 0.3897371292114258
249 0.37162142992019653
250 0.3544740676879883
251 0.3381309509277344
252 0.322529137134552
253 0.30765706300735474
254 0.29343822598457336
255 0.27983975410461426
256 0.2669477164745331
257 0.25460538268089294
258 0.24286289513111115
259 0.23165613412857056
260 0.22098445892333984
261 0.21081732213497162
262 0.20108091831207275
263 0.19184118509292603
264 0.18301290273666382
265 0.17458976805210114
266 0.16653916239738464
267 0.15889734029769897
268 0.1515941023826599
269 0.1446174830198288
270 0.13800375163555145
271 0.1316535472869873
272 0.12555857002735138
273 0.11981766670942307
274 0.11429396271705627
275 0.10903395712375641
276 0.10406746715307236
277 0.09929314255714417
278 0.0947614386677742
279 0.09039424359798431
280 0.08625229448080063
281 0.08230412006378174
282 0.07852151989936829
283 0.07493852823972702
284 0.071522556245327
285 0.06824541091918945
286 0.06513582915067673
287 0.06213860958814621
288 0.059307388961315155
289 0.05659571290016174
290 0.05399005860090256
291 0.05152208358049393
292 0.04915192723274231
293 0.0469355471432209
294 0.04480253905057907
295 0.04274328798055649
296 0.040802378207445145
297 0.038954056799411774
298 0.03717837482690811
299 0.035465385764837265
300 0.03384703770279884
301 0.032314691692590714
302 0.03085215575993061
303 0.02944507822394371
304 0.028107093647122383
305 0.02683217264711857
306 0.025605276226997375
307 0.02445397898554802
308 0.02335381880402565
309 0.022288480773568153
310 0.021276239305734634
311 0.020312387496232986
312 0.019397035241127014
313 0.018527137115597725
314 0.01770195923745632
315 0.016903115436434746
316 0.01614859327673912
317 0.01542503573000431
318 0.01472427323460579
319 0.014069728553295135
320 0.013434939086437225
321 0.012836556881666183
322 0.012258786708116531
323 0.011717756278812885
324 0.011197803542017937
325 0.010700796730816364
326 0.01022394560277462
327 0.00977302435785532
328 0.009333723224699497
329 0.008923370391130447
330 0.008526774123311043
331 0.008147627115249634
332 0.007792768068611622
333 0.007454231381416321
334 0.00712941400706768
335 0.00682031037285924
336 0.006521350238472223
337 0.006236766930669546
338 0.005969686433672905
339 0.005714518018066883
340 0.0054693021811544895
341 0.005235066171735525
342 0.005006518214941025
343 0.004793217405676842
344 0.004587443545460701
345 0.004396097734570503
346 0.004208866506814957
347 0.004032652825117111
348 0.0038641956634819508
349 0.0037003259640187025
350 0.0035438132472336292
351 0.0033953639212995768
352 0.003258656244724989
353 0.003122936002910137
354 0.0029965597204864025
355 0.0028731501661241055
356 0.0027524558827281
357 0.0026407609693706036
358 0.00253285258077085
359 0.002433427609503269
360 0.002336303936317563
361 0.0022421758621931076
362 0.0021542594768106937
363 0.002070060698315501
364 0.0019879345782101154
365 0.0019094282761216164
366 0.0018372323829680681
367 0.0017662843456491828
368 0.0016980170039460063
369 0.0016339552821591496
370 0.0015738147776573896
371 0.001514765084721148
372 0.0014577957335859537
373 0.0014065473806113005
374 0.0013545791152864695
375 0.0013044598745182157
376 0.0012587555684149265
377 0.0012123232008889318
378 0.001168436836451292
379 0.0011263287160545588
380 0.0010876681189984083
381 0.0010489921551197767
382 0.0010123583488166332
383 0.0009766623843461275
384 0.0009423411684110761
385 0.0009096176945604384
386 0.0008782913209870458
387 0.0008485364378429949
388 0.0008198323776014149
389 0.000791664351709187
390 0.0007662636344321072
391 0.0007405986543744802
392 0.0007158889202401042
393 0.0006938243750482798
394 0.0006707744323648512
395 0.0006491108797490597
396 0.0006281695677898824
397 0.0006089641246944666
398 0.0005892976187169552
399 0.0005709282122552395
400 0.0005544618470594287
401 0.0005376324406825006
402 0.0005212302785366774
403 0.0005050330655649304
404 0.0004897580365650356
405 0.00047591893235221505
406 0.00046078150626271963
407 0.00044790987158194184
408 0.00043429495417512953
409 0.00042069086339324713
410 0.0004090293077751994
411 0.0003977144951932132
412 0.000386145111406222
413 0.000375887262634933
414 0.0003656860499177128
415 0.0003549414104782045
416 0.0003455854894127697
417 0.0003364203148521483
418 0.00032716983696445823
419 0.00031906343065202236
420 0.00031057297019287944
421 0.0003017220296896994
422 0.000293860852252692
423 0.0002862894325517118
424 0.000279192317975685
425 0.000271642638836056
426 0.00026528388843871653
427 0.00025870627723634243
428 0.0002524942101445049
429 0.0002451710752211511
430 0.00023987199529074132
431 0.00023324164794757962
432 0.00022728127078153193
433 0.00022195448400452733
434 0.000217096705455333
435 0.00021176922018639743
436 0.00020652025705203414
437 0.00020189551287330687
438 0.0001968961878446862
439 0.0001927059784065932
440 0.0001883884979179129
441 0.0001844682265073061
442 0.00018013891531154513
443 0.00017592639778740704
444 0.00017223588656634092
445 0.00016857586160767823
446 0.00016467602108605206
447 0.00016065315867308527
448 0.0001573258632561192
449 0.00015407492173835635
450 0.00015059291035868227
451 0.00014780194032937288
452 0.00014449226728174835
453 0.0001413429417880252
454 0.00013875112927053124
455 0.00013627571752294898
456 0.0001334023691015318
457 0.00013062136713415384
458 0.00012823991710320115
459 0.00012523178884293884
460 0.00012290496670175344
461 0.00012050622899550945
462 0.00011851731687784195
463 0.00011616052506724373
464 0.00011405172699596733
465 0.00011175837425980717
466 0.00010943552479147911
467 0.0001071853912435472
468 0.00010494334128452465
469 0.00010316870611859486
470 0.00010090871364809573
471 9.946932550519705e-05
472 9.77175441221334e-05
473 9.580243204254657e-05
474 9.416375542059541e-05
475 9.234881144948304e-05
476 9.083648910745978e-05
477 8.922808046918362e-05
478 8.765413076616824e-05
479 8.61199659993872e-05
480 8.436515781795606e-05
481 8.308548422064632e-05
482 8.181951125152409e-05
483 8.00254347268492e-05
484 7.890770211815834e-05
485 7.761816959828138e-05
486 7.621549593750387e-05
487 7.477735198335722e-05
488 7.392222323687747e-05
489 7.285339961526915e-05
490 7.134801853680983e-05
491 7.043810182949528e-05
492 6.908640352776274e-05
493 6.808344187447801e-05
494 6.7174929426983e-05
495 6.617417238885537e-05
496 6.486748316092417e-05
497 6.411506910808384e-05
498 6.326192669803277e-05
499 6.2366365455091e-05
model[0].weight
Parameter containing:
tensor([[-0.0218,  0.0212,  0.0243,  ...,  0.0230,  0.0247,  0.0168],
        [-0.0144,  0.0177, -0.0221,  ...,  0.0161,  0.0098, -0.0172],
        [ 0.0086, -0.0122, -0.0298,  ..., -0.0236, -0.0187,  0.0295],
        ...,
        [ 0.0266, -0.0008, -0.0141,  ...,  0.0018,  0.0319, -0.0129],
        [ 0.0296, -0.0005,  0.0115,  ...,  0.0141, -0.0088, -0.0106],
        [ 0.0289, -0.0077,  0.0239,  ..., -0.0166, -0.0156, -0.0235]],
       requires_grad=True)

PyTorch: optim

这一次我们不再手动更新模型的weights,而是使用optim这个包来帮助我们更新参数。
optim这个package提供了各种不同的模型优化方法,包括SGD+momentum, RMSProp, Adam等等。

import torch.nn as nn

N, D_in, H, D_out = 64, 1000, 100, 10

# 随机创建一些训练数据
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H, bias=False), # w_1 * x + b_1
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out, bias=False),
)

torch.nn.init.normal_(model[0].weight)
torch.nn.init.normal_(model[2].weight)

# model = model.cuda()

loss_fn = nn.MSELoss(reduction='sum')
# learning_rate = 1e-4
# optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

learning_rate = 1e-6
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

for it in range(500):
    # Forward pass
    y_pred = model(x) # model.forward() 
    
    # compute loss
    loss = loss_fn(y_pred, y) # computation graph
    print(it, loss.item())

    optimizer.zero_grad()
    # Backward pass
    loss.backward()
    
    # update model parameters
    optimizer.step()

0 24436214.0
1 20115276.0
2 18840850.0
3 18223790.0
4 17027580.0
5 14675071.0
6 11567663.0
7 8366201.5
8 5720385.0
9 3799774.75
10 2535152.0
11 1735162.5
12 1236944.25
13 922564.25
14 718680.1875
15 580795.125
16 483214.65625
17 410988.375
18 355345.84375
19 310880.21875
20 274325.625
21 243722.59375
22 217734.828125
23 195305.78125
24 175787.5625
25 158686.078125
26 143600.5
27 130217.0390625
28 118322.4765625
29 107720.890625
30 98256.671875
31 89756.2734375
32 82104.359375
33 75197.3125
34 68949.78125
35 63292.28515625
36 58161.140625
37 53495.71484375
38 49262.35546875
39 45406.51171875
40 41886.3671875
41 38671.0625
42 35729.078125
43 33036.390625
44 30567.708984375
45 28301.845703125
46 26222.076171875
47 24308.93359375
48 22548.6953125
49 20927.591796875
50 19433.642578125
51 18058.23046875
52 16788.662109375
53 15616.2177734375
54 14533.13671875
55 13530.798828125
56 12604.1884765625
57 11745.923828125
58 10950.625
59 10213.337890625
60 9529.8671875
61 8895.59375
62 8306.091796875
63 7758.91064453125
64 7250.0498046875
65 6776.876953125
66 6337.04443359375
67 5927.30078125
68 5546.20166015625
69 5191.66162109375
70 4860.919921875
71 4552.841796875
72 4265.46484375
73 3997.487060546875
74 3747.48681640625
75 3514.1376953125
76 3296.32763671875
77 3092.704833984375
78 2902.46435546875
79 2724.722412109375
80 2558.482666015625
81 2402.98779296875
82 2257.489013671875
83 2121.367919921875
84 1993.8583984375
85 1874.394775390625
86 1762.48974609375
87 1657.6212158203125
88 1559.3736572265625
89 1467.2152099609375
90 1380.824951171875
91 1299.764892578125
92 1223.675048828125
93 1152.2919921875
94 1085.2791748046875
95 1022.3812866210938
96 963.3178100585938
97 907.811279296875
98 855.660888671875
99 806.6342163085938
100 760.5975341796875
101 717.3490600585938
102 676.70654296875
103 638.4835815429688
104 602.4908447265625
105 568.6420288085938
106 536.7633056640625
107 506.75787353515625
108 478.52423095703125
109 451.9193115234375
110 426.85400390625
111 403.2418518066406
112 380.9964904785156
113 360.0345153808594
114 340.2715759277344
115 321.6309814453125
116 304.0492248535156
117 287.4762878417969
118 271.8389892578125
119 257.1002197265625
120 243.17950439453125
121 230.04359436035156
122 217.64369201660156
123 205.94189453125
124 194.8915252685547
125 184.45968627929688
126 174.60533142089844
127 165.29302978515625
128 156.4994659423828
129 148.18898010253906
130 140.34364318847656
131 132.92042541503906
132 125.90531921386719
133 119.2744369506836
134 113.00390625
135 107.079833984375
136 101.47257995605469
137 96.16826629638672
138 91.15074920654297
139 86.40707397460938
140 81.91826629638672
141 77.66998291015625
142 73.64669799804688
143 69.83943939208984
144 66.23517608642578
145 62.82527160644531
146 59.59428024291992
147 56.534263610839844
148 53.63661575317383
149 50.89312744140625
150 48.29410934448242
151 45.830631256103516
152 43.496971130371094
153 41.28459930419922
154 39.19132614135742
155 37.20603561401367
156 35.32252883911133
157 33.53795623779297
158 31.845619201660156
159 30.243682861328125
160 28.722116470336914
161 27.28006935119629
162 25.911415100097656
163 24.615415573120117
164 23.38458251953125
165 22.21629524230957
166 21.108041763305664
167 20.058032989501953
168 19.06101417541504
169 18.114952087402344
170 17.21552276611328
171 16.362825393676758
172 15.55347728729248
173 14.785439491271973
174 14.056044578552246
175 13.363285064697266
176 12.705596923828125
177 12.082019805908203
178 11.488263130187988
179 10.925143241882324
180 10.39013671875
181 9.88184928894043
182 9.399191856384277
183 8.940385818481445
184 8.504379272460938
185 8.09056568145752
186 7.697397708892822
187 7.323540210723877
188 6.967723369598389
189 6.630079746246338
190 6.30927038192749
191 6.003860950469971
192 5.7135701179504395
193 5.4377899169921875
194 5.175771713256836
195 4.926578998565674
196 4.6893110275268555
197 4.46390962600708
198 4.249563694000244
199 4.0457329750061035
200 3.851684808731079
201 3.6672728061676025
202 3.4917759895324707
203 3.3250765800476074
204 3.166198492050171
205 3.015110969543457
206 2.8714663982391357
207 2.7346131801605225
208 2.6045756340026855
209 2.480750799179077
210 2.362886667251587
211 2.250751256942749
212 2.1440927982330322
213 2.04252028465271
214 1.9457452297210693
215 1.8538568019866943
216 1.7661700248718262
217 1.6827582120895386
218 1.6033743619918823
219 1.5277283191680908
220 1.4558110237121582
221 1.387202262878418
222 1.322044014930725
223 1.2599530220031738
224 1.2007644176483154
225 1.144354224205017
226 1.0907174348831177
227 1.0396398305892944
228 0.990975022315979
229 0.9446046948432922
230 0.9003474116325378
231 0.8583122491836548
232 0.8182880282402039
233 0.7800549864768982
234 0.743695855140686
235 0.7090640068054199
236 0.6760954856872559
237 0.6445534229278564
238 0.6146584749221802
239 0.5860655903816223
240 0.5588774085044861
241 0.5329322218894958
242 0.5082358717918396
243 0.4846791625022888
244 0.4622225761413574
245 0.4408394694328308
246 0.42042845487594604
247 0.4009735584259033
248 0.38249772787094116
249 0.364850789308548
250 0.34804221987724304
251 0.3319990336894989
252 0.3166584372520447
253 0.30206847190856934
254 0.28815487027168274
255 0.27485451102256775
256 0.2622392177581787
257 0.2502085864543915
258 0.23871560394763947
259 0.2277631163597107
260 0.21732009947299957
261 0.20739634335041046
262 0.1978614628314972
263 0.18878591060638428
264 0.18012169003486633
265 0.17191345989704132
266 0.16405321657657623
267 0.15654942393302917
268 0.1493798941373825
269 0.14255140721797943
270 0.13606132566928864
271 0.12984351813793182
272 0.12392966449260712
273 0.11828220635652542
274 0.11291969567537308
275 0.10776922851800919
276 0.10287689417600632
277 0.09821106493473053
278 0.09375122934579849
279 0.08948099613189697
280 0.08539889752864838
281 0.08153844624757767
282 0.07785839587450027
283 0.07430348545312881
284 0.0709386020898819
285 0.06773824989795685
286 0.06466753780841827
287 0.06175475940108299
288 0.05895467475056648
289 0.056295156478881836
290 0.0537416897714138
291 0.051327235996723175
292 0.049020860344171524
293 0.04680536314845085
294 0.04470276087522507
295 0.042676717042922974
296 0.0407608300447464
297 0.03894304484128952
298 0.03718706592917442
299 0.035516317933797836
300 0.03393702581524849
301 0.03242314234375954
302 0.03096468560397625
303 0.029573975130915642
304 0.02825460024178028
305 0.026992876082658768
306 0.025776676833629608
307 0.02463972568511963
308 0.023537106812000275
309 0.022485176101326942
310 0.021485133096575737
311 0.02053215727210045
312 0.01961725950241089
313 0.018752608448266983
314 0.017924023792147636
315 0.017138047143816948
316 0.016380706802010536
317 0.015649406239390373
318 0.014961308799684048
319 0.014297058805823326
320 0.01366621907800436
321 0.013071554712951183
322 0.012496653012931347
323 0.011956160888075829
324 0.011431466788053513
325 0.010935298167169094
326 0.01046603824943304
327 0.010009510442614555
328 0.00957783404737711
329 0.009157683700323105
330 0.008763862773776054
331 0.00838443636894226
332 0.00802331417798996
333 0.007679164409637451
334 0.007354430388659239
335 0.007041578181087971
336 0.006744027603417635
337 0.0064595723524689674
338 0.006185965612530708
339 0.005921291187405586
340 0.005675483029335737
341 0.00543745793402195
342 0.005212889518588781
343 0.00499426294118166
344 0.004786796867847443
345 0.004592748824506998
346 0.004408272914588451
347 0.004227207042276859
348 0.004052576143294573
349 0.0038889518473297358
350 0.003731220494955778
351 0.0035824656952172518
352 0.0034373796079307795
353 0.0033003336284309626
354 0.003168500494211912
355 0.0030428837053477764
356 0.0029242313466966152
357 0.00280955177731812
358 0.0027018673717975616
359 0.002595777390524745
360 0.0024974728003144264
361 0.0023979872930794954
362 0.0023071218747645617
363 0.002222965005785227
364 0.002138041891157627
365 0.0020597113762050867
366 0.001983368769288063
367 0.001912725274451077
368 0.0018420533742755651
369 0.0017750875558704138
370 0.0017139844130724669
371 0.0016515005845576525
372 0.0015924760373309255
373 0.0015359288081526756
374 0.0014829676365479827
375 0.0014301817864179611
376 0.0013818496372550726
377 0.00133396009914577
378 0.0012875873362645507
379 0.0012446430046111345
380 0.00120221683755517
381 0.0011615381808951497
382 0.0011241419706493616
383 0.001086542964912951
384 0.00105152593459934
385 0.0010166612919420004
386 0.0009863207815214992
387 0.0009538594749756157
388 0.0009235472534783185
389 0.0008946286980062723
390 0.0008671405958011746
391 0.0008400777005590498
392 0.0008136503165587783
393 0.0007891628192737699
394 0.0007641659467481077
395 0.0007422937196679413
396 0.0007196390070021152
397 0.0006988736568018794
398 0.0006786492886021733
399 0.0006577383610419929
400 0.0006379594560712576
401 0.0006192006985656917
402 0.0006014448590576649
403 0.000583052053116262
404 0.0005676199798472226
405 0.000550447846762836
406 0.0005354208988137543
407 0.000520190573297441
408 0.0005064225988462567
409 0.0004918242339044809
410 0.00047766356146894395
411 0.00046613806625828147
412 0.0004535649495664984
413 0.00044190019252710044
414 0.00042959259008057415
415 0.0004186177102383226
416 0.00040694003109820187
417 0.0003977562300860882
418 0.000387015548767522
419 0.00037667807191610336
420 0.00036764898686669767
421 0.00035817461321130395
422 0.000349209934938699
423 0.00034028870868496597
424 0.0003315831418149173
425 0.0003239487705286592
426 0.00031686556758359075
427 0.00030898250406607985
428 0.00030165858333930373
429 0.0002940705744549632
430 0.0002872246550396085
431 0.00028022946207784116
432 0.00027337868232280016
433 0.0002681115292944014
434 0.00026153764338232577
435 0.0002554746170062572
436 0.0002504090080037713
437 0.00024434077204205096
438 0.0002393243630649522
439 0.00023363585933111608
440 0.00022893572167959064
441 0.00022382299357559532
442 0.00021802390983793885
443 0.0002135633840225637
444 0.00020877565839327872
445 0.00020438502542674541
446 0.000200132533791475
447 0.00019636568322312087
448 0.00019197550136595964
449 0.00018845757585950196
450 0.00018516821728553623
451 0.0001812220725696534
452 0.00017768006364349276
453 0.00017394236056134105
454 0.00017036692588590086
455 0.0001669702905928716
456 0.000163633594638668
457 0.0001608784223208204
458 0.00015729425649624318
459 0.00015425201854668558
460 0.0001512485760031268
461 0.00014837279741186649
462 0.00014571723295375705
463 0.00014315942826215178
464 0.0001404505455866456
465 0.00013795308768749237
466 0.00013533096353057772
467 0.00013275298988446593
468 0.0001303627504967153
469 0.00012791437620762736
470 0.00012587543460540473
471 0.00012379918189253658
472 0.000121756260341499
473 0.00011986290337517858
474 0.00011718282621586695
475 0.00011540651030372828
476 0.00011365834507159889
477 0.00011187761265318841
478 0.00011016721691703424
479 0.00010829235543496907
480 0.00010646654118318111
481 0.0001048781123245135
482 0.0001032217187457718
483 0.00010149369336431846
484 9.985972428694367e-05
485 9.835550736170262e-05
486 9.673195017967373e-05
487 9.538326412439346e-05
488 9.360058174934238e-05
489 9.203865192830563e-05
490 9.080882591661066e-05
491 8.959635306382552e-05
492 8.830626757116988e-05
493 8.669115777593106e-05
494 8.531992352800444e-05
495 8.447348955087364e-05
496 8.305440132971853e-05
497 8.147572225425392e-05
498 8.059616084210575e-05
499 7.961507071740925e-05

PyTorch: 自定义 nn Modules

我们可以定义一个模型,这个模型继承自nn.Module类。如果需要定义一个比Sequential模型更加复杂的模型,就需要定义nn.Module模型。

import torch.nn as nn

N, D_in, H, D_out = 64, 1000, 100, 10

# 随机创建一些训练数据
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

class TwoLayerNet(torch.nn.Module):
    def __init__(self,D_in,H,D_out):
        super(TwoLayerNet,self).__init__()
        self.linear1 = torch.nn.Linear(D_in,H,bias=False)
        self.linear2 = torch.nn.Linear(H,D_out,bias=False)
        
    def forward(self,x):
        y_pred = self.linear2(self.linear1(x).clamp(min=0))
        return y_pred


model = TwoLayerNet(D_in,H,D_out)
loss_fn = nn.MSELoss(reduction='sum')
learning_rate = 1e-4
# optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# learning_rate = 1e-6
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

for it in range(500):
    # Forward pass
    y_pred = model(x) # model.forward() 
    
    # compute loss
    loss = loss_fn(y_pred, y) # computation graph
    print(it, loss.item())

    optimizer.zero_grad()
    # Backward pass
    loss.backward()
    
    # update model parameters
    optimizer.step()

0 643.50390625
1 596.1583251953125
2 554.8339233398438
3 518.2571411132812
4 485.13421630859375
5 455.211669921875
6 427.857177734375
7 402.695556640625
8 379.4078369140625
9 357.7391052246094
10 337.6387023925781
11 319.02978515625
12 301.5810852050781
13 285.24334716796875
14 269.781005859375
15 255.1282501220703
16 241.20184326171875
17 228.06509399414062
18 215.64251708984375
19 203.86334228515625
20 192.68496704101562
21 182.07632446289062
22 172.00338745117188
23 162.4481201171875
24 153.3866729736328
25 144.79428100585938
26 136.60987854003906
27 128.85963439941406
28 121.52810668945312
29 114.61528015136719
30 108.07255554199219
31 101.88998413085938
32 96.05018615722656
33 90.5425033569336
34 85.35028839111328
35 80.45690155029297
36 75.83234405517578
37 71.45622253417969
38 67.33213806152344
39 63.444091796875
40 59.78274154663086
41 56.32850646972656
42 53.080177307128906
43 50.02677536010742
44 47.1556282043457
45 44.454795837402344
46 41.912052154541016
47 39.52488327026367
48 37.28063201904297
49 35.1696662902832
50 33.17994689941406
51 31.306848526000977
52 29.544689178466797
53 27.885482788085938
54 26.324779510498047
55 24.856962203979492
56 23.4764347076416
57 22.176918029785156
58 20.95256805419922
59 19.800678253173828
60 18.715728759765625
61 17.69495391845703
62 16.733766555786133
63 15.825174331665039
64 14.968779563903809
65 14.161850929260254
66 13.400372505187988
67 12.682832717895508
68 12.005960464477539
69 11.3670015335083
70 10.763898849487305
71 10.195371627807617
72 9.657954216003418
73 9.151839256286621
74 8.673748016357422
75 8.222179412841797
76 7.795566558837891
77 7.392944812774658
78 7.012562274932861
79 6.652833938598633
80 6.313276290893555
81 5.991885185241699
82 5.688510894775391
83 5.402405738830566
84 5.131752014160156
85 4.875490665435791
86 4.633208274841309
87 4.403843879699707
88 4.186874866485596
89 3.9817159175872803
90 3.787443161010742
91 3.6034152507781982
92 3.4290924072265625
93 3.2637829780578613
94 3.106860876083374
95 2.958101272583008
96 2.817206621170044
97 2.6837315559387207
98 2.557152032852173
99 2.437047004699707
100 2.3230109214782715
101 2.2147605419158936
102 2.1119933128356934
103 2.014394998550415
104 1.9217182397842407
105 1.8335442543029785
106 1.7496813535690308
107 1.6700289249420166
108 1.594429850578308
109 1.5225064754486084
110 1.4541884660720825
111 1.3891369104385376
112 1.3271598815917969
113 1.268174648284912
114 1.2121176719665527
115 1.1586997509002686
116 1.1078498363494873
117 1.059446096420288
118 1.013326644897461
119 0.9693808555603027
120 0.9275099635124207
121 0.8876402974128723
122 0.8496061563491821
123 0.8133607506752014
124 0.7787914276123047
125 0.7458066344261169
126 0.7143651247024536
127 0.6843442916870117
128 0.6557046175003052
129 0.6283568143844604
130 0.6022655963897705
131 0.5773533582687378
132 0.5535581707954407
133 0.5308330655097961
134 0.5091062784194946
135 0.4883470833301544
136 0.4685206413269043
137 0.44956618547439575
138 0.43139609694480896
139 0.41401591897010803
140 0.3973923623561859
141 0.3814816176891327
142 0.3662518858909607
143 0.3516913652420044
144 0.33775147795677185
145 0.3244094252586365
146 0.31163325905799866
147 0.2993829548358917
148 0.28765684366226196
149 0.2764323353767395
150 0.265671968460083
151 0.2553692162036896
152 0.2454928308725357
153 0.23602063953876495
154 0.22694392502307892
155 0.21824125945568085
156 0.20991156995296478
157 0.20192688703536987
158 0.19426807761192322
159 0.1869065761566162
160 0.17984475195407867
161 0.1730717420578003
162 0.16656909883022308
163 0.16032937169075012
164 0.1543383002281189
165 0.14859001338481903
166 0.14306719601154327
167 0.13776302337646484
168 0.13266821205615997
169 0.1277797818183899
170 0.12307970225811005
171 0.11855921149253845
172 0.11421595513820648
173 0.11004259437322617
174 0.10603177547454834
175 0.10217587649822235
176 0.09846923500299454
177 0.09490332752466202
178 0.09147759526968002
179 0.08818051964044571
180 0.08500966429710388
181 0.08195913583040237
182 0.07902351766824722
183 0.07620026171207428
184 0.07348329573869705
185 0.07087048888206482
186 0.06835521012544632
187 0.06593383103609085
188 0.06360296159982681
189 0.06135967746376991
190 0.05919862911105156
191 0.05711675435304642
192 0.055113475769758224
193 0.053186118602752686
194 0.051327235996723175
195 0.04953685402870178
196 0.04781219735741615
197 0.04615061730146408
198 0.04456048831343651
199 0.043028417974710464
200 0.04155047982931137
201 0.040126096457242966
202 0.03875337541103363
203 0.03743097186088562
204 0.03615552559494972
205 0.03492562100291252
206 0.03373995050787926
207 0.0325964018702507
208 0.031494345515966415
209 0.03043070249259472
210 0.02940497174859047
211 0.028415612876415253
212 0.027461133897304535
213 0.026540271937847137
214 0.0256519578397274
215 0.024794844910502434
216 0.023967135697603226
217 0.023168642073869705
218 0.022398268803954124
219 0.021655231714248657
220 0.020937874913215637
221 0.020244888961315155
222 0.019575942307710648
223 0.01892971247434616
224 0.018306255340576172
225 0.0177040696144104
226 0.017122581601142883
227 0.0165608711540699
228 0.016018372029066086
229 0.01549447514116764
230 0.014988349750638008
231 0.014499458484351635
232 0.014027480967342854
233 0.013571377843618393
234 0.013130509294569492
235 0.012704497203230858
236 0.01229295413941145
237 0.011895306408405304
238 0.011511022225022316
239 0.011139780282974243
240 0.010780880227684975
241 0.010434587486088276
242 0.010099063627421856
243 0.009774786420166492
244 0.009461523965001106
245 0.009158661589026451
246 0.008865751326084137
247 0.00858248956501484
248 0.008308641612529755
249 0.008043909445405006
250 0.007787936367094517
251 0.007540302816778421
252 0.007300873752683401
253 0.007069278508424759
254 0.006845217198133469
255 0.006628581788390875
256 0.006419079843908548
257 0.006216380745172501
258 0.006020296830683947
259 0.005830589681863785
260 0.005647046025842428
261 0.005469473544508219
262 0.005297815427184105
263 0.0051317899487912655
264 0.004971045069396496
265 0.004815416410565376
266 0.004664790816605091
267 0.004519041161984205
268 0.004377990961074829
269 0.004241453018039465
270 0.004109332337975502
271 0.003981441259384155
272 0.003857683390378952
273 0.0037379274144768715
274 0.00362191628664732
275 0.003509600879624486
276 0.0034009318333119154
277 0.0032957338262349367
278 0.003193825948983431
279 0.0030951944645494223
280 0.002999735763296485
281 0.002907247981056571
282 0.002817726694047451
283 0.002731019863858819
284 0.0026472024619579315
285 0.002565888687968254
286 0.0024871407076716423
287 0.002410889370366931
288 0.0023370322305709124
289 0.002265491522848606
290 0.002196189481765032
291 0.00212907325476408
292 0.0020640871953219175
293 0.002001119777560234
294 0.0019401400350034237
295 0.001881021773442626
296 0.0018238313496112823
297 0.0017683777259662747
298 0.0017146460013464093
299 0.0016626017168164253
300 0.001612187596037984
301 0.001563329016789794
302 0.0015160118928179145
303 0.0014701565960422158
304 0.0014257251750677824
305 0.0013826463837176561
306 0.0013409872772172093
307 0.0013005408691242337
308 0.0012613482540473342
309 0.001223351457156241
310 0.0011865469859912992
311 0.0011508765164762735
312 0.0011162819573655725
313 0.0010827595833688974
314 0.0010502805234864354
315 0.0010187827283516526
316 0.000988290528766811
317 0.0009587243548594415
318 0.0009300425299443305
319 0.0009022317826747894
320 0.0008752950234338641
321 0.000849154603201896
322 0.0008238255395554006
323 0.00079927226761356
324 0.0007754720281809568
325 0.0007523798267357051
326 0.0007300009019672871
327 0.0007082807132974267
328 0.000687277119141072
329 0.0006668754504062235
330 0.0006470840889960527
331 0.0006278998916968703
332 0.0006093102274462581
333 0.0005912552587687969
334 0.0005737639730796218
335 0.0005567987682297826
336 0.0005403520772233605
337 0.0005244037602096796
338 0.000508927449118346
339 0.0004939152859151363
340 0.0004793632251676172
341 0.00046523933997377753
342 0.0004515491018537432
343 0.0004382577899377793
344 0.0004253771039657295
345 0.00041289228829555213
346 0.0004007562529295683
347 0.0003890114603564143
348 0.00037760313716717064
349 0.0003665318654384464
350 0.00035582270356826484
351 0.00034540961496531963
352 0.0003353079373482615
353 0.0003255090268794447
354 0.0003159997286275029
355 0.0003067783545702696
356 0.00029783070203848183
357 0.00028914102585986257
358 0.00028071439010091126
359 0.0002725384838413447
360 0.00026461269590072334
361 0.0002569258213043213
362 0.0002494502696208656
363 0.00024220230989158154
364 0.00023517692170571536
365 0.00022834861010778695
366 0.00022172437456902117
367 0.0002153010864276439
368 0.00020906204008497298
369 0.00020301638869568706
370 0.00019714212976396084
371 0.00019144076213706285
372 0.00018590723630040884
373 0.00018055386317428201
374 0.0001753446995280683
375 0.00017028837464749813
376 0.00016537835472263396
377 0.0001606080768397078
378 0.00015599098696839064
379 0.00015149489627219737
380 0.00014713894051965326
381 0.0001429070980520919
382 0.00013880711048841476
383 0.0001348184305243194
384 0.00013095102622173727
385 0.00012719776714220643
386 0.00012355335638858378
387 0.00012001392315141857
388 0.00011658400035230443
389 0.00011324407387292013
390 0.0001100109948311001
391 0.00010686026507755741
392 0.00010381064930697903
393 0.0001008469334919937
394 9.796746599022299e-05
395 9.517112630419433e-05
396 9.246824629371986e-05
397 8.982933650258929e-05
398 8.727205567993224e-05
399 8.478765084873885e-05
400 8.237401198130101e-05
401 8.003140101209283e-05
402 7.775943959131837e-05
403 7.554843614343554e-05
404 7.340432784985751e-05
405 7.132549944799393e-05
406 6.930233212187886e-05
407 6.733900227118284e-05
408 6.543210474774241e-05
409 6.357915117405355e-05
410 6.17831465206109e-05
411 6.003151065669954e-05
412 5.8339144743513316e-05
413 5.6691871577640995e-05
414 5.5090906244004145e-05
415 5.3535237384494394e-05
416 5.20265348313842e-05
417 5.055889050709084e-05
418 4.913530210615136e-05
419 4.7752972022863105e-05
420 4.6409502829192206e-05
421 4.5104152377462015e-05
422 4.383470877655782e-05
423 4.260554851498455e-05
424 4.1409039113204926e-05
425 4.024508598376997e-05
426 3.9116966945584863e-05
427 3.801933053182438e-05
428 3.6956334952265024e-05
429 3.5919998481404036e-05
430 3.491135430522263e-05
431 3.3935921237571165e-05
432 3.298764568171464e-05
433 3.20655781251844e-05
434 3.117137021035887e-05
435 3.029843355761841e-05
436 2.9452032322296873e-05
437 2.8631173336179927e-05
438 2.7831340048578568e-05
439 2.7056015824200585e-05
440 2.6304227503715083e-05
441 2.5570319849066436e-05
442 2.4856406525941566e-05
443 2.4166112780221738e-05
444 2.3494641936849803e-05
445 2.2840760721010156e-05
446 2.220711212430615e-05
447 2.1591475160676055e-05
448 2.0993134967284277e-05
449 2.0410347133292817e-05
450 1.9843460904667154e-05
451 1.929254904098343e-05
452 1.8759303202386945e-05
453 1.823955426516477e-05
454 1.7734291759552434e-05
455 1.724492722132709e-05
456 1.6768060959293507e-05
457 1.630611950531602e-05
458 1.5854229786782525e-05
459 1.5418048860738054e-05
460 1.4992020624049474e-05
461 1.4579009985027369e-05
462 1.4176099284668453e-05
463 1.3784307157038711e-05
464 1.340637118119048e-05
465 1.3036725249548908e-05
466 1.2677173799602315e-05
467 1.2328306183917448e-05
468 1.1992384315817617e-05
469 1.1661880307656247e-05
470 1.1341294339217711e-05
471 1.103102204069728e-05
472 1.0727787412179168e-05
473 1.0432360795675777e-05
474 1.0147992725251243e-05
475 9.86904797173338e-06
476 9.598277756595053e-06
477 9.336837138107512e-06
478 9.081155440071598e-06
479 8.833509127725847e-06
480 8.59129704622319e-06
481 8.355967111128848e-06
482 8.127360160870012e-06
483 7.905085112724919e-06
484 7.689598533033859e-06
485 7.480665317416424e-06
486 7.276807536982233e-06
487 7.078009730321355e-06
488 6.885013135615736e-06
489 6.699090590700507e-06
490 6.516103894682601e-06
491 6.338487764878664e-06
492 6.1669588831136934e-06
493 5.998598226142349e-06
494 5.836237050971249e-06
495 5.676768978446489e-06
496 5.523819709196687e-06
497 5.373812200559769e-06
498 5.227952897257637e-06
499 5.086752025817987e-06
import torch.nn as nn

N, D_in, H, D_out = 64, 1000, 100, 10

# 随机创建一些训练数据
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(TwoLayerNet, self).__init__()
        # define the model architecture
        self.linear1 = torch.nn.Linear(D_in, H, bias=False)
        self.linear2 = torch.nn.Linear(H, D_out, bias=False)
    
    def forward(self, x):
        y_pred = self.linear2(self.linear1(x).clamp(min=0))
        return y_pred

model = TwoLayerNet(D_in, H, D_out)
loss_fn = nn.MSELoss(reduction='sum')
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for it in range(500):
    # Forward pass
    y_pred = model(x) # model.forward() 
    
    # compute loss
    loss = loss_fn(y_pred, y) # computation graph
    print(it, loss.item())

    optimizer.zero_grad()
    # Backward pass
    loss.backward()
    
    # update model parameters
    optimizer.step()

0 720.544921875
1 702.627685546875
2 685.1553955078125
3 668.2481079101562
4 651.8722534179688
5 635.8622436523438
6 620.3606567382812
7 605.335693359375
8 590.691162109375
9 576.394287109375
10 562.4141845703125
11 548.7941284179688
12 535.5223999023438
13 522.7237548828125
14 510.2599182128906
15 498.1318359375
16 486.36083984375
17 474.88922119140625
18 463.7231750488281
19 452.8804931640625
20 442.35601806640625
21 432.02728271484375
22 421.9671630859375
23 412.1201477050781
24 402.5291748046875
25 393.20074462890625
26 384.08380126953125
27 375.23956298828125
28 366.6481628417969
29 358.23779296875
30 350.0379638671875
31 342.0119934082031
32 334.17034912109375
33 326.49560546875
34 319.0382080078125
35 311.7474670410156
36 304.6119689941406
37 297.6454162597656
38 290.829345703125
39 284.1673889160156
40 277.6377258300781
41 271.27862548828125
42 265.0540771484375
43 258.98870849609375
44 253.02772521972656
45 247.19558715820312
46 241.47308349609375
47 235.84817504882812
48 230.3345947265625
49 224.93191528320312
50 219.63356018066406
51 214.43231201171875
52 209.3229217529297
53 204.3197479248047
54 199.41831970214844
55 194.62274169921875
56 189.91372680664062
57 185.30979919433594
58 180.78543090820312
59 176.35043334960938
60 172.03123474121094
61 167.80076599121094
62 163.65994262695312
63 159.6077117919922
64 155.635009765625
65 151.7338409423828
66 147.90628051757812
67 144.1611328125
68 140.49923706054688
69 136.90609741210938
70 133.38592529296875
71 129.92535400390625
72 126.53105163574219
73 123.20957946777344
74 119.96233367919922
75 116.7892074584961
76 113.67752075195312
77 110.62653350830078
78 107.63640594482422
79 104.70710754394531
80 101.84573364257812
81 99.05326843261719
82 96.31414794921875
83 93.63819122314453
84 91.02159118652344
85 88.4598388671875
86 85.95780944824219
87 83.51100158691406
88 81.12576293945312
89 78.79422760009766
90 76.51752471923828
91 74.29196166992188
92 72.11930084228516
93 70.0026626586914
94 67.928955078125
95 65.90653991699219
96 63.93334197998047
97 62.00388717651367
98 60.12252426147461
99 58.285614013671875
100 56.493125915527344
101 54.74164962768555
102 53.03565216064453
103 51.37256622314453
104 49.752525329589844
105 48.17483139038086
106 46.633705139160156
107 45.131378173828125
108 43.67048263549805
109 42.248836517333984
110 40.86577606201172
111 39.517311096191406
112 38.206024169921875
113 36.931732177734375
114 35.69224166870117
115 34.48631286621094
116 33.313079833984375
117 32.17280960083008
118 31.06452751159668
119 29.986743927001953
120 28.940500259399414
121 27.92294692993164
122 26.93610191345215
123 25.97654151916504
124 25.04584503173828
125 24.14299774169922
126 23.267927169799805
127 22.418174743652344
128 21.593399047851562
129 20.793794631958008
130 20.01937484741211
131 19.26983642578125
132 18.54322624206543
133 17.83873748779297
134 17.156911849975586
135 16.497770309448242
136 15.8604736328125
137 15.2435302734375
138 14.64791202545166
139 14.07233715057373
140 13.516242980957031
141 12.979171752929688
142 12.460262298583984
143 11.959256172180176
144 11.475519180297852
145 11.008925437927246
146 10.558966636657715
147 10.125717163085938
148 9.707698822021484
149 9.304882049560547
150 8.917524337768555
151 8.54373550415039
152 8.18483829498291
153 7.838366985321045
154 7.505320072174072
155 7.185374736785889
156 6.878193378448486
157 6.582548141479492
158 6.298079013824463
159 6.024651527404785
160 5.76194429397583
161 5.509881496429443
162 5.267538070678711
163 5.035026550292969
164 4.81190824508667
165 4.597996711730957
166 4.39227294921875
167 4.195413589477539
168 4.00651741027832
169 3.825408697128296
170 3.651704788208008
171 3.4849436283111572
172 3.3250298500061035
173 3.171847343444824
174 3.0250601768493652
175 2.884410858154297
176 2.749636650085449
177 2.620737314224243
178 2.4973068237304688
179 2.379302501678467
180 2.2665505409240723
181 2.158726215362549
182 2.0557680130004883
183 1.9572761058807373
184 1.8632687330245972
185 1.77349853515625
186 1.6877164840698242
187 1.6059437990188599
188 1.527989387512207
189 1.4536229372024536
190 1.3826979398727417
191 1.3150861263275146
192 1.2506059408187866
193 1.1891487836837769
194 1.130806565284729
195 1.0753490924835205
196 1.0225379467010498
197 0.9722328186035156
198 0.9243756532669067
199 0.8788415193557739
200 0.835479736328125
201 0.7942246198654175
202 0.7550785541534424
203 0.7177739143371582
204 0.6823312044143677
205 0.6486000418663025
206 0.6165372729301453
207 0.5860490202903748
208 0.5570669770240784
209 0.5294883847236633
210 0.503280758857727
211 0.47837594151496887
212 0.4547117352485657
213 0.43219828605651855
214 0.41081875562667847
215 0.3905050754547119
216 0.3711991012096405
217 0.3528625965118408
218 0.33542150259017944
219 0.3188553750514984
220 0.3031269311904907
221 0.2881830930709839
222 0.27400603890419006
223 0.26052147150039673
224 0.24771645665168762
225 0.23554985225200653
226 0.223990336060524
227 0.2130102664232254
228 0.20257283747196198
229 0.1926596611738205
230 0.1832451969385147
231 0.17429481446743011
232 0.1657957136631012
233 0.15770511329174042
234 0.1500290334224701
235 0.1427464634180069
236 0.135794535279274
237 0.12920279800891876
238 0.12293926626443863
239 0.11698602885007858
240 0.11132536083459854
241 0.10593777894973755
242 0.10081358253955841
243 0.09593836218118668
244 0.09130439162254333
245 0.08689787983894348
246 0.08270401507616043
247 0.07871619611978531
248 0.07492534071207047
249 0.07131915539503098
250 0.0678882747888565
251 0.06462489813566208
252 0.06152183562517166
253 0.058571480214595795
254 0.05576350539922714
255 0.053093574941158295
256 0.050553057342767715
257 0.048138853162527084
258 0.04583986476063728
259 0.043653454631567
260 0.041573505848646164
261 0.03959491848945618
262 0.03771457076072693
263 0.03592259809374809
264 0.03421742841601372
265 0.032595936208963394
266 0.031053274869918823
267 0.02958434633910656
268 0.028185972943902016
269 0.026855386793613434
270 0.025590816512703896
271 0.02439006045460701
272 0.023246297612786293
273 0.022157778963446617
274 0.021121777594089508
275 0.020135166123509407
276 0.019196175038814545
277 0.018301717936992645
278 0.017449872568249702
279 0.016638539731502533
280 0.015866167843341827
281 0.015130136162042618
282 0.014429199509322643
283 0.013761700130999088
284 0.013125191442668438
285 0.012518973089754581
286 0.011941485106945038
287 0.01139123272150755
288 0.01086719986051321
289 0.0103673180565238
290 0.009891511872410774
291 0.00943682063370943
292 0.009004391729831696
293 0.008591714315116405
294 0.008198557421565056
295 0.007823653519153595
296 0.007466151379048824
297 0.007125228643417358
298 0.006800246890634298
299 0.006490218453109264
300 0.006194571498781443
301 0.005912586115300655
302 0.005643600597977638
303 0.005387018900364637
304 0.00514244195073843
305 0.004908899776637554
306 0.00468617957085371
307 0.00447375513613224
308 0.00427114125341177
309 0.004077683202922344
310 0.003893042216077447
311 0.0037169677671045065
312 0.0035489369183778763
313 0.0033885298762470484
314 0.003235474694520235
315 0.0030893925577402115
316 0.002949965186417103
317 0.0028169089928269386
318 0.0026899331714957952
319 0.002568728756159544
320 0.0024529669899493456
321 0.00234250002540648
322 0.002237096196040511
323 0.002136391820386052
324 0.0020402902737259865
325 0.0019485403317958117
326 0.0018609495600685477
327 0.001777337514795363
328 0.0016974894097074866
329 0.0016212319023907185
330 0.001548543106764555
331 0.0014789194101467729
332 0.0014125668676570058
333 0.0013491882709786296
334 0.0012886538170278072
335 0.001230836845934391
336 0.0011756359599530697
337 0.0011229182127863169
338 0.0010725702159106731
339 0.0010244643781334162
340 0.0009785457514226437
341 0.0009346609003841877
342 0.0008927764720283449
343 0.0008527427562512457
344 0.0008145335013978183
345 0.0007780257146805525
346 0.0007431511767208576
347 0.0007098381174728274
348 0.0006780338590033352
349 0.0006476147682406008
350 0.0006186012760736048
351 0.0005908627063035965
352 0.0005643694312311709
353 0.0005390712758526206
354 0.0005148871568962932
355 0.0004917891346849501
356 0.0004697243857663125
357 0.00044865073869004846
358 0.00042851141188293695
359 0.00040927110239863396
360 0.0003908964863512665
361 0.0003733384655788541
362 0.0003565754450391978
363 0.0003405387979000807
364 0.0003252301539760083
365 0.0003106163057964295
366 0.0002966441388707608
367 0.000283299625152722
368 0.0002705401857383549
369 0.0002583622408565134
370 0.00024672725703567266
371 0.0002356083132326603
372 0.0002249831159133464
373 0.00021484990429598838
374 0.00020514836069196463
375 0.00019588429131545126
376 0.00018703722162172198
377 0.00017859306535683572
378 0.0001705146860331297
379 0.00016280770068988204
380 0.00015543507470283657
381 0.00014839736104477197
382 0.00014167308108881116
383 0.00013525327085517347
384 0.00012911579688079655
385 0.00012324984709266573
386 0.00011766132229240611
387 0.00011231174721615389
388 0.00010720016871346161
389 0.00010232516069663689
390 9.765935828909278e-05
391 9.321245306637138e-05
392 8.896314102457836e-05
393 8.489633182762191e-05
394 8.102042193058878e-05
395 7.732125231996179e-05
396 7.378427108051255e-05
397 7.040343916742131e-05
398 6.717292126268148e-05
399 6.409359775716439e-05
400 6.115188443800434e-05
401 5.833932664245367e-05
402 5.565791070694104e-05
403 5.309678090270609e-05
404 5.065144068794325e-05
405 4.8312889703083783e-05
406 4.608543895301409e-05
407 4.395186260808259e-05
408 4.192261985735968e-05
409 3.9982223825063556e-05
410 3.8130066968733445e-05
411 3.636054680100642e-05
412 3.467079295660369e-05
413 3.305972495581955e-05
414 3.152069984935224e-05
415 3.005408143508248e-05
416 2.865028363885358e-05
417 2.7312342353980057e-05
418 2.603633220132906e-05
419 2.4817110897856764e-05
420 2.3652552044950426e-05
421 2.2546058971784078e-05
422 2.1485664547071792e-05
423 2.0474553821259178e-05
424 1.95126067410456e-05
425 1.8592400010675192e-05
426 1.7714726709527895e-05
427 1.687816984485835e-05
428 1.6079298802651465e-05
429 1.5317518773372285e-05
430 1.4591370018024463e-05
431 1.3899963050789665e-05
432 1.3238307474239264e-05
433 1.26073991850717e-05
434 1.2005950338789262e-05
435 1.1433628060331102e-05
436 1.088749195332639e-05
437 1.0366490641899873e-05
438 9.869987479760312e-06
439 9.395583219884429e-06
440 8.943191460275557e-06
441 8.514545697835274e-06
442 8.102580068225507e-06
443 7.711866601312067e-06
444 7.339545391005231e-06
445 6.983953880990157e-06
446 6.645625035162084e-06
447 6.322632543742657e-06
448 6.015842700435314e-06
449 5.722788500861498e-06
450 5.443753252620809e-06
451 5.177714683668455e-06
452 4.924917902826564e-06
453 4.6825798563077115e-06
454 4.4534272092278115e-06
455 4.234493189869681e-06
456 4.025897396786604e-06
457 3.827206455753185e-06
458 3.6387639283930184e-06
459 3.4581094041641336e-06
460 3.2868751986825373e-06
461 3.1240106181940064e-06
462 2.9683715183637105e-06
463 2.821423777277232e-06
464 2.6804389108292526e-06
465 2.5469707907177508e-06
466 2.4194885099859675e-06
467 2.298792878718814e-06
468 2.183521019105683e-06
469 2.073098812616081e-06
470 1.9687890926434193e-06
471 1.8700706050367444e-06
472 1.775020791683346e-06
473 1.6854221485118615e-06
474 1.6007022622943623e-06
475 1.5190260000963463e-06
476 1.4418430964724394e-06
477 1.3687368891623919e-06
478 1.2988508615308092e-06
479 1.2323052942520007e-06
480 1.1692854968714528e-06
481 1.1097954484284855e-06
482 1.0527102176638437e-06
483 9.983506288335775e-07
484 9.470696795688127e-07
485 8.981135692920361e-07
486 8.516540219716262e-07
487 8.075884920799581e-07
488 7.655362423975021e-07
489 7.257012839545496e-07
490 6.879391776237753e-07
491 6.521823934235726e-07
492 6.180292757562711e-07
493 5.857539804310363e-07
494 5.552342372538988e-07
495 5.255469659459777e-07
496 4.982005066267448e-07
497 4.720054675999563e-07
498 4.4719493530465115e-07
499 4.232516062074865e-07
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值