在【TensorFlow学习笔记(一):手写数字识别之softmax回归】中:
我使用softmax回归算法识别mnist数据集的手写数字,在我机器上的mnist测试集上最好结果是 92.9% 。
在本节中,我们基于【TensorFlow学习笔记(一):手写数字识别之softmax回归】的代码,只修改几行代码,把softmax回归改为含一个隐层的多层感知机。
然后看看在mnist测试集上的准确率会是多少?
我增加一个含 500 个神经元的隐层。
在【TensorFlow学习笔记(一):手写数字识别之softmax回归】的基础上修改:
修改前的代码:
#定义变量 w 和 bw = tf.Variable ( tf.zeros ( [ 784, 10 ] ) )
b = tf.Variable ( tf.zeros ( [ 10 ] ) )
#实现 softmax 模型
y = tf.nn.softmax ( tf.matmul ( x, w ) + b )
修改后的代码:
#定义变量 w1 、 b1、w2、b2w1 = tf.Variable ( tf.truncated_normal ( [ 784, 500 ] , stddev = 0.1 ) )
b1 = tf.Variable ( tf.zeros ( [ 500 ] ) )
w2 = tf.Variable ( tf.truncated_normal ( [ 500, 10 ], stddev = 0.1 ) )
b2 = tf.Variable ( tf.zeros ( [ 10 ] ) )
#含一个隐层的多层感知机
h1 = tf.nn.relu ( tf.matmul ( x, w1 ) + b1 )
y = tf.nn.softmax ( tf.matmul (h1, w2 ) + b2 )
然后运行整个网络,在我的电脑上,准确率是 98.14% ,结果如下:
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
accuracy: 0.9814
>>>
完整代码如下:
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets ( "MNIST_data/", one_hot = True )
x = tf.placeholder ( tf.float32, [ None, 784 ] )
y_ = tf.placeholder ( tf.float32, [ None, 10 ] )
w1 = tf.Variable ( tf.truncated_normal ( [ 784, 500 ] , stddev = 0.1 ) )
b1 = tf.Variable ( tf.zeros ( [ 500 ] ) )
w2 = tf.Variable ( tf.truncated_normal ( [ 500, 10 ] , stddev = 0.1 ) )
b2 = tf.Variable ( tf.zeros ( [ 10 ] ) )
h1 = tf.nn.relu ( tf.matmul ( x, w1 ) + b1 )
y = tf.nn.softmax ( tf.matmul ( h1, w2 ) + b2 )
cross_entropy = tf.reduce_mean ( -tf.reduce_sum ( y_ * tf.log ( y ), 1 ) )
train_step = tf.train.GradientDescentOptimizer ( 0.03 ) .minimize ( cross_entropy )
correct_prediction = tf.equal ( tf.argmax ( y, 1 ) ,tf.argmax ( y_, 1 ) )
accuracy = tf.reduce_mean ( tf.cast ( correct_prediction, "float" ) )
init = tf.global_variables_initializer ( )
with tf.Session() as sess:
sess.run ( init )
for i in range ( 200000 ) :
batch_xs, batch_ys = mnist.train.next_batch ( 100 )
result = sess.run ( [ accuracy, train_step ], feed_dict = { x:batch_xs, y_:batch_ys } )
if ( i % 100 == 0 ):
print ( "accuracy:", sess.run ( accuracy, feed_dict = { x:mnist.test.images, y_:mnist.test.labels } )," step:", i )