小批量梯度下降简单实现

除之前的两篇博客介绍的随机梯度下降(SGD)和批量梯度下降以外(BGD),结合两种算法,在每个batch中使用一定大小的数据作为每次训练的样本进行参数的迭代,这种方法被称为小批量梯度下降(MBGD, Mini-Batch Gradient Descent)算法,这篇博客我们讨论一下这种算法的简单实现。

数据集构造

## Author: xlwu
## Date:   2020/07/08

import numpy as np
import time
import matplotlib.pyplot as plt

LEN = 20000
X = np.arange(0, LEN)

# rand从-5到+5
np.random.seed(1)
rand = (np.random.random(LEN) * 2 - 1) * 10000
Y = X * 5 + rand

plt.scatter(X, Y)

X = X.reshape(LEN, 1)
Y = Y.reshape(LEN, 1)
allData = np.concatenate((X, Y), axis = 1)
# np.sum(0.5 * (allData[:, 0] * 3 - allData[:, 1]) ** 2)

np.random.shuffle(allData)

# 训练集:测试集 = 4:1
ratio = 0.8
index = (int)(allData.shape[0] * ratio)
trainData = allData[:index]
testData = allData[index:]

dataDistribution
这次我们构造的数据集比较大,长度为20 000, 这样方便我们进行比较。另外数据的误差比较大,范围为(-10 000, +10 000),如我们在图中看到的所示。

算法实现

参数设置


# 超参数
# 学习率
lr = 0.00000001
N = trainData.shape[0]
epsilon = 1000
BATCH_SIZE = 1000

# 待估及参数
theta = np.random.rand()

iter = 1
loss_list = []
iter_list = []
theta_list = []
loss = np.inf

特别说明的是,在这里由于数据的误差比较大,因此在学习的过程中,我发现学习率比较设置的足够小,不然很容易就造成了梯度爆炸。

模型训练

# 模型训练
while True:
    np.random.shuffle(trainData)
    # grad = (theta * x - y) * x
    grad = np.sum((theta * trainData[:BATCH_SIZE, 0] - trainData[:BATCH_SIZE, 1]) * trainData[:BATCH_SIZE, 0]) / BATCH_SIZE
    theta = theta - lr * grad
    
    loss = np.sum((trainData[:BATCH_SIZE, 0] * theta - trainData[:BATCH_SIZE, 1]) ** 2 / 2)
    theta_list.append(theta)
    loss_list.append(loss)
    iter_list.append(iter)
    print("No.%d:\t grad = %f\t theta: %f\tloss: %f" %(iter, grad, theta, loss))
    iter += 1
    if loss < epsilon:
        print("Traing Completed!")
        break
    if iter > 100:
        break

训练Batch大小为1000, 最大迭代次数为100次

训练结果

迭代参数变化

迭代次数梯度thetaloss
1-614437401.9934306.532228169557856264.238037
2209724697.7946264.43498136192524403.832718
3-81553154.4998555.25051319805449672.060402
431886785.9309834.93164517301330041.979183
5-11285970.2111995.04450516047865490.878262
66795588.2964864.97654916071927521.518204
7-4089631.9809105.01744516590962479.968994
8-1985730.6470605.03730316176510423.516010
94384690.7088434.99345616589709189.668032
10-3843769.0484695.03189316416730889.191032
118014685.7974364.95174716841161655.957632
12-7623170.9965245.02797816756338787.901598
134781305.3406004.98016516813489268.907459
14-4123972.7477325.02140516230714369.249050
152284718.2824774.99855817498400112.691254
16-3251878.6795455.03107716030215212.989265
174552496.9604144.98555216675732033.139099
18-3639754.5754725.02194916870539052.239897
193708732.4050824.98486215849372990.284237
20173795.9193094.98312416985019416.401278
21-2869769.8093675.01182216561306242.674110
2250634.2591805.01131516169789117.130127
231264942.2127594.99866617225787702.815445
24-1555409.4823265.01422016099076364.266872
252329368.9362794.99092617437424420.176945
2626954.9766014.99065716591168575.651573
27-2386042.9838015.01451715886956834.999683
28-1714915.8476315.03166616466863965.690956
293375012.8629334.99791617319082837.197086
302034349.7972924.97757317134637762.985409
31-2626768.8967815.00384016584720368.031294
32-2048732.1433215.02432815939611771.529335
334695746.9534544.97737016906812770.422020
34-4712643.5858945.02449716930819192.852804
351559090.7130195.00890616762537973.765923
36-45893.8666795.00936516758676394.960321
37-1464440.4701775.02400917602614319.908745
385034455.3686644.97366416437618087.631313
39152380.5258424.97214115614236431.156599
40-5215035.4987935.02429116425920628.893679
412147586.1839615.00281516882867610.899406
422247827.4157424.98033716469867474.817997
43-5163295.5792475.03197016108088080.141354
443151150.1796045.00045816920150490.491007
45939167.8512114.99106716292730367.428476
46-113496.2961094.99220216774658596.425957
47-2611812.6424495.01832016701788858.448139
481366588.5073195.00465417012824126.470089
493227773.2034064.97237615754034301.061296
50-4278899.8698785.01516515473357689.792469
511465829.8952665.00050716164610762.756149
52-3018112.6126995.03068817053806018.661465
53-233114.0016225.03301916055729137.546135
546523783.3083834.96778117096154072.122627
55-3390560.4899785.00168716457903807.274166
56-5392840.6881735.05561516409698052.206661
5710337000.1730354.95224516638752276.462385
58-3588391.9430314.98812916375390259.822605
59-3073482.8124245.01886416642001368.582855
60726830.1112055.01159615984736672.733067
61-804870.3762335.01964417062799905.361982
62-103086.0676775.02067516757646691.556469
631668663.1913575.00398916567170422.718788
64-1576616.1772865.01975516910911600.235109
653111979.9031834.98863516896192531.313564
66-2775233.0571415.01638717197712624.060017
671696894.6830164.99941817142247482.271784
68-1398863.8254105.01340716333471931.507477
69-154469.2170825.01495216741721351.330023
702375204.2465544.99120016994651875.705441
71-5158691.5010605.04278716759746579.809208
726086703.8583854.98192016322327410.619946
73-3459806.4540035.01651816659598188.081657
741964286.8357234.99687516212456225.955757
75-357446.8162075.00044917733536966.081696
761790140.4246864.98254816071199077.246952
77-1204884.4385604.99459715682580271.927826
78-210391.8712284.99670116246474846.031092
79-1256845.1398975.00926916440315580.191532
802110696.7051084.98816217011888013.782473
81-694178.7522434.99510416644089883.128731
82-992069.8541905.00502517046681863.138529
831050853.4144044.99451617073817857.214806
84-3888903.1831475.03340516171534776.551075
851520705.3432095.01819816436875242.582214
865433458.5752714.96386315924583704.911821
87-5800593.7920255.02186916578294531.967175
88747946.7287405.01439017177345981.141333
89911307.4882065.00527716751266629.503719
9054355.0298465.00473316184136725.305483
911192919.9069754.99280416236613497.467697
92-1598501.3236905.00878917035745922.302227
933704272.0580864.97174616997214005.110441
94-3888233.3097475.01062916175567029.206501
95798884.7674295.00264017063181204.524981
961857413.8492024.98406616844997621.761753
97-1576242.5874844.99982816462429502.875530
983408873.0569724.96573917136891861.577848
99-3631328.5518985.00205316749053209.808189
100-1134856.7929165.01340116068593487.097910

结果展示

theta的变化趋势:
theta_list
loss的变化趋势:
loss_list

结果分析

从表格和图片的结果我们可以看到,模型仍然能够较快地达到收敛。但是在收敛结果中,存在一定的抖动。对于不同的batch来说,这个抖动是不可避免的。

  • 2
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值