先上代码:
#!/usr/bin/python
#-*-coding:utf-8-*-
import input_data
import tensorflow as tf
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
x = tf.placeholder(tf.float32,shape=[None,784])
y_ = tf.placeholder(tf.float32,shape=[None,10])
x_image = tf.reshape(x,[-1,28,28,1])
################ network #################
# conv1
w_conv1 = tf.Variable(tf.truncated_normal([5,5,1,20],mean=0,stddev=0.1))
b_conv1 = tf.Variable(tf.constant(0.1,shape=[20]))
conv1 = tf.nn.conv2d(x_image,w_conv1,strides=[1,1,1,1],padding='SAME')+b_conv1
relu1 = tf.nn.relu(conv1)
# pool1
pool1 = tf.nn.max_pool(relu1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
# conv2
w_conv2 = tf.Variable(tf.truncated_normal([5,5,20,50],mean=0,stddev=0.1))
b_conv2 = tf.Variable(tf.constant(0.1,shape=[50]))
conv2 = tf.nn.conv2d(pool1,w_conv2,strides=[1,1,1,1],padding='SAME')+b_conv2
relu2 = tf.nn.relu(conv2)
# pool2
pool2 = tf.nn.max_pool(relu2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
# fc1
pool2_flat = tf.reshape(pool2,shape=[-1,7*7*50])
w_fc1 = tf.Variable(tf.truncated_normal([7*7*50,500],mean=0,stddev=0.1))
b_fc1 = tf.Variable(tf.constant(0.1,shape=[500]))
fc1 = tf.matmul(pool2_flat,w_fc1)+b_fc1
relu_fc1 = tf.nn.relu(fc1)
# dropout
keep_prob = tf.placeholder(tf.float32)
fc1_dropout = tf.nn.dropout(relu_fc1, keep_prob)
# fc2
w_fc2 = tf.Variable(tf.truncated_normal([500,10],mean=0,stddev=0.1))
b_fc2 = tf.Variable(tf.constant(0.1,shape=[10]))
fc2 = tf.matmul(fc1_dropout,w_fc2)+b_fc2
# softmax
y = tf.nn.softmax(fc2)
##########train-predict############
# train and evaluate
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.0001).minimize(cross_entropy)
#train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y,1),tf.argmax(y_,1)),tf.float32))
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for i in range(0,20000):
train_x,train_y = mnist.train.next_batch(100)
sess.run(train_step,feed_dict = {x:train_x,y_:train_y,keep_prob:0.5})
if i%100==0:
acc = sess.run(accuracy,feed_dict = {x:train_x,y_:train_y,keep_prob:1.0})
print 'step %d, train_accuracy %g' % (i,acc)
print "test accuracy %g" % sess.run(accuracy,feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})
sess.close()
经过多次调参数测试,做如下总结,欢迎拍砖:
- 收敛的决定性因素是学习率,刚开始设置为0.01,怎么都不行,其实采用哪种优化算法,影响不是特别大。
- 卷积核的个数可以做适当的调整,卷积层1、卷积层2、全连接层1的核分别设置为(20,50,500)和(32,64,1024)对结果几乎没有什么影响,主要是因为输入的图片长宽比较小。
- 是否有RELU层,可能会稍微有点影响收敛的速度,但不明显,就这个例子,结果无妨。
- dropout层可以防止过拟合,加上会比较好。
- 参数的初始化,不是特别影响,就算全是0也没有太大问题。