深度学习课程 DAY 6 - 图像分类问题:手写数字识别案例(五)
Chapter 3 图像分类问题
3.6 模型优化之优化算法
(1)概述
上一节明确了分类任务的损失函数(优化目标)的相关概念和实现方法,本节我们依旧横向展开"横纵式"教学法,如图所示,本节主要探讨在手写数字识别任务中,使得损失达到最小的参数取值的实现方法。
前提条件:
在优化算法之前,需要进行数据处理、设计神经网络结构,代码与上一节保持一致。
(2)设置学习率
在深度学习神经网络模型中,通常使用标准的随机梯度下降算法更新参数,学习率代表参数更新幅度的大小,即步长。当学习率最优时,模型的有效容量最大,最终能达到的效果最好。学习率和深度学习任务类型有关,合适的学习率往往需要大量的实验和调参经验。探索学习率最优值时需要注意如下两点:
- 学习率不是越小越好。学习率越小,损失函数的变化速度越慢,意味着我们需要花费更长的时间进行收敛,如 图2 左图所示。
- 学习率不是越大越好。只根据总样本集中的一个批次计算梯度,抽样误差会导致计算出的梯度不是全局最优的方向,且存在波动。在接近最优解时,过大的学习率会导致参数在最优解附近震荡,损失难以收敛,如图所示。
在训练前,我们往往不清楚一个特定问题设置成怎样的学习率是合理的,因此在训练时可以尝试调小或调大,通过观察Loss下降的情况判断合理的学习率,设置学习率的代码如下所示。
#仅优化算法的设置有所差别
with fluid.dygraph.guard():
model = MNIST()
model.train()
#调用加载数据的函数
train_loader = load_data('train')
#设置不同初始学习率,模型所有参数都通过优化器优化
optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.01, parameter_list=model.parameters())
# optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.001, parameter_list=model.parameters())
# optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.1, parameter_list=model.parameters())
EPOCH_NUM = 5
for epoch_id in range(EPOCH_NUM):
for batch_id, data in enumerate(train_loader()):
#准备数据,变得更加简洁
image_data, label_data = data
image = fluid.dygraph.to_variable(image_data)
label = fluid.dygraph.to_variable(label_data)
#前向计算的过程
predict = model(image)
#计算损失,取一个批次样本损失的平均值
loss = fluid.layers.cross_entropy(predict, label)
avg_loss = fluid.layers.mean(loss)
#每训练了200批次的数据,打印下当前Loss的情况
if batch_id % 200 == 0:
print("epoch: {}, batch: {}, loss is: {}".format(epoch_id, batch_id, avg_loss.numpy()))
#后向传播,更新参数的过程
avg_loss.backward()
optimizer.minimize(avg_loss)
model.clear_gradients()
#保存模型参数
fluid.save_dygraph(model.state_dict(), 'mnist')
当learning_rate=0.01的结果
loading mnist dataset from ./work/mnist.json.gz ......
epoch: 0, batch: 0, loss is: [2.5479224]
epoch: 0, batch: 200, loss is: [0.51583934]
epoch: 0, batch: 400, loss is: [0.3206303]
epoch: 1, batch: 0, loss is: [0.2574886]
epoch: 1, batch: 200, loss is: [0.33207777]
epoch: 1, batch: 400, loss is: [0.2047088]
epoch: 2, batch: 0, loss is: [0.10459759]
epoch: 2, batch: 200, loss is: [0.14488357]
epoch: 2, batch: 400, loss is: [0.16438884]
epoch: 3, batch: 0, loss is: [0.22483312]
epoch: 3, batch: 200, loss is: [0.11302722]
epoch: 3, batch: 400, loss is: [0.11553524]
epoch: 4, batch: 0, loss is: [0.12319741]
epoch: 4, batch: 200, loss is: [0.07355891]
epoch: 4, batch: 400, loss is: [0.08584274]
当learning_rate=0.001的结果
loading mnist dataset from ./work/mnist.json.gz ......
epoch: 0, batch: 0, loss is: [2.575643]
epoch: 0, batch: 200, loss is: [1.7954494]
epoch: 0, batch: 400, loss is: [1.3650863]
epoch: 1, batch: 0, loss is: [1.1866897]
epoch: