一 导入库及数据集:
二 模型设计:
使用单层且没有非线性变换的模型,预测输入的图形数字值。其中,模型的输入为784维(28×28)数据,输出为1维数据。
三 训练配置:
训练配置需要先生成模型实例(设为“训练”状态),再设置优化算法和学习率(使用随机梯度下降SGD,学习率设置为0.001)。
四 训练过程:
训练过程采用二层循环嵌套方式:
- 内层循环:负责整个数据集的一次遍历,遍历数据集采用分批次(batch)方式。
- 外层循环:定义遍历数据集的次数,本次训练中外层循环10次,通过参数EPOCH_NUM设置。
epoch_id: 0, batch_id: 0, loss is: [37.40072] epoch_id: 0, batch_id: 1000, loss is: [6.511747] epoch_id: 0, batch_id: 2000, loss is: [2.732864] epoch_id: 0, batch_id: 3000, loss is: [3.4710536] epoch_id: 1, batch_id: 0, loss is: [3.096552] epoch_id: 1, batch_id: 1000, loss is: [2.7121267] epoch_id: 1, batch_id: 2000, loss is: [3.9004908] epoch_id: 1, batch_id: 3000, loss is: [5.8262606] epoch_id: 2, batch_id: 0, loss is: [4.3189907] epoch_id: 2, batch_id: 1000, loss is: [1.8050106] epoch_id: 2, batch_id: 2000, loss is: [5.5903983] epoch_id: 2, batch_id: 3000, loss is: [5.1830134] epoch_id: 3, batch_id: 0, loss is: [2.2760863] epoch_id: 3, batch_id: 1000, loss is: [5.3123417] epoch_id: 3, batch_id: 2000, loss is: [3.207981] epoch_id: 3, batch_id: 3000, loss is: [3.355655] epoch_id: 4, batch_id: 0, loss is: [2.409387] epoch_id: 4, batch_id: 1000, loss is: [4.8867397] epoch_id: 4, batch_id: 2000, loss is: [1.474983] epoch_id: 4, batch_id: 3000, loss is: [4.4998617] epoch_id: 5, batch_id: 0, loss is: [2.2124193] epoch_id: 5, batch_id: 1000, loss is: [5.57606] epoch_id: 5, batch_id: 2000, loss is: [3.913291] epoch_id: 5, batch_id: 3000, loss is: [2.5047076] epoch_id: 6, batch_id: 0, loss is: [3.2275934] epoch_id: 6, batch_id: 1000, loss is: [6.0397263] epoch_id: 6, batch_id: 2000, loss is: [4.522304] epoch_id: 6, batch_id: 3000, loss is: [5.8284745] epoch_id: 7, batch_id: 0, loss is: [2.982819] epoch_id: 7, batch_id: 1000, loss is: [2.6216135] epoch_id: 7, batch_id: 2000, loss is: [6.423638] epoch_id: 7, batch_id: 3000, loss is: [1.9560221] epoch_id: 8, batch_id: 0, loss is: [2.6066465] epoch_id: 8, batch_id: 1000, loss is: [4.659786] epoch_id: 8, batch_id: 2000, loss is: [2.6956224] epoch_id: 8, batch_id: 3000, loss is: [1.8081566] epoch_id: 9, batch_id: 0, loss is: [1.3943908] epoch_id: 9, batch_id: 1000, loss is: [6.2347713] epoch_id: 9, batch_id: 2000, loss is: [3.3525836] epoch_id: 9, batch_id: 3000, loss is: [2.8217497]
完成代码后,自己写了一个28*28的手写数字进行了识别: