TensorFlow-3: 用 feed-forward neural network 识别数字

原创 2017年04月26日 10:38:47

今天继续看 TensorFlow Mechanics 101:
https://www.tensorflow.org/get_started/mnist/mechanics

完整版教程可以看中文版tutorial:
http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_tf.html

这一节讲了使用 MNIST 数据集训练并评估一个简易前馈神经网络(feed-forward neural network)

input,output 和前两节是一样的:即划分数据集并预测图片的 label

data_sets.train 55000个图像和标签(labels),作为主要训练集。
data_sets.validation    5000个图像和标签,用于迭代验证训练准确度。
data_sets.test  10000个图像和标签,用于最终测试训练准确度(trained accuracy)。

主要有两个代码:

mnist.py
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/mnist/mnist.py

  • 构建一个全连接网络,由 2 个隐藏层,1 个 `softmax_linearv 输出构成
  • 定义损失函数,用 `cross entropyv
  • 定义训练时的优化器,用 GradientDescentOptimizer
  • 定义评价函数,用 tf.nn.in_top_k

fully_connected_feed.py
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/mnist/fully_connected_feed.py

  • placeholder_inputs 传入 batch size,得到 image 和 label 的两个placeholder
  • 定义生成 feed_dict 的函数,key 是 placeholders,value 是 data
  • 定义 do_eval 函数,每隔 1000 个训练步骤,就对模型进行以下评估,分别作用于训练集、验证集和测试集
  • 训练时:
    • 导入数据
    • 得到 image 和 label 两个 placeholder
    • 传入 mnist.inference 定义的 NN, 得到 predictions
    • 将 predictions 传给 mnist.loss 计算 loss
    • loss 传给 mnist.training 进行优化训练
    • 再用 mnist.evaluation 评价预测值和实际值

代码中涉及到下面几个函数:

with tf.Graph().as_default():
即所有已经构建的操作都要与默认的 tf.Graph 全局实例关联起来,tf.Graph 实例是一系列可以作为整体执行的操作

summary = tf.summary.merge_all():
为了释放 TensorBoard 所使用的 events file,所有的即时数据都要在图表构建时合并至一个操作 op 中,每次运行 summary 时,都会向 events file 中写入最新的即时数据

summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph):
用于写入包含了图表本身和即时数据具体值的 events file。

saver = tf.train.Saver():
就是向训练文件夹中写入包含了当前所有可训练变量值 checkpoint file

with tf.name_scope('hidden1'):
主要用于管理一个图里面的各种 op,返回的是一个以 scope_name 命名的 context manager,一个 graph 会维护一个 name_space 的堆,实现一种层次化的管理,避免各个 op 之间命名冲突。例如,如果额外使用 tf.get_variable() 定义的变量是不会被 tf.name_scope() 当中的名字所影响的

tf.nn.in_top_k(logits, labels, 1):
意思是在 K 个最有可能的预测中如果可以发现 true,就将输出标记为 correct。本文 K 为 1,也就是只有在预测是 true 时,才判定它是 correct。


推荐阅读
历史技术博文链接汇总
也许可以找到你想要的

版权声明:本文为博主原创文章,未经博主允许不得转载。

一步一步分析讲解神经网络基础-Feedforward Neural Network

A feedforward neural network is an artificial neural network wherein connections between the units d...
  • jk981811667
  • jk981811667
  • 2017年12月25日 13:39
  • 219

Feedforward Neural Network Language Model(NNLM)原理及数学推导

本文来自CSDN博客,转载请注明出处:http://blog.csdn.net/a635661820/article/details/44130285         这一篇是Bengio大牛用神经网...
  • a635661820
  • a635661820
  • 2015年03月08日 06:32
  • 9730

DeepLearning--Part2--Chapter6:Feedforward-Deep-Networks(1)

Part 2 : Deep Networks: Modern Practices本书的这部分内容主要介绍一些已经有实际应用的深度学习方法。深度学习拥有很长的历史,也有宏大的愿景。一些深度学习方法尚未成...
  • meanme
  • meanme
  • 2016年03月06日 15:57
  • 1254

Feedforward Deep Networks(要点)

Feedforward Deep Networks(要点)
  • u011762313
  • u011762313
  • 2015年10月16日 20:25
  • 1418

Matlab 简单使用 Neural Network Toolbox 的 GUI 之 nnstart

Matlab中的Neural Network Toolbox我也是第一次使用,之前在coursera上上Machine Learning 的课完全是写代码,并没有使用封装好的库。 在命令行窗口中写入...
  • u013429988
  • u013429988
  • 2017年12月10日 11:31
  • 266

《Understanding the difficulty of training deep feedforward neural networks》笔记

Sigmod为什么不适合深度学习,交叉熵代价函数和平方差代价函数,神经网络权重初始化...
  • KangRoger
  • KangRoger
  • 2017年03月11日 11:13
  • 1239

Paper Note - Learning to Hash with Binary Deep Neural Network

本文来自ECCV2016,这里主要记录一下自己读完论文的收获。原文链接:Learning to Hash with Binary Deep Neural Network - ECCV2016 自制p...
  • lcx543576178
  • lcx543576178
  • 2016年12月09日 21:23
  • 479

卷积神经网络Convolutional Neural Network (CNN)

卷积神经网络 转载请注明:http://blog.csdn.net/stdcoutzyx/article/details/41596663 自今年七月份以来,一直在实验室负责卷积神经网络(Conv...
  • GarfieldEr007
  • GarfieldEr007
  • 2016年03月31日 12:56
  • 3724

recursive neural network梳理(CS224D lecture9)

1.概念层级的讨论 把词语映射成向量,我们可以用距离远近来评价词语语义的相似度 。同样,对于短语、句子,也可以把它们映射到向量空间中来表示它们的语义。对于词级别的向量表示,常用的有分布式表示,布朗聚...
  • mengmengz07
  • mengmengz07
  • 2016年05月08日 23:29
  • 1141

人工神经网络(ANN, artificial neural network)

人工神经网络是由具有适应性的简单单元组成的广泛并行互连的网络,它的组织能够模拟生物神经系统对真实世界物体所作出的交互反应。 人工神经网络研究的局限性: l  研究受到脑科学研究成果的限...
  • whycold
  • whycold
  • 2012年06月23日 15:13
  • 7409
内容举报
返回顶部
收藏助手
不良信息举报
您举报文章:TensorFlow-3: 用 feed-forward neural network 识别数字
举报原因:
原因补充:

(最多只允许输入30个字)