2.5mnist手写数字识别之优化算法精讲(百度架构师手把手带你零基础实践深度学习原版笔记系列)
目录
2.5mnist手写数字识别之优化算法精讲(百度架构师手把手带你零基础实践深度学习原版笔记系列)
设置学习率
(学习率大小的选择严重影响着模型的效果,人为选择困难很大,目前已经有很好的优化学习率算法可以拿来直接用)
在深度学习神经网络模型中,通常使用标准的随机梯度下降算法更新参数,学习率代表参数更新幅度的大小,即步长。当学习率最优时,模型的有效容量最大,最终能达到的效果最好。学习率和深度学习任务类型有关,合适的学习率往往需要大量的实验和调参经验。探索学习率最优值时需要注意如下两点:
- 学习率不是越小越好。学习率越小,损失函数的变化速度越慢,意味着我们需要花费更长的时间进行收敛,如 图2 左图所示。
- 学习率不是越大越好。只根据总样本集中的一个批次计算梯度,抽样误差会导致计算出的梯度不是全局最优的方向,且存在波动。在接近最优解时,过大的学习率会导致参数在最优解附近震荡,损失难以收敛,如 图2 右图所示。
图2: 不同学习率(步长过小/过大)的示意图
在训练前,我们往往不清楚一个特定问题设置成怎样的学习率是合理的,因此在训练时可以尝试调小或调大,通过观察Loss下降的情况判断合理的学习率,设置学习率的代码如下所示。
(下面使用学习率(learning rate,lr)为0.01作为基线程序)
#仅优化算法的设置有所差别
with fluid.dygraph.guard():
model = MNIST()
model.train()
#调用加载数据的函数
train_loader = load_data('train')
#设置不同初始学习率
optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.01, parameter_list=model.parameters())
# optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.001, parameter_list=model.parameters())
# optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.1, parameter_list=model.parameters())
EPOCH_NUM = 5
for epoch_id in range(EPOCH_NUM):
for batch_id, data in enumerate(train_loader()):
#准备数据,变得更加简洁
image_data, label_data = data
image = fluid.dygraph.to_variable(image_data)
label = fluid.dygraph.to_variable(label_data)
#前向计算的过程
predict = model(image)
#计算损失,取一个批次样本损失的平均值
loss = fluid.layers.cross_entropy(predict, label)
avg_loss = fluid.layers.mean(loss)
#每训练了200批次的数据,打印下当前Loss的情况
if batch_id % 200 == 0:
print("epoch: {}, batch: {}, loss is: {}".format(epoch_id, batch_id, avg_loss.numpy()))
#后向传播,更新参数的过程
avg_loss.backward()
optimizer.minimize(avg_loss)
model.clear_gradients()
#保存模型参数
fluid.save_dygraph(model.state_dict(), 'mnist')
loading mnist dataset from ./work/mnist.json.gz ...... epoch: 0, batch: 0, loss is: [2.565506] epoch: 0, batch: 200, loss is: [0.51754624] epoch: 0, batch: 400, loss is: [0.28635043] epoch: 1, batch: 0, loss is: [0.20556127] epoch: 1, batch: 200, loss is: [0.21960375] epoch: 1, batch: 400, loss is: [0.2661132] epoch: 2, batch: 0, loss is: [0.14973752] epoch: 2, batch: 200, loss is: [0.1251861