深度篇——神经网络(七) 细说 DNN神经网络手写数字代码演示

本文通过一个手写数字识别项目,详细解析如何将深度神经网络理论转化为实际代码。项目需求是识别0-9的阿拉伯数字图片,使用的数据集已提供下载链接。文章介绍了项目结构、依赖环境配置、训练与测试代码,并提到了利用Tensorboard进行日志可视化的方法。
摘要由CSDN通过智能技术生成

返回主目录

返回神经网络目录

上一章:深度篇——神经网络(六)  细说 数据增强与fine-tuning

 

本小节,细说 神经网络手写数字代码演示

 

github 上项目下载:mnist_pro 项目

 

6. 代码项目演示

 

(1). 前言

虽然前面我们学了很多 深度神经网络的理论,大概,也知道,是那样训练和测试,但是,项目要怎么做呢?估计,还是有不少人一脸懵逼的。懵逼于如何将理论转成代码,去构建项目。下面,我将用一个简单的 手写数字项目,去为大家简单说明一下。

(2). 首先,要明确项目需求

项目需求,就是想把 手写 的 0 ~ 9 的阿拉伯数字图片识别出来。比方说,发票上面的数字(第一个做手写数字识别的,是1989年,美国的一家银行 聘请大佬写的,当时是用卷积神经网络技术 LeNet-5 写的,它的识别率比深度神经DNN 要好些。后面,会为大家讲解卷积神经网络的。当时这个项目是用来识别支票上面签约的数字 )。训练网络,当然离不开数据,所以,我们先下载数据,数据为已经为大家上传到百度云盘:链接:https://pan.baidu.com/s/13OokGc0h3F5rGrxuSLYj9Q   提取码:qfj6  。
 

(3). 构建项目

项目结构如下:

上面的模型,是我随意训练 10 个 epoch 得到的精度:0.9614。我之前跑100 个 epoch,精度上到 0.98 多。

 

(4). 依赖环境和 README.md

依赖环境:

pip install numpy==1.16
pip install easydict
conda install tensorflow-gpu==1.13.1 # 建议不要用 2.0 版本的 tf,坑多

tensorflow 的安装,我前面的博客有详细解说:碎点篇——tensorflow gpu 版本安装  如果不会安装的,可以查看如何安装。

 

README.md 文件

# mnist_pro
DNN 手写数字预测 2020-02-06
- 项目下载地址:https://github.com/wandaoyi/mnist_pro
- 请到百度云盘下载项目所需要的训练数据:
- 链接:https://pan.baidu.com/s/13OokGc0h3F5rGrxuSLYj9Q   提取码:qfj6 

## 参数设置
- 在训练或预测之前,我们要先进行参数设置
- 打开 config.py 文件,对其中的参数或路径进行设置。

## 训练模型
- 运行 mnist_train.py ,简单操作,右键直接 run
- 训练效果如下:
- acc_train: 0.90625
- y_perd: [7 2 1 0 4]
- y_true: [7 2 1 0 4]
- epoch: 10, acc_test: 0.9613999724388123
- epoch: 10, acc_test_2: 0.9606000185012817
- 下面是随意训练的效果,如果想效果好,可以多训练多点epoch
- 也可以自己添加 early-stopping 进去,不麻烦的

## 预测
- 运行 mnist_test.py ,简单操作,右键直接 run
- 运行后,部分预测结果会打印在控制台上
- 预测效果如下:
- 预测值: [7 2 1 0 4]
- 真实值: [7 2 1 0 4]

## tensorboard 日志
- 使用 tensorboard 的好处是,这个日志是实时的,可以一边训练一边看效果图。
- 在 cmd 命令窗口,输入下面命令:
- tensorboard --logdir=G:\work_space\python_space\pro2018_space\wandao\mnist_pro\logs\mnist_log_train --host=localhost
![image](./docs/images/open_tensorboard.png)
- 在 --logdir= 后面是日志的文件夹路径,
- 在 --host= 是用来指定 ip 的,如果不写,则只能电脑的地址,而不能使用 localhost
- 在 谷歌浏览器 上打开 tensorboard 日志: http://localhost:6006/
![image](./docs/images/tensorboard_acc.png)
- 
![image](./docs/images/tensorboard_image.png)
- 
![image](./docs/images/tensorboard_graph.png)
- 
![image](./docs/images/tensorboard_param.png)
- 
![image](./docs/images/tensorboard_histograms.png)
- 
- 测试日志也是这样打开来看,就不详细去说了。
- 
- 关于其他的 ROC 曲线 或 mAP 等,这里就没做这些操作。以后的项目,再操作一番就是了。

 

下面的文件或代码,里面,都有注释

(5). 配置文件 config.py

#!/usr/bin/env python
# _*_ coding:utf-8 _*_
# ============================================
# @Time     : 2020/02/05 13:51
# @Author   : WanDaoYi
# @FileName : config.py
# ============================================

from easydict import EasyDict as edict
import os


__C = edict()

cfg = __C

# common options 公共配置文件
__C.COMMON = edict()
# windows 获
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值