MNIST复杂版本

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 载入数据集
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
# 每个批次的大小
batch_size = 100
# 计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size

# 定义两个placeholder
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32)
lr = tf.Variable(0.001, dtype= tf.float32)
# 创建一个简单的神经网络
W1 = tf.Variable(tf.truncated_normal([784, 500], stddev=0.1))
b1 = tf.Variable(tf.zeros(500) + 0.1)
L1 = tf.nn.tanh(tf.matmul(x, W1) + b1)
L1_drop = tf.nn.dropout(L1, keep_prob)

W2 = tf.Variable(tf.truncated_normal([500, 300], stddev=0.1))
b2 = tf.Variable(tf.zeros(300) + 0.1)
L2 = tf.nn.tanh(tf.matmul(L1_drop, W2) + b2)
L2_drop = tf.nn.dropout(L2, keep_prob)

W3 = tf.Variable(tf.truncated_normal([300, 10], stddev=0.1))
b3 = tf.Variable(tf.zeros(10) + 0.1)
prediction = tf.nn.softmax(tf.matmul(L2_drop, W3) + b3)

# 定义交叉熵
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction))
# 定义梯度下降法
train_step = tf.train.AdamOptimizer(lr).minimize(loss)

# 初始化变量
init = tf.global_variables_initializer()

# 结果存储在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(prediction, 1))
# 求准确率
accuracy= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(21):
        sess.run(tf.assign(lr, 0.001 * (0.95 ** epoch)))
        for batch in range(n_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys, keep_prob: 0.7})

        learning_rate = sess.run(lr)
        train_acc = sess.run(accuracy, feed_dict={x: mnist.train.images, y: mnist.train.labels, keep_prob: 0.7})
        test_acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels, keep_prob: 0.7})
        print("Iter" + str(epoch) + ", Training Accuracy" + str(train_acc) + ", Testing Accuracy" + str(test_acc)
              + ",Learning Rate = " + str(learning_rate))

Iter0, Training Accuracy0.929273, Testing Accuracy0.9299,Learning Rate = 0.001
Iter1, Training Accuracy0.946273, Testing Accuracy0.942,Learning Rate = 0.00095
Iter2, Training Accuracy0.952109, Testing Accuracy0.9485,Learning Rate = 0.0009025
Iter3, Training Accuracy0.958273, Testing Accuracy0.9519,Learning Rate = 0.000857375
Iter4, Training Accuracy0.961582, Testing Accuracy0.9538,Learning Rate = 0.000814506
Iter5, Training Accuracy0.967582, Testing Accuracy0.96,Learning Rate = 0.000773781
Iter6, Training Accuracy0.969055, Testing Accuracy0.9601,Learning Rate = 0.000735092
Iter7, Training Accuracy0.971218, Testing Accuracy0.9629,Learning Rate = 0.000698337
Iter8, Training Accuracy0.972509, Testing Accuracy0.9663,Learning Rate = 0.00066342
Iter9, Training Accuracy0.9746, Testing Accuracy0.9652,Learning Rate = 0.000630249
Iter10, Training Accuracy0.974782, Testing Accuracy0.9662,Learning Rate = 0.000598737
Iter11, Training Accuracy0.976764, Testing Accuracy0.9672,Learning Rate = 0.0005688
Iter12, Training Accuracy0.978055, Testing Accuracy0.9675,Learning Rate = 0.00054036
Iter13, Training Accuracy0.978891, Testing Accuracy0.9674,Learning Rate = 0.000513342
Iter14, Training Accuracy0.978982, Testing Accuracy0.9672,Learning Rate = 0.000487675
Iter15, Training Accuracy0.979945, Testing Accuracy0.9703,Learning Rate = 0.000463291
Iter16, Training Accuracy0.980982, Testing Accuracy0.97,Learning Rate = 0.000440127
Iter17, Training Accuracy0.982145, Testing Accuracy0.9699,Learning Rate = 0.00041812
Iter18, Training Accuracy0.982491, Testing Accuracy0.9696,Learning Rate = 0.000397214
Iter19, Training Accuracy0.982473, Testing Accuracy0.9687,Learning Rate = 0.000377354
Iter20, Training Accuracy0.9842, Testing Accuracy0.9716,Learning Rate = 0.000358486

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值