上一章:深度篇——神经网络(六) 细说 数据增强与fine-tuning
本小节,细说 神经网络手写数字代码演示
github 上项目下载:mnist_pro 项目
6. 代码项目演示
(1). 前言
虽然前面我们学了很多 深度神经网络的理论,大概,也知道,是那样训练和测试,但是,项目要怎么做呢?估计,还是有不少人一脸懵逼的。懵逼于如何将理论转成代码,去构建项目。下面,我将用一个简单的 手写数字项目,去为大家简单说明一下。
(2). 首先,要明确项目需求
项目需求,就是想把 手写 的 0 ~ 9 的阿拉伯数字图片识别出来。比方说,发票上面的数字(第一个做手写数字识别的,是1989年,美国的一家银行 聘请大佬写的,当时是用卷积神经网络技术 LeNet-5 写的,它的识别率比深度神经DNN 要好些。后面,会为大家讲解卷积神经网络的。当时这个项目是用来识别支票上面签约的数字 )。训练网络,当然离不开数据,所以,我们先下载数据,数据为已经为大家上传到百度云盘:链接:https://pan.baidu.com/s/13OokGc0h3F5rGrxuSLYj9Q 提取码:qfj6 。
(3). 构建项目
项目结构如下:
上面的模型,是我随意训练 10 个 epoch 得到的精度:0.9614。我之前跑100 个 epoch,精度上到 0.98 多。
(4). 依赖环境和 README.md
依赖环境:
pip install numpy==1.16
pip install easydict
conda install tensorflow-gpu==1.13.1 # 建议不要用 2.0 版本的 tf,坑多
tensorflow 的安装,我前面的博客有详细解说:碎点篇——tensorflow gpu 版本安装 如果不会安装的,可以查看如何安装。
README.md 文件
# mnist_pro
DNN 手写数字预测 2020-02-06
- 项目下载地址:https://github.com/wandaoyi/mnist_pro
- 请到百度云盘下载项目所需要的训练数据:
- 链接:https://pan.baidu.com/s/13OokGc0h3F5rGrxuSLYj9Q 提取码:qfj6
## 参数设置
- 在训练或预测之前,我们要先进行参数设置
- 打开 config.py 文件,对其中的参数或路径进行设置。
## 训练模型
- 运行 mnist_train.py ,简单操作,右键直接 run
- 训练效果如下:
- acc_train: 0.90625
- y_perd: [7 2 1 0 4]
- y_true: [7 2 1 0 4]
- epoch: 10, acc_test: 0.9613999724388123
- epoch: 10, acc_test_2: 0.9606000185012817
- 下面是随意训练的效果,如果想效果好,可以多训练多点epoch
- 也可以自己添加 early-stopping 进去,不麻烦的
## 预测
- 运行 mnist_test.py ,简单操作,右键直接 run
- 运行后,部分预测结果会打印在控制台上
- 预测效果如下:
- 预测值: [7 2 1 0 4]
- 真实值: [7 2 1 0 4]
## tensorboard 日志
- 使用 tensorboard 的好处是,这个日志是实时的,可以一边训练一边看效果图。
- 在 cmd 命令窗口,输入下面命令:
- tensorboard --logdir=G:\work_space\python_space\pro2018_space\wandao\mnist_pro\logs\mnist_log_train --host=localhost
![image](./docs/images/open_tensorboard.png)
- 在 --logdir= 后面是日志的文件夹路径,
- 在 --host= 是用来指定 ip 的,如果不写,则只能电脑的地址,而不能使用 localhost
- 在 谷歌浏览器 上打开 tensorboard 日志: http://localhost:6006/
![image](./docs/images/tensorboard_acc.png)
-
![image](./docs/images/tensorboard_image.png)
-
![image](./docs/images/tensorboard_graph.png)
-
![image](./docs/images/tensorboard_param.png)
-
![image](./docs/images/tensorboard_histograms.png)
-
- 测试日志也是这样打开来看,就不详细去说了。
-
- 关于其他的 ROC 曲线 或 mAP 等,这里就没做这些操作。以后的项目,再操作一番就是了。
下面的文件或代码,里面,都有注释
(5). 配置文件 config.py
#!/usr/bin/env python
# _*_ coding:utf-8 _*_
# ============================================
# @Time : 2020/02/05 13:51
# @Author : WanDaoYi
# @FileName : config.py
# ============================================
from easydict import EasyDict as edict
import os
__C = edict()
cfg = __C
# common options 公共配置文件
__C.COMMON = edict()
# windows 获