一、前言
本文将创建一个784x256x128的全连接网络,用来识别手写数字,不依赖与keras,手动求梯度更新参数,训练10个epoch后测试几张手写数字。
二、功能函数
1、设置 gpu 显存按需申请
def
2、设置图片预处理回调
def
3、生成数据集
def
4、前向计算
def
5、计算梯度并更新参数
def
6、计算验证准确率
def
7、训练样本
def
8、绘制训练结果
def
9、获取一张本地待识别图片
def
10、识别多张图片
def
三、训练主逻辑
"""训练主逻辑"""
四、训练结果
准确率一直在上升,最终准确率为0.9210662939297125,但是可以看出15000step后loss有点反弹
五、测试
# 测试
上图为OpenCV自带digits.png里切割出来的,大小为20x20
测试结果为:[1 2 4 7]