优化目标:寻找目标函数最小的权重组合。
梯度下降是往梯度负方向走。
梯度下降只能找到局部最优,无法找到全局最优。
随机梯度下降可能会跨过鞍点,不用对所有样本求导,只是随机选一个样本去更新。
梯度下降算法可以并行运算率高,随机梯度下降算法性能好但时间复杂度高。可以选择折中的方法就是Bach,就是批量的随机梯度下降。
梯度下降算法代码实现:
#梯度下降
import matplotlib.pyplot as plt
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1.0
#前馈模型
def forward(x):
return x * w
#计算mse
def cost(xs, ys):
cost = 0
for x, y in zip(xs, ys):
y_pred = forward(x)
cost += (y_pred - y) ** 2
return cost / len(xs)
#计算梯度
def gradient(xs, ys):
grad = 0
for x, y in zip(xs, ys):
grad += 2 * x *(x*w - y)
return grad/len(xs)
#存放次数和损失
epoch_list = []
cost_list = []
print('predict (before training)', 4, forward(4))
for epoch in range(100):
cost_val = cost(x_data, y_data)
grad_val = gradient(x_data, y_data)
w -= 0.045 * grad_val #0.045为学习率
print('epoch:', epoch, 'w=', w, 'loss=', cost_val)
epoch_list.append(epoch)
cost_list.append(cost_val)
print('predict (after training)', 4, forward(4))
plt.plot(epoch_list, cost_list)
plt.ylabel('cost')
plt.xlabel('epoch')
plt.show()
结果如下:
D:\Study\Anacodna\envs\DL\python.exe D:\Study\Code\LearnPyorch\03.py
Predict (before training) 4 4.0
Epoch: 0 w= 1.4666666666666668 loss= 4.666666666666667
Epoch: 1 w= 1.7155555555555557 loss= 1.3274074074074067
Epoch: 2 w= 1.8482962962962963 loss= 0.3775736625514398
Epoch: 3 w= 1.9190913580246913 loss= 0.10739873068129853
Epoch: 4 w= 1.9568487242798354 loss= 0.030548972282680543
Epoch: 5 w= 1.976985986282579 loss= 0.008689485449295776
Epoch: 6 w= 1.9877258593507088 loss= 0.0024716758611330204
Epoch: 7 w= 1.9934537916537114 loss= 0.0007030544671667267
Epoch: 8 w= 1.9965086888819794 loss= 0.0001999799373274173
Epoch: 9 w= 1.9981379674037223 loss= 5.688318217313286e-05
Epoch: 10 w= 1.9990069159486519 loss= 1.618010515146993e-05
Epoch: 11 w= 1.9994703551726143 loss= 4.602341020862254e-06
Epoch: 12 w= 1.9997175227587276 loss= 1.3091103348236213e-06
Epoch: 13 w= 1.9998493454713213 loss= 3.723691619052323e-07
Epoch: 14 w= 1.999919650918038 loss= 1.0591833938645004e-07
Epoch: 15 w= 1.999957147156287 loss= 3.0127883203364745e-08
Epoch: 16 w= 1.9999771451500197 loss= 8.569709000094413e-09
Epoch: 17 w= 1.999987810746677 loss= 2.4376061155645925e-09
Epoch: 18 w= 1.9999934990648944 loss= 6.933635173220451e-10
Epoch: 19 w= 1.9999965328346103 loss= 1.9722340048226596e-10
Epoch: 20 w= 1.9999981508451254 loss= 5.609910058542032e-11
Epoch: 21 w= 1.9999990137840669 loss= 1.5957077500436847e-11
Epoch: 22 w= 1.999999474018169 loss= 4.538902045415441e-12
Epoch: 23 w= 1.9999997194763568 loss= 1.2910654708931954e-12
Epoch: 24 w= 1.9999998503873904 loss= 3.6723640039091726e-13
Epoch: 25 w= 1.999999920206608 loss= 1.0445835388749555e-13
Epoch: 26 w= 1.9999999574435243 loss= 2.971259855722779e-14
Epoch: 27 w= 1.9999999773032129 loss= 8.451583569452665e-15
Epoch: 28 w= 1.999999987895047 loss= 2.4040060320624316e-15
Epoch: 29 w= 1.999999993544025 loss= 6.838061337110749e-16
Epoch: 30 w= 1.9999999965568134 loss= 1.9450486402996616e-16
Epoch: 31 w= 1.9999999981636338 loss= 5.532582289380589e-17
Epoch: 32 w= 1.9999999990206048 loss= 1.573712494993027e-17
Epoch: 33 w= 1.9999999994776558 loss= 4.4763367583436385e-18
Epoch: 34 w= 1.9999999997214164 loss= 1.2732695708436325e-18
Epoch: 35 w= 1.9999999998514222 loss= 3.6217462890900584e-19
Epoch: 36 w= 1.9999999999207585 loss= 1.0301831624602159e-19
Epoch: 37 w= 1.9999999999577378 loss= 2.9303006500354953e-20
Epoch: 38 w= 1.9999999999774603 loss= 8.335100760491641e-21
Epoch: 39 w= 1.9999999999879787 loss= 2.3708344012213664e-21
Epoch: 40 w= 1.9999999999935887 loss= 6.743793343241322e-22
Epoch: 41 w= 1.9999999999965807 loss= 1.9182889646896314e-22
Epoch: 42 w= 1.9999999999981763 loss= 5.455821876973161e-23
Epoch: 43 w= 1.9999999999990274 loss= 1.5520779885212614e-23
Epoch: 44 w= 1.9999999999994813 loss= 4.4140317521189105e-24
Epoch: 45 w= 1.9999999999997233 loss= 1.2555468094916013e-24
Epoch: 46 w= 1.9999999999998523 loss= 3.5696409556271286e-25
Epoch: 47 w= 1.9999999999999212 loss= 1.0181467785899592e-25
Epoch: 48 w= 1.999999999999958 loss= 2.896140110957243e-26
Epoch: 49 w= 1.9999999999999776 loss= 8.237499222146303e-27
Epoch: 50 w= 1.999999999999988 loss= 2.357067080993807e-27
Epoch: 51 w= 1.9999999999999936 loss= 6.603423160787553e-28
Epoch: 52 w= 1.9999999999999967 loss= 1.9637706159345563e-28
Epoch: 53 w= 1.9999999999999982 loss= 5.030631731003161e-29
Epoch: 54 w= 1.9999999999999991 loss= 1.4725403564125555e-29
Epoch: 55 w= 1.9999999999999996 loss= 3.681350891031389e-30
Epoch: 56 w= 1.9999999999999998 loss= 1.3805065841367707e-30
Epoch: 57 w= 2.0 loss= 3.4512664603419266e-31
Epoch: 58 w= 2.0 loss= 0.0
Epoch: 59 w= 2.0 loss= 0.0
Epoch: 60 w= 2.0 loss= 0.0
Epoch: 61 w= 2.0 loss= 0.0
Epoch: 62 w= 2.0 loss= 0.0
Epoch: 63 w= 2.0 loss= 0.0
Epoch: 64 w= 2.0 loss= 0.0
Epoch: 65 w= 2.0 loss= 0.0
Epoch: 66 w= 2.0 loss= 0.0
Epoch: 67 w= 2.0 loss= 0.0
Epoch: 68 w= 2.0 loss= 0.0
Epoch: 69 w= 2.0 loss= 0.0
Epoch: 70 w= 2.0 loss= 0.0
Epoch: 71 w= 2.0 loss= 0.0
Epoch: 72 w= 2.0 loss= 0.0
Epoch: 73 w= 2.0 loss= 0.0
Epoch: 74 w= 2.0 loss= 0.0
Epoch: 75 w= 2.0 loss= 0.0
Epoch: 76 w= 2.0 loss= 0.0
Epoch: 77 w= 2.0 loss= 0.0
Epoch: 78 w= 2.0 loss= 0.0
Epoch: 79 w= 2.0 loss= 0.0
Epoch: 80 w= 2.0 loss= 0.0
Epoch: 81 w= 2.0 loss= 0.0
Epoch: 82 w= 2.0 loss= 0.0
Epoch: 83 w= 2.0 loss= 0.0
Epoch: 84 w= 2.0 loss= 0.0
Epoch: 85 w= 2.0 loss= 0.0
Epoch: 86 w= 2.0 loss= 0.0
Epoch: 87 w= 2.0 loss= 0.0
Epoch: 88 w= 2.0 loss= 0.0
Epoch: 89 w= 2.0 loss= 0.0
Epoch: 90 w= 2.0 loss= 0.0
Epoch: 91 w= 2.0 loss= 0.0
Epoch: 92 w= 2.0 loss= 0.0
Epoch: 93 w= 2.0 loss= 0.0
Epoch: 94 w= 2.0 loss= 0.0
Epoch: 95 w= 2.0 loss= 0.0
Epoch: 96 w= 2.0 loss= 0.0
Epoch: 97 w= 2.0 loss= 0.0
Epoch: 98 w= 2.0 loss= 0.0
Epoch: 99 w= 2.0 loss= 0.0
Predict (after trsining) 4 8.0
梯度随机下降算法代码实现:
#梯度下降
import matplotlib.pyplot as plt
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1.0
#前馈模型
def forward(x):
return x * w
#计算mse
def loss(x, y):
y_pred = forward(x)
return (y_pred - y) ** 2
#计算梯度
def gradient(xs, ys):
return 2 * x *(x*w - y)
#存放次数和损失
epoch_list = []
cost_list = []
print('predict (before training)', 4, forward(4))
for epoch in range(100):
for x, y in zip(x_data, y_data):
grad = gradient(x, y)
w -= 0.02 * grad #0.02为学习率
l = loss(x, y)
print('epoch:', epoch, 'w=', w, 'loss=', l)
epoch_list.append(epoch)
cost_list.append(l)
print('predict (after training)', 4, forward(4))
plt.plot(epoch_list, cost_list)
plt.ylabel('cost')
plt.xlabel('epoch')
plt.show()
结果如下:
predict (before training) 4 4.0
epoch: 0 w= 1.483904 loss= 2.397195730944001
epoch: 1 w= 1.733644918784 loss= 0.6385052636062378
epoch: 2 w= 1.8625352080047473 loss= 0.17006912134468274
epoch: 3 w= 1.9290549707104179 loss= 0.0452987746280969
epoch: 4 w= 1.9633855541637637 loss= 0.012065558795052085
epoch: 5 w= 1.981103430961702 loss= 0.003213722892772484
epoch: 6 w= 1.9902475563056106 loss= 0.0008559914221101261
epoch: 7 w= 1.9949668028191003 loss= 0.00022799766475633436
epoch: 8 w= 1.9974023870677264 loss= 6.072833651323373e-05
epoch: 9 w= 1.9986593823561054 loss= 1.6175301004096973e-05
epoch: 10 w= 1.9993081125964567 loss= 4.308373612637153e-06
epoch: 11 w= 1.999642919678581 loss= 1.147557203502345e-06
epoch: 12 w= 1.9998157122744369 loss= 3.056576921387877e-07
epoch: 13 w= 1.9999048898419878 loss= 8.141347941386486e-08
epoch: 14 w= 1.9999509140278904 loss= 2.1684893921386116e-08
epoch: 15 w= 1.999974666926138 loss= 5.775881681706368e-09
epoch: 16 w= 1.999986925701912 loss= 1.5384354344494047e-09
epoch: 17 w= 1.999993252407054 loss= 4.097700950863367e-10
epoch: 18 w= 1.9999965175942709 loss= 1.0914434695962365e-10
epoch: 19 w= 1.999998202744333 loss= 2.9071151391702595e-11
epoch: 20 w= 1.9999990724435395 loss= 7.743248889680307e-12
epoch: 21 w= 1.999999521291821 loss= 2.0624536872461804e-12
epoch: 22 w= 1.9999997529406235 loss= 5.493450197046261e-13
epoch: 23 w= 1.999999872493644 loss= 1.4632083736542766e-13
epoch: 24 w= 1.9999999341944796 loss= 3.8973298683631035e-14
epoch: 25 w= 1.9999999660380343 loss= 1.0380736018194994e-14
epoch: 26 w= 1.9999999824723653 loss= 2.7649618196832182e-15
epoch: 27 w= 1.9999999909540578 loss= 7.364616113197862e-16
epoch: 28 w= 1.9999999953314254 loss= 1.9616029613450225e-16
epoch: 29 w= 1.9999999975905673 loss= 5.224829140117585e-17
epoch: 30 w= 1.9999999987565014 loss= 1.3916599750790023e-17
epoch: 31 w= 1.9999999993582354 loss= 3.706755839896269e-18
epoch: 32 w= 1.999999999668788 loss= 9.873113005513243e-19
epoch: 33 w= 1.999999999829063 loss= 2.62975251868284e-19
epoch: 34 w= 1.9999999999117801 loss= 7.004460092080086e-20
epoch: 35 w= 1.99999999995447 loss= 1.8656908587225806e-20
epoch: 36 w= 1.9999999999765021 loss= 4.969380490070246e-21
epoch: 37 w= 1.9999999999878728 loss= 1.3236182302109993e-21
epoch: 38 w= 1.9999999999937412 loss= 3.525416229989081e-22
epoch: 39 w= 1.99999999999677 loss= 9.389661471273712e-23
epoch: 40 w= 1.999999999998333 loss= 2.5013328589353583e-23
epoch: 41 w= 1.9999999999991396 loss= 6.661800971402988e-24
epoch: 42 w= 1.999999999999556 loss= 1.7749370367472766e-24
epoch: 43 w= 1.9999999999997708 loss= 4.725876356561829e-25
epoch: 44 w= 1.9999999999998819 loss= 1.255874449720903e-25
epoch: 45 w= 1.9999999999999392 loss= 3.3476101373958857e-26
epoch: 46 w= 1.9999999999999687 loss= 8.863641131063289e-27
epoch: 47 w= 1.999999999999984 loss= 2.3003183996244704e-27
epoch: 48 w= 1.9999999999999918 loss= 6.184669496932733e-28
epoch: 49 w= 1.9999999999999958 loss= 1.5461673742331831e-28
epoch: 50 w= 1.9999999999999978 loss= 5.048709793414476e-29
epoch: 51 w= 1.999999999999999 loss= 1.262177448353619e-29
epoch: 52 w= 1.9999999999999993 loss= 3.1554436208840472e-30
epoch: 53 w= 1.9999999999999996 loss= 3.1554436208840472e-30
epoch: 54 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 55 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 56 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 57 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 58 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 59 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 60 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 61 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 62 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 63 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 64 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 65 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 66 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 67 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 68 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 69 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 70 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 71 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 72 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 73 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 74 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 75 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 76 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 77 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 78 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 79 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 80 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 81 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 82 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 83 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 84 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 85 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 86 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 87 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 88 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 89 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 90 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 91 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 92 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 93 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 94 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 95 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 96 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 97 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 98 w= 1.9999999999999998 loss= 7.888609052210118e-31
epoch: 99 w= 1.9999999999999998 loss= 7.888609052210118e-31
predict (after training) 4 7.999999999999999
Process finished with exit code 0