# Tensorflow实现MNIST数据自编码(1)

1. import tensorflow as tf
2.
3. #导入数据集合
4. from tensorflow.examples.tutorials.mnist import input_data
6.
7. #整体流程，原始图片像素28*28-784
8. #784-》256-》128-》128-》256-》784
9.
10. learning_rate = 0.01
11. n_hidden_1 = 256     #第一层256个结点
12. n_hidden_2 = 128     #第二层128个结点
13. n_input = 784
14.
15. x = tf.placeholder('float',[None,n_input])
16. y = x
17.
18. weights = {
19.     'encoder_h1':tf.Variable(tf.random_normal([n_input,n_hidden_1])),
20.     'encoder_h2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2])),
21.     'decoder_h1':tf.Variable(tf.random_normal([n_hidden_2,n_hidden_1])),
22.     'decoder_h2':tf.Variable(tf.random_normal([n_hidden_1,n_input])),
23. }
24. biases = {
25.     'encoder_b1':tf.Variable(tf.zeros([n_hidden_1])),
26.     'encoder_b2':tf.Variable(tf.zeros([n_hidden_2])),
27.     'decoder_b1':tf.Variable(tf.zeros([n_hidden_1])),
28.     'decoder_b2':tf.Variable(tf.zeros([n_input])),
29. }
30.
31. def encoder(x):
34.     return layer_2
35. def decoder(x):
38.     return layer_2
39.
40. pred = decoder(encoder(x))
41. cost = tf.reduce_mean(tf.pow(y-pred,2))
43.
44. training_epochs = 20  #共迭代20次
45. batch_size = 256      #每次取256个样本
46. display_step = 5      #迭代5次输出一次信息
47.
48. #启动会话
49. with tf.Session() as sess:
50.     sess.run(tf.global_variables_initializer())
51.     total_batch = int(mnist.train.num_examples/batch_size)
52.     #开始训练
53.     for epoch in range(training_epochs):
54.         for i in range(total_batch):
55.             batch_xs,batch_ys = mnist.train.next_batch(batch_size)#取数据
56.             _,c = sess.run([optimizer,cost],feed_dict={x:batch_xs})#训练模型
57.
58.             if epoch % display_step == 0:#输出日志信息
59.                 print("Epoch:",'%4d' % (epoch+1),'cost=',"{:.9f}".format(c))
60.     print('Training Finished!')
61.
62.     correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
63.     accuracy = tf.reduce_mean(tf.cast(correct_prediction,'float'))
64.     print('Accuracy:',1-accuracy.eval({x:mnist.test.images,y:mnist.test.images}))

• 广告
• 抄袭
• 版权
• 政治
• 色情
• 无意义
• 其他

120