本人是2016年9月份开始接触到TensorFlow,主要是通过TensorFlow网站上的教程来学习。记博客主要被当做自己的学习笔记,因为对深度学习及TensorFlow平台知之甚少,如果错误,请帮忙指出,笔者不甚感激。
MNIST是一个用于手写体识别的字符集。在TensorFlow教程中,第一个例程便讲解了如何通过TensorFlow来实现对该字符集的分类。由于该例子比较简单,因此,它也被称之为深度学习中的hello world程序。本文第1步是mnist实现的源码。2是对源码中的关键函数进行了解析。
mnist源码
#coding:utf-8
#第一句代码的作用是,使Python代码支持中文注释
#!/usr/bin/python2.7
#input_data是一个脚本,用于导入MNIST数据集,并返回一个mnist类
import input_data
mnist = input_data.read_data_sets('MNIST_data',one_hot=True)
import tensorflow as tf
sess = tf.InteractiveSession()
x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
sess.run(tf.initialize_all_variables())
y = tf.nn.softmax(tf.matmul(x,W) + b)
#cross_entropy是交叉熵,用于衡量预测值与真实值之间的差距,如果两者完全一致,其交叉熵为0
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
for i in range(1000):
batch = mnist.train.next_batch(50)
train_step.run(feed_dict={x: batch[0], y_: batch[1]})
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels})
该段源码的运行结果如下图所示:
其检测率为0.9092。
关键函数解析
y = tf.nn.softmax(tf.matmul(x,W) + b)
该函数中有两个重点:1是通过该函数可知网络为单层神经网络。2是对输出层的分类使用softmax。
(1)tf.matmul(x,W)表示将样本x直接与参数W相乘的值作为特征值,且将该值直接输出。这说明该例程的神经网络中为单层神经网络,即输入层直接与输出层相连。单隐层神经网络的结构如下图所示:
(2)tf.nn.softmax()说明采用softmax来对特征值进行分类。softmax分类特别适合与多类别分类的情况。MNIST数据集中的分类任务是将每一个样本判断为1-10中的某个值,属于多类别分类的情况。因此,softmax适用于MNIST的分类。
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
GradientDescentOptimizer(0.01)说明该例程使用梯度下降算法通过最小化cross_entropy来进行参数W的学习,且梯度下降算法的学习率为0.01。