深度篇—— CNN 卷积神经网络(四) 使用 tf cnn 进行 mnist 手写数字 代码演示项目

返回主目录

返回 CNN 卷积神经网络目录

上一章:深度篇—— CNN 卷积神经网络(三) 关于 ROI pooling 和 ROI Align 与 插值

 

本小节,细说 使用 tf cnn 进行 mnist 手写数字 代码演示项目

 

本项目 github 代码:https://github.com/wandaoyi/tf_cnn_mnist_pro

 

五. TF_CNN_MNIST 手写数字代码演示

 

(1). 前言

之前在 深度篇——神经网络 我们学了 ANN 和 DNN,现在,我们又学了 CNN,对于学以致用来说,我们将使用 CNN 来构建 卷积神经网络。

 

(2). 明确需求

项目需求,就是想把 手写 的 0 ~ 9 的阿拉伯数字图片识别出来。比方说,发票上面的数字(第一个做手写数字识别的,是1989年,美国的一家银行 聘请大佬写的,当时是用卷积神经网络技术 LeNet-5 写的。前面 深度篇——神经网络(七) 细说 DNN神经网络手写数字代码演示 已经使用 DNN 案例讲解了;现在,我们就使用 LeNet-5 做一个案例讲解。当时这个项目是用来识别支票上面签约的数字 )。训练网络,当然离不开数据,所以,我们先下载数据,数据为已经为大家上传到百度云盘:链接:https://pan.baidu.com/s/13OokGc0h3F5rGrxuSLYj9Q   提取码:qfj6  。
 

 

(3). 构建项目

项目结构如下:

上面的模型,是我随意训练 10 个 epoch 得到的精度:0.984000。之前我们使用 DNN 的时候,10 个 epoch 才 0.96+ 的精度。这样,我们看到,精度上升了 2 个百分点,这样说,也许有人会不以为意,觉得并不怎么样。但是,如果我们反过来看,可以看成,错误率减少了一半,这样,效果就会非常可观了。在公司做项目的时候,会做事很重要,但是,语言表达能力,也很重要。

 

(4). 环境依赖

环境依赖:

pip install numpy==1.16
pip install easydict
conda install tensorflow-gpu==1.13.1 # 建议不要用 2.0 版本的 tf,坑多

tensorflow 的安装,我前面的博客有详细解说:碎点篇——tensorflow gpu 版本安装  如果不会安装的,可以查看如何安装。

README.md 文件:

# tf_cnn_mnist_pro
tf_cnn 手写数字预测 2020-02-09
- 项目下载地址:https://github.com/wandaoyi/tf_cnn_mnist_pro
- 请到百度云盘下载项目所需要的训练数据:
- 链接:https://pan.baidu.com/s/13OokGc0h3F5rGrxuSLYj9Q   提取码:qfj6 

## 参数设置
- 在训练或预测之前,我们要先进行参数设置
- 打开 config.py 文件,对其中的参数或路径进行设置。

## 模型
- 模型代码 model_net.py
- 在这里,使用了 lenet-5 网络模型来提取特征

## 训练模型
- 运行 cnn_mnist_train.py ,简单操作,右键直接 run
- 训练效果如下:
- acc_train: 1.0
- epoch: 10, acc_test: 0.984000
- 下面是随意训练的效果,如果想效果好,可以多训练多点epoch
- 也可以自己添加 early-stopping 进去,不麻烦的

## 预测
- 运行 cnn_mnist_test.py ,简单操作,右键直接 run
- 运行后,部分预测结果会打印在控制台上
- 预测效果如下:
- 预测值: [7 2 1 0 4]
- 真实值: [7 2 1 0 4]

## tensorboard 日志
- 使用 tensorboard 的好处是,这个日志是实时的,可以一边训练一边看效果图。
- 在 cmd 命令窗口,输入下面命令:
- tensorboard --logdir=G:\work_space\python_space\pro2018_space\wandao\mnist_pro\logs\mnist_log_train --host=localhost
- 在 --logdir= 后面是日志的文件夹路径,
- 在 --host= 是用来指定 ip 的,如果不写,则只能电脑的地址,而不能使用 localhost
- 在 谷歌浏览器 上打开 tensorboard 日志: http://localhost:6006/

- 模型 acc
![image](./docs/images/acc.png)
- 模型结构
![image](./docs/images/graphs.png)

 

下面的文件或代码,里面,都有注释

(5). 配置文件 config.py

#!/usr/bin/env python
# _*_ coding:utf-8 _*_
# ============================================
# @Time     : 2020/02/08 19:23
# @Author   : WanDaoYi
# @FileName : config.py
# ============================================


from easydict import EasyDict as edict
import os


__C = edict()

cfg = __C

# common options 公共配置文件
__C.COMMON = edict()
# windows 获取文件绝对路径, 方便 windows 在黑窗口 运行项目
__C.COMMON.BASE_PATH = os.path.abspath(os.path.dirname(__file__))
# # 获取当前窗口的路径, 当用 Linux 的时候切用这个,不然会报错。(windows也可以用这个)
# __C.COMMON.BASE_PATH = os.getcwd()

__C.COMMON.DATA_PATH = os.path.join(__C.COMMON.BASE_PATH, "dataset")

# 图像的形状
__C.COMMON.DATA_RESHAPE = [-1, 28, 28, 1]
# 图像 rezise 的形状
__C.COMMON.DATA_RESIZE = (32, 32)


# 训练配置
__C.TRAIN =
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值