样本[None,784] 标记:[None,10]
https://blog.csdn.net/qq_33144323/article/details/81393769
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
FLAGS=tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('is_train', 1, '选择是测试不是训练')
def weight_init(shape):
w=tf.Variable(tf.random_normal(shape=shape,mean=0.0,stddev=1.0))
return w
def biasd_init(shape):
b=tf.Variable(tf.constant(0,shape=shape,dtype=tf.float32))
return b
def model():
with tf.variable_scope("data"):
x=tf.placeholder(tf.float32,[None,784])
y_true=tf.placeholder(tf.int32,[None,10])
#定义卷积层:卷积,激活,池化
#卷积层二
with tf.variable_scope("conv1"):
#将输入x->[None,28,28,1]
x_reshape=tf.reshape(x,[-1,28,28,1])
#filter=5*5*1 32个,strides=1,padding=same
w_conv1=weight_init([5,5,1,32])
b_conv1=biasd_init([32])
#x->NOne,28,28,32
x_relu1=tf.nn.relu(tf.nn.conv2d(x_reshape,w_conv1,strides=[1,1,1,1],padding="SAME")+b_conv1)
# 激活
x_pool1=tf.nn.max_pool(x_relu1,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")
#卷积层二
with tf.variable_scope("conv2"):
#filter=5*5*32 64个人
w_conv2=weight_init([5,5,32,64])
b_conv2=biasd_init([64])
#x->None,14,14,64
x_relu2=tf.nn.relu(tf.nn.conv2d(x_pool1,w_conv2,strides=[1,1,1,1],padding="SAME")+b_conv2)
#池化#池化2*2,strides=2,[None,7,7,64]
x_pool2=tf.nn.max_pool(x_relu2,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")
#全连接层
with tf.variable_scope("fc"):
#[None,7*7*64]*[7*7*64,10]->[None,10]
w_fc=weight_init([7*7*64,10])
b_fc=biasd_init([10])
x_fc_reshape=tf.reshape(x_pool2,[-1,7*7*64])
y_predict=tf.matmul(x_fc_reshape,w_fc)+b_fc
return x,y_true,y_predict
def conv_fc():
mnist = input_data.read_data_sets("./temp/data", one_hot=True)
x,y_true,y_predict=model()
with tf.variable_scope("soft_cross"):
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_predict))
with tf.variable_scope("optimizer"):
train_op = tf.train.GradientDescentOptimizer(0.001).minimize(loss)
with tf.variable_scope("acc"):
equal_list = tf.equal(tf.argmax(y_true, 1), tf.argmax(y_predict, 1))
accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32))
init_op=tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
if FLAGS.is_train == True:
for i in range(2000):
x_train,y_train=mnist.train.next_batch(50)
sess.run(train_op,feed_dict={x:x_train,y_true:y_train})
print("训练第%d次,准确率为:%f"%(i,sess.run(accuracy,feed_dict={x:x_train,y_true:y_train})))
saver.save(sess, "./temp/ckpt/fc_model.ckpt")
else:
saver.restore(sess, "./temp/ckpt/fc_model.ckpt")
for i in range(100):
x_test, y_test = mnist.test.next_batch(1)
print("第%id张图,数字目标是:%d,预测结果是%d" % (i, tf.argmax(y_test, 1).eval(), tf.argmax(sess.run(y_predict, feed_dict={x: x_test, y_true: y_test}), 1).eval()))
return None
if __name__ == '__main__':
conv_fc()
结果运行:
打开python终端测试:
C:\Users\Steven\Desktop\pythonLearn\tensorflow>python tenFw_mnist_02.py --is_train=0