上一章:深度篇—— CNN 卷积神经网络(三) 关于 ROI pooling 和 ROI Align 与 插值
本小节,细说 使用 tf cnn 进行 mnist 手写数字 代码演示项目
本项目 github 代码:https://github.com/wandaoyi/tf_cnn_mnist_pro
五. TF_CNN_MNIST 手写数字代码演示
(1). 前言
之前在 深度篇——神经网络 我们学了 ANN 和 DNN,现在,我们又学了 CNN,对于学以致用来说,我们将使用 CNN 来构建 卷积神经网络。
(2). 明确需求
项目需求,就是想把 手写 的 0 ~ 9 的阿拉伯数字图片识别出来。比方说,发票上面的数字(第一个做手写数字识别的,是1989年,美国的一家银行 聘请大佬写的,当时是用卷积神经网络技术 LeNet-5 写的。前面 深度篇——神经网络(七) 细说 DNN神经网络手写数字代码演示 已经使用 DNN 案例讲解了;现在,我们就使用 LeNet-5 做一个案例讲解。当时这个项目是用来识别支票上面签约的数字 )。训练网络,当然离不开数据,所以,我们先下载数据,数据为已经为大家上传到百度云盘:链接:https://pan.baidu.com/s/13OokGc0h3F5rGrxuSLYj9Q 提取码:qfj6 。
(3). 构建项目
项目结构如下:
上面的模型,是我随意训练 10 个 epoch 得到的精度:0.984000。之前我们使用 DNN 的时候,10 个 epoch 才 0.96+ 的精度。这样,我们看到,精度上升了 2 个百分点,这样说,也许有人会不以为意,觉得并不怎么样。但是,如果我们反过来看,可以看成,错误率减少了一半,这样,效果就会非常可观了。在公司做项目的时候,会做事很重要,但是,语言表达能力,也很重要。
(4). 环境依赖
环境依赖:
pip install numpy==1.16
pip install easydict
conda install tensorflow-gpu==1.13.1 # 建议不要用 2.0 版本的 tf,坑多
tensorflow 的安装,我前面的博客有详细解说:碎点篇——tensorflow gpu 版本安装 如果不会安装的,可以查看如何安装。
README.md 文件:
# tf_cnn_mnist_pro
tf_cnn 手写数字预测 2020-02-09
- 项目下载地址:https://github.com/wandaoyi/tf_cnn_mnist_pro
- 请到百度云盘下载项目所需要的训练数据:
- 链接:https://pan.baidu.com/s/13OokGc0h3F5rGrxuSLYj9Q 提取码:qfj6
## 参数设置
- 在训练或预测之前,我们要先进行参数设置
- 打开 config.py 文件,对其中的参数或路径进行设置。
## 模型
- 模型代码 model_net.py
- 在这里,使用了 lenet-5 网络模型来提取特征
## 训练模型
- 运行 cnn_mnist_train.py ,简单操作,右键直接 run
- 训练效果如下:
- acc_train: 1.0
- epoch: 10, acc_test: 0.984000
- 下面是随意训练的效果,如果想效果好,可以多训练多点epoch
- 也可以自己添加 early-stopping 进去,不麻烦的
## 预测
- 运行 cnn_mnist_test.py ,简单操作,右键直接 run
- 运行后,部分预测结果会打印在控制台上
- 预测效果如下:
- 预测值: [7 2 1 0 4]
- 真实值: [7 2 1 0 4]
## tensorboard 日志
- 使用 tensorboard 的好处是,这个日志是实时的,可以一边训练一边看效果图。
- 在 cmd 命令窗口,输入下面命令:
- tensorboard --logdir=G:\work_space\python_space\pro2018_space\wandao\mnist_pro\logs\mnist_log_train --host=localhost
- 在 --logdir= 后面是日志的文件夹路径,
- 在 --host= 是用来指定 ip 的,如果不写,则只能电脑的地址,而不能使用 localhost
- 在 谷歌浏览器 上打开 tensorboard 日志: http://localhost:6006/
- 模型 acc
![image](./docs/images/acc.png)
- 模型结构
![image](./docs/images/graphs.png)
下面的文件或代码,里面,都有注释
(5). 配置文件 config.py
#!/usr/bin/env python
# _*_ coding:utf-8 _*_
# ============================================
# @Time : 2020/02/08 19:23
# @Author : WanDaoYi
# @FileName : config.py
# ============================================
from easydict import EasyDict as edict
import os
__C = edict()
cfg = __C
# common options 公共配置文件
__C.COMMON = edict()
# windows 获取文件绝对路径, 方便 windows 在黑窗口 运行项目
__C.COMMON.BASE_PATH = os.path.abspath(os.path.dirname(__file__))
# # 获取当前窗口的路径, 当用 Linux 的时候切用这个,不然会报错。(windows也可以用这个)
# __C.COMMON.BASE_PATH = os.getcwd()
__C.COMMON.DATA_PATH = os.path.join(__C.COMMON.BASE_PATH, "dataset")
# 图像的形状
__C.COMMON.DATA_RESHAPE = [-1, 28, 28, 1]
# 图像 rezise 的形状
__C.COMMON.DATA_RESIZE = (32, 32)
# 训练配置
__C.TRAIN =