3.tensorflow单层神经网络mnist数字识别:训练,加载模型,预测图像

#coding:utf-8
"""
mnist分类,单层神经网络
保存模型,加载模型,预测图像
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/Users/ming/Downloads/zhangming/pytorch_demo/data/mnist", one_hot=True)
import pylab

tf.reset_default_graph()
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])

w = tf.Variable(initial_value=tf.random_normal([784, 10]), name="weight")
b = tf.Variable(initial_value=tf.zeros([10]), name="bias")

model = tf.matmul(x, w) + b
predict = tf.nn.softmax(model)

cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(predict), reduction_indices=1))

optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(cost)

saver = tf.train.Saver(max_to_keep=2)
batch_size = 32
epochs = 20
display_step = 2
save_model = "mnist_model/mnist.cpkt"
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    total_batch = int(mnist.train.num_examples/batch_size)
    for epoch in range(epochs):
        epoch_loss = 0
        for i in range(total_batch):
            batch_x, batch_y = mnist.train.next_batch(batch_size)
            _, loss = sess.run([optimizer, cost], feed_dict={x:batch_x, y:batch_y})
            epoch_loss += loss
        epoch_loss = epoch_loss/total_batch
        saver.save(sess, save_model, global_step=epoch)
        if epoch % display_step == 0:
            print("epoch %d , loss %.2f" %(epoch, epoch_loss))
    print("done...")

# 预测模型
with tf.Session() as sess2:
    sess2.run(tf.global_variables_initializer())
    saver.restore(sess2, save_model)
    x_test, y = mnist.test.next_batch(2)
    correct_pred = tf.equal(tf.argmax(predict, 1), tf.argmax(y, 1))
    acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    predict, accuracy = sess2.run([predict, acc], feed_dict={x:x_test})
    print("predict:", predict)
    print("acc: %.4f" % accuracy)

    img1 = x_test[0].reshape([-1,28])
    pylab.imshow(img1)
    pylab.show()
    img2 = x_test[1].reshape([-1,28])
    pylab.imshow(img2)
    pylab.show()

输出:

 

 

/usr/local/bin/python2.7 /Users/ming/Downloads/zhangming/tf_demo/3.tf_mnist_1_layer.py
WARNING:tensorflow:From /Users/ming/Downloads/zhangming/tf_demo/3.tf_mnist_1_layer.py:7: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From /usr/local/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From /usr/local/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting /Users/ming/Downloads/zhangming/pytorch_demo/data/mnist/train-images-idx3-ubyte.gz
WARNING:tensorflow:From /usr/local/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting /Users/ming/Downloads/zhangming/pytorch_demo/data/mnist/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From /usr/local/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting /Users/ming/Downloads/zhangming/pytorch_demo/data/mnist/t10k-images-idx3-ubyte.gz
Extracting /Users/ming/Downloads/zhangming/pytorch_demo/data/mnist/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From /usr/local/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: __init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
2018-11-17 23:44:02.858203: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
epoch 0 , loss 5.37
epoch 2 , loss 1.45
epoch 4 , loss 1.07
epoch 6 , loss 0.91
epoch 8 , loss 0.82
epoch 10 , loss 0.76
epoch 12 , loss 0.71
epoch 14 , loss 0.67
epoch 16 , loss 0.64
epoch 18 , loss 0.62
done...
('predict:', array([[6.7180810e-12, 6.8759458e-14, 8.5211534e-12, 2.8013984e-09,
        9.9993575e-01, 1.3801176e-08, 9.9942827e-09, 8.1797918e-10,
        2.9035888e-05, 3.5194662e-05],
       [3.0729464e-06, 4.6555795e-09, 1.1277327e-07, 1.6401351e-06,
        1.5443068e-06, 1.3794046e-07, 9.7556288e-08, 6.1844635e-01,
        1.1378842e-04, 3.8143331e-01]], dtype=float32))
acc: 1.0000

Process finished with exit code 0
 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值