# -*- coding: utf-8 -*-
"""
Created on Sun Sep 3 13:48:19 2017
@author: piaodexin
"""
from __future__ import division, print_function, absolute_import
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
import numpy as np
mnist=input_data.read_data_sets('E:\\mnist',one_hot=True)
'''
定义输入层 (28,28) =784
第一层隐含层500个
第二层100个
第三层500
输出层784
这是因为自编码就是希望神经网络自己学习图片特征,然后再用学习到的特征去组成原始图片,所以最后
输出层是(28,28)=784
'''
input_n=784
hidden1_n=500
hidden2_n=100
hidden3_n=500
output_n=784
learn_rate=0.01
batch_size=100
train_epoch=30000
x=tf.placeholder(tf.float32,[None,input_n])
y=tf.placeholder(tf.float32,[None,input_n])
weights1=tf.Variable(tf.truncated_normal([input_n,hidden1_n],stddev=0.1))
bias1=tf.Variable(tf.constant(0.1,shape=[hidden1_n]))
weights2=tf.Variable(tf.truncated_normal([hidden1_n,hidden2_n],stddev=0.1))
bias2=tf.Variable(tf.constant(0.1,shape=[hidden2_n]))
weights3=tf.Variable(tf.truncated_normal([hidden2_n,hidden3_n],stddev=0.1))
bias3=tf.Variable(tf.constant(0.1,shape=[hidden3_n]))
weights4=tf.Variable(tf.truncated_normal([hidden3_n,output_n],stddev=0.1))
bias4=tf.Variable(tf.constant(0.1,shape=[output_n]))
def get_result(x,weights1,bias1,weights2,bias2,weights3,bias3,weights4,bias4):
a1=tf.nn.sigmoid(tf.matmul(x,weights1)+bias1)
a2=tf.nn.sigmoid(tf.matmul(a1,weights2)+bias2)
a3=tf.nn.sigmoid(tf.matmul(a2,weights3)+bias3)
y_=tf.nn.sigmoid(tf.matmul(a3,weights4)+bias4)
return y_
'''
当我一步一步求y_的时候,却出现错误,只能用函数,不知道为什么
'''
y_=get_result(x,weights1,bias1,weights2,bias2,weights3,bias3,weights4,bias4)
loss=tf.reduce_mean(tf.pow(y_-y,2))
train_op=tf.train.RMSPropOptimizer(learn_rate).minimize(loss)
with tf.Session() as sess:
tf.global_variables_initializer().run()
for i in range(train_epoch):
xs,ys=mnist.train.next_batch(batch_size)
if i%1000 == 0:
print('epoch:',i)
print('loss:',sess.run(loss,feed_dict={x:xs,y:xs}))
sess.run(train_op,feed_dict={x:xs,y:xs})
xt=mnist.test.images[:5]
yt=xt
encode_decode=sess.run(y_,feed_dict={x:xt,y:yt})
f,a =plt.subplots(2,5,figsize=(10,2))
for i in range(5):
a[0][i].imshow(np.reshape(mnist.test.images[i],(28,28)))
a[1][i].imshow(np.reshape(encode_decode[i],(28,28)))
f.show()
plt.draw()
#结果展示:上面是原图片,下面是自编码学习到的