除之前的两篇博客介绍的随机梯度下降(SGD)和批量梯度下降以外(BGD),结合两种算法,在每个batch中使用一定大小的数据作为每次训练的样本进行参数的迭代,这种方法被称为小批量梯度下降(MBGD, Mini-Batch Gradient Descent)算法,这篇博客我们讨论一下这种算法的简单实现。
数据集构造
## Author: xlwu
## Date: 2020/07/08
import numpy as np
import time
import matplotlib.pyplot as plt
LEN = 20000
X = np.arange(0, LEN)
# rand从-5到+5
np.random.seed(1)
rand = (np.random.random(LEN) * 2 - 1) * 10000
Y = X * 5 + rand
plt.scatter(X, Y)
X = X.reshape(LEN, 1)
Y = Y.reshape(LEN, 1)
allData = np.concatenate((X, Y), axis = 1)
# np.sum(0.5 * (allData[:, 0] * 3 - allData[:, 1]) ** 2)
np.random.shuffle(allData)
# 训练集:测试集 = 4:1
ratio = 0.8
index = (int)(allData.shape[0] * ratio)
trainData = allData[:index]
testData = allData[index:]
这次我们构造的数据集比较大,长度为20 000, 这样方便我们进行比较。另外数据的误差比较大,范围为(-10 000, +10 000),如我们在图中看到的所示。
算法实现
参数设置
# 超参数
# 学习率
lr = 0.00000001
N = trainData.shape[0]
epsilon = 1000
BATCH_SIZE = 1000
# 待估及参数
theta = np.random.rand()
iter = 1
loss_list = []
iter_list = []
theta_list = []
loss = np.inf
特别说明的是,在这里由于数据的误差比较大,因此在学习的过程中,我发现学习率比较设置的足够小,不然很容易就造成了梯度爆炸。
模型训练
# 模型训练
while True:
np.random.shuffle(trainData)
# grad = (theta * x - y) * x
grad = np.sum((theta * trainData[:BATCH_SIZE, 0] - trainData[:BATCH_SIZE, 1]) * trainData[:BATCH_SIZE, 0]) / BATCH_SIZE
theta = theta - lr * grad
loss = np.sum((trainData[:BATCH_SIZE, 0] * theta - trainData[:BATCH_SIZE, 1]) ** 2 / 2)
theta_list.append(theta)
loss_list.append(loss)
iter_list.append(iter)
print("No.%d:\t grad = %f\t theta: %f\tloss: %f" %(iter, grad, theta, loss))
iter += 1
if loss < epsilon:
print("Traing Completed!")
break
if iter > 100:
break
训练Batch大小为1000, 最大迭代次数为100次
训练结果
迭代参数变化
迭代次数 | 梯度 | theta | loss |
---|---|---|---|
1 | -614437401.993430 | 6.532228 | 169557856264.238037 |
2 | 209724697.794626 | 4.434981 | 36192524403.832718 |
3 | -81553154.499855 | 5.250513 | 19805449672.060402 |
4 | 31886785.930983 | 4.931645 | 17301330041.979183 |
5 | -11285970.211199 | 5.044505 | 16047865490.878262 |
6 | 6795588.296486 | 4.976549 | 16071927521.518204 |
7 | -4089631.980910 | 5.017445 | 16590962479.968994 |
8 | -1985730.647060 | 5.037303 | 16176510423.516010 |
9 | 4384690.708843 | 4.993456 | 16589709189.668032 |
10 | -3843769.048469 | 5.031893 | 16416730889.191032 |
11 | 8014685.797436 | 4.951747 | 16841161655.957632 |
12 | -7623170.996524 | 5.027978 | 16756338787.901598 |
13 | 4781305.340600 | 4.980165 | 16813489268.907459 |
14 | -4123972.747732 | 5.021405 | 16230714369.249050 |
15 | 2284718.282477 | 4.998558 | 17498400112.691254 |
16 | -3251878.679545 | 5.031077 | 16030215212.989265 |
17 | 4552496.960414 | 4.985552 | 16675732033.139099 |
18 | -3639754.575472 | 5.021949 | 16870539052.239897 |
19 | 3708732.405082 | 4.984862 | 15849372990.284237 |
20 | 173795.919309 | 4.983124 | 16985019416.401278 |
21 | -2869769.809367 | 5.011822 | 16561306242.674110 |
22 | 50634.259180 | 5.011315 | 16169789117.130127 |
23 | 1264942.212759 | 4.998666 | 17225787702.815445 |
24 | -1555409.482326 | 5.014220 | 16099076364.266872 |
25 | 2329368.936279 | 4.990926 | 17437424420.176945 |
26 | 26954.976601 | 4.990657 | 16591168575.651573 |
27 | -2386042.983801 | 5.014517 | 15886956834.999683 |
28 | -1714915.847631 | 5.031666 | 16466863965.690956 |
29 | 3375012.862933 | 4.997916 | 17319082837.197086 |
30 | 2034349.797292 | 4.977573 | 17134637762.985409 |
31 | -2626768.896781 | 5.003840 | 16584720368.031294 |
32 | -2048732.143321 | 5.024328 | 15939611771.529335 |
33 | 4695746.953454 | 4.977370 | 16906812770.422020 |
34 | -4712643.585894 | 5.024497 | 16930819192.852804 |
35 | 1559090.713019 | 5.008906 | 16762537973.765923 |
36 | -45893.866679 | 5.009365 | 16758676394.960321 |
37 | -1464440.470177 | 5.024009 | 17602614319.908745 |
38 | 5034455.368664 | 4.973664 | 16437618087.631313 |
39 | 152380.525842 | 4.972141 | 15614236431.156599 |
40 | -5215035.498793 | 5.024291 | 16425920628.893679 |
41 | 2147586.183961 | 5.002815 | 16882867610.899406 |
42 | 2247827.415742 | 4.980337 | 16469867474.817997 |
43 | -5163295.579247 | 5.031970 | 16108088080.141354 |
44 | 3151150.179604 | 5.000458 | 16920150490.491007 |
45 | 939167.851211 | 4.991067 | 16292730367.428476 |
46 | -113496.296109 | 4.992202 | 16774658596.425957 |
47 | -2611812.642449 | 5.018320 | 16701788858.448139 |
48 | 1366588.507319 | 5.004654 | 17012824126.470089 |
49 | 3227773.203406 | 4.972376 | 15754034301.061296 |
50 | -4278899.869878 | 5.015165 | 15473357689.792469 |
51 | 1465829.895266 | 5.000507 | 16164610762.756149 |
52 | -3018112.612699 | 5.030688 | 17053806018.661465 |
53 | -233114.001622 | 5.033019 | 16055729137.546135 |
54 | 6523783.308383 | 4.967781 | 17096154072.122627 |
55 | -3390560.489978 | 5.001687 | 16457903807.274166 |
56 | -5392840.688173 | 5.055615 | 16409698052.206661 |
57 | 10337000.173035 | 4.952245 | 16638752276.462385 |
58 | -3588391.943031 | 4.988129 | 16375390259.822605 |
59 | -3073482.812424 | 5.018864 | 16642001368.582855 |
60 | 726830.111205 | 5.011596 | 15984736672.733067 |
61 | -804870.376233 | 5.019644 | 17062799905.361982 |
62 | -103086.067677 | 5.020675 | 16757646691.556469 |
63 | 1668663.191357 | 5.003989 | 16567170422.718788 |
64 | -1576616.177286 | 5.019755 | 16910911600.235109 |
65 | 3111979.903183 | 4.988635 | 16896192531.313564 |
66 | -2775233.057141 | 5.016387 | 17197712624.060017 |
67 | 1696894.683016 | 4.999418 | 17142247482.271784 |
68 | -1398863.825410 | 5.013407 | 16333471931.507477 |
69 | -154469.217082 | 5.014952 | 16741721351.330023 |
70 | 2375204.246554 | 4.991200 | 16994651875.705441 |
71 | -5158691.501060 | 5.042787 | 16759746579.809208 |
72 | 6086703.858385 | 4.981920 | 16322327410.619946 |
73 | -3459806.454003 | 5.016518 | 16659598188.081657 |
74 | 1964286.835723 | 4.996875 | 16212456225.955757 |
75 | -357446.816207 | 5.000449 | 17733536966.081696 |
76 | 1790140.424686 | 4.982548 | 16071199077.246952 |
77 | -1204884.438560 | 4.994597 | 15682580271.927826 |
78 | -210391.871228 | 4.996701 | 16246474846.031092 |
79 | -1256845.139897 | 5.009269 | 16440315580.191532 |
80 | 2110696.705108 | 4.988162 | 17011888013.782473 |
81 | -694178.752243 | 4.995104 | 16644089883.128731 |
82 | -992069.854190 | 5.005025 | 17046681863.138529 |
83 | 1050853.414404 | 4.994516 | 17073817857.214806 |
84 | -3888903.183147 | 5.033405 | 16171534776.551075 |
85 | 1520705.343209 | 5.018198 | 16436875242.582214 |
86 | 5433458.575271 | 4.963863 | 15924583704.911821 |
87 | -5800593.792025 | 5.021869 | 16578294531.967175 |
88 | 747946.728740 | 5.014390 | 17177345981.141333 |
89 | 911307.488206 | 5.005277 | 16751266629.503719 |
90 | 54355.029846 | 5.004733 | 16184136725.305483 |
91 | 1192919.906975 | 4.992804 | 16236613497.467697 |
92 | -1598501.323690 | 5.008789 | 17035745922.302227 |
93 | 3704272.058086 | 4.971746 | 16997214005.110441 |
94 | -3888233.309747 | 5.010629 | 16175567029.206501 |
95 | 798884.767429 | 5.002640 | 17063181204.524981 |
96 | 1857413.849202 | 4.984066 | 16844997621.761753 |
97 | -1576242.587484 | 4.999828 | 16462429502.875530 |
98 | 3408873.056972 | 4.965739 | 17136891861.577848 |
99 | -3631328.551898 | 5.002053 | 16749053209.808189 |
100 | -1134856.792916 | 5.013401 | 16068593487.097910 |
结果展示
theta的变化趋势:
loss的变化趋势:
结果分析
从表格和图片的结果我们可以看到,模型仍然能够较快地达到收敛。但是在收敛结果中,存在一定的抖动。对于不同的batch来说,这个抖动是不可避免的。