# 成功运行了! import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 载入数据集 mnist = input_data.read_data_sets("MNIST_data", one_hot=True) # 每个批次的大小 batch_size = 100 # 计算一共有多少个批次 n_batch = mnist.train.num_examples // batch_size # 定义两个placeholder x = tf.placeholder(tf.float32, [None, 784]) y = tf.placeholder(tf.float32, [None, 10]) # 创建一个神经网络 W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) prediction = tf.nn.softmax(tf.matmul(x, W) + b) # 二次代价函数 loss = tf.reduce_mean(tf.square(y-prediction)) # 使用梯度下降法 train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) # 初始化变量 initial = tf.global_variables_initializer() # 结果存放在一个布尔型列表中 #argmax返回一位张量中最大的值所在位置 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1)) # 求准确率 把true变成1.0 把fault变成0.0 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) with tf.Session() as sess: sess.run(initial) for epoch in range(21): for batch in range(n_batch): batch_xs, batch_ys = mnist.train.next_batch(batch_size) sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys}) acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}) print("Iter" + str(epoch) + ",Test Accuracy" + str(acc))
# 二次代价函数:
Iter0,Test Accuracy0.8304 Iter1,Test Accuracy0.8704 Iter2,Test Accuracy0.8813 Iter3,Test Accuracy0.8883 Iter4,Test Accuracy0.895 Iter5,Test Accuracy0.8968 Iter6,Test Accuracy0.8992 Iter7,Test Accuracy0.9022 Iter8,Test Accuracy0.9037 Iter9,Test Accuracy0.9054 Iter10,Test Accuracy0.9068 Iter11,Test Accuracy0.9071 Iter12,Test Accuracy0.9074 Iter13,Test Accuracy0.909 Iter14,Test Accuracy0.9093 Iter15,Test Accuracy0.9107 Iter16,Test Accuracy0.9122 Iter17,Test Accuracy0.9128 Iter18,Test Accuracy0.9127 Iter19,Test Accuracy0.9143 Iter20,Test Accuracy0.9136
# 交叉熵代价函数 效果更好一点
Iter0,Test Accuracy0.8242 Iter1,Test Accuracy0.8831 Iter2,Test Accuracy0.8993 Iter3,Test Accuracy0.9048 Iter4,Test Accuracy0.9088 Iter5,Test Accuracy0.9093 Iter6,Test Accuracy0.9122 Iter7,Test Accuracy0.9131 Iter8,Test Accuracy0.9146 Iter9,Test Accuracy0.9164 Iter10,Test Accuracy0.9172 Iter11,Test Accuracy0.918 Iter12,Test Accuracy0.9201 Iter13,Test Accuracy0.9192 Iter14,Test Accuracy0.9196 Iter15,Test Accuracy0.921 Iter16,Test Accuracy0.9211 Iter17,Test Accuracy0.92 Iter18,Test Accuracy0.9209 Iter19,Test Accuracy0.9214 Iter20,Test Accuracy0.922