用tensorflow的slim模块快速实现mnist手写体识别分类

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.examples.tutorials.mnist import mnist
import tensorflow.contrib.slim as slim

mnist=input_data.read_data_sets('../share/MNIST_DATA',one_hot=True)
x=tf.placeholder("float",shape=[None,784])
y_=tf.placeholder("float",shape=[None,10])

#cast x to 3D
x_image=tf.reshape(x,[-1,28,28,1])#shape of x is [N,28,28,1]

#conv layer1
net=slim.conv2d(x_image,32,[5,5],scope='conv1')#shape of net is [N,28,28,32]
net=slim.max_pool2d(net,[2,2],scope='pool1')#shape of net is [N,14,14,32]

#conv layer2
net=slim.conv2d(net,64,[5,5],scope='conv2')#shape of net is [N,14,14,64]
net=slim.max_pool2d(net,[2,2],scope='pool2')#shape of net is [N,7,7,64]

#reshape for full connection
net=tf.reshape(net,[-1,7*7*64])#[N,7*7*64]

#fc1
net=slim.fully_connected(net,1024,scope='fc1')#shape of net is [N,1024]

#dropout layer
keep_prob=tf.placeholder('float')
net=tf.nn.dropout(net,keep_prob)
#fc2
net=slim.fully_connected(net,10,scope='fc2')#[N,10]
#softmax
y=tf.nn.softmax(net)#[N,10]

cross_entropy=-tf.reduce_sum(tf.multiply(y_,tf.log(y)))#y and _y have same shape.
train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction=tf.equal(tf.argmax(y,axis=1),tf.argmax(y_,axis=1))#shape of correct_prediction is [N]
accuracy=tf.reduce_mean(tf.cast(correct_prediction,'float'))

init=tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    for i in range(10000):
        batch=mnist.train.next_batch(50)
        if i%100==0:
            train_accuracy=sess.run(accuracy,feed_dict={x:batch[0],y_:batch[1],keep_prob:1.0})
            print('step %d,training accuracy  %g !!!!!!!'%(i,train_accuracy))
        sess.run(train_step,feed_dict={x:batch[0],y_:batch[1],keep_prob:0.5})

    total_accuracy=sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels,keep_prob:1.0})
    print('test_accuracy  %s!!!!!!!'%(total_accuracy))

直接贴代码,代码没什么好说的了,我都做了注释了。主要是参考TensorFlow官网上的教程点击打开链接,但是使用了slim模块(slim介绍参考我的这个博客点击打开链接),于是大大缩小了代码量,也提高了代码的可读性,强烈推荐slim模块。当然如果对上述代码中的函数不熟悉的可直接去TensorFlow官网查看API手册,里面介绍得非常详尽。当然在调试代码时最重要的还是关注tensor的shape,于是我在每个tensor变量后都注释了shape,方便调试,也能提高程序的可读性。

后面得到的结果展示


  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值