import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.examples.tutorials.mnist import mnist
import tensorflow.contrib.slim as slim
mnist=input_data.read_data_sets('../share/MNIST_DATA',one_hot=True)
x=tf.placeholder("float",shape=[None,784])
y_=tf.placeholder("float",shape=[None,10])
#cast x to 3D
x_image=tf.reshape(x,[-1,28,28,1])#shape of x is [N,28,28,1]
#conv layer1
net=slim.conv2d(x_image,32,[5,5],scope='conv1')#shape of net is [N,28,28,32]
net=slim.max_pool2d(net,[2,2],scope='pool1')#shape of net is [N,14,14,32]
#conv layer2
net=slim.conv2d(net,64,[5,5],scope='conv2')#shape of net is [N,14,14,64]
net=slim.max_pool2d(net,[2,2],scope='pool2')#shape of net is [N,7,7,64]
#reshape for full connection
net=tf.reshape(net,[-1,7*7*64])#[N,7*7*64]
#fc1
net=slim.fully_connected(net,1024,scope='fc1')#shape of net is [N,1024]
#dropout layer
keep_prob=tf.placeholder('float')
net=tf.nn.dropout(net,keep_prob)
#fc2
net=slim.fully_connected(net,10,scope='fc2')#[N,10]
#softmax
y=tf.nn.softmax(net)#[N,10]
cross_entropy=-tf.reduce_sum(tf.multiply(y_,tf.log(y)))#y and _y have same shape.
train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction=tf.equal(tf.argmax(y,axis=1),tf.argmax(y_,axis=1))#shape of correct_prediction is [N]
accuracy=tf.reduce_mean(tf.cast(correct_prediction,'float'))
init=tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for i in range(10000):
batch=mnist.train.next_batch(50)
if i%100==0:
train_accuracy=sess.run(accuracy,feed_dict={x:batch[0],y_:batch[1],keep_prob:1.0})
print('step %d,training accuracy %g !!!!!!!'%(i,train_accuracy))
sess.run(train_step,feed_dict={x:batch[0],y_:batch[1],keep_prob:0.5})
total_accuracy=sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels,keep_prob:1.0})
print('test_accuracy %s!!!!!!!'%(total_accuracy))
直接贴代码,代码没什么好说的了,我都做了注释了。主要是参考TensorFlow官网上的教程点击打开链接,但是使用了slim模块(slim介绍参考我的这个博客点击打开链接),于是大大缩小了代码量,也提高了代码的可读性,强烈推荐slim模块。当然如果对上述代码中的函数不熟悉的可直接去TensorFlow官网查看API手册,里面介绍得非常详尽。当然在调试代码时最重要的还是关注tensor的shape,于是我在每个tensor变量后都注释了shape,方便调试,也能提高程序的可读性。
后面得到的结果展示