两层FC层做分类:MNIST
refer: http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_beginners.html
@author: huangyongye
@date: 2017-02-24
在本教程中,我们来实现一个非常简单的两层全连接网络来完成MNIST数据的分类问题。
输入[-1,28*28], FC1 有 1024 个neurons, FC2 有 10 个neurons。这么简单的一个全连接网络,结果测试准确率达到了 0.98。还是非常棒的!!!
import numpy as np
import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.InteractiveSession(config=config)
1. 导入数据
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
print 'training data shape ', mnist.train.images.shape
print 'training label shape ', mnist.train.labels.shape
training data shape (55000, 784)
training label shape (55000, 10)
2. 构建网络
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
X_ = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
W_fc1 = weight_variable([784, 1024])
b_fc1 = bias_variable([1024])
h_fc1 = tf.nn.relu(tf.matmul(X_, W_fc1) + b_fc1)
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y_pre = tf.nn.softmax(tf.matmul(h_fc1, W_fc2) + b_fc2)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
3. 训练和评估
cross_entropy = -tf.reduce_sum(y_ * tf.log(y_pre))
train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_pre, 1), tf.arg_max(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
sess.run(tf.global_variables_initializer())
for i in range(5000):
x_batch,y_batch=mnist.train.next_batch(batch_size=100)
train_step.run(feed_dict={x_:x_batch,y_:y_batch})
if (i+1)%200==0:
train_accuracy=accuracy.eval(feed_dict={x_:mnist.train.images,y_:mnist.train.labels})
print('step %d,training acc %g'%(i+1,train_accuracy))
if (i+1)%1000==0:
test_accuracy = accuracy.eval(feed_dict={x_: mnist.test.images, y_: mnist.test.labels})
print("= " * 10, "step %d, testing acc %g" % (i + 1, test_accuracy))
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22