# TensorFlow第十二步 自编码器去噪降维

# coding=utf-8
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"]='2' # 只显示 warning 和 Error

###data (50000,784),(1000,784),(1000,784):
import pickle
import gzip
import numpy as np

f = gzip.open('../data/mnist.pkl.gz', 'rb')
f.close()
return (training_data, validation_data, test_data)

def vectorized_result(j):
e = np.zeros(10)
e[j] = 1.0
return e

trainData_in=training_data[0][:50000]
trainData_out=[vectorized_result(j) for j in training_data[1][:50000]]
validData_in=validation_data[0]
validData_out=[vectorized_result(j) for j in validation_data[1]]
testData_in=test_data[0][:100]
testData_out=[vectorized_result(j) for j in test_data[1][:100]]

###自编码降维net 784X256X128X256X784:
import tensorflow as tf
import matplotlib.pyplot as plt

LEAEING_RATE=0.01
TRAINING_EPOCHS=30
BATCH_SIZE=100
DISPLAY_STEP=1
EXAMPLES_TO_SHOW=8
NUM_INPUT=784
NUM_ENCODER1=256
NUM_ENCODER2=128

x_input=tf.placeholder(tf.float32, [None,NUM_INPUT], name='x_input_noisy')
y_desired=tf.placeholder(tf.float32, [None,NUM_INPUT], name='y_ouput_no_noisy')
weights={'we1':tf.Variable(tf.random_normal([NUM_INPUT,NUM_ENCODER1])),\
'we2':tf.Variable(tf.random_normal([NUM_ENCODER1,NUM_ENCODER2])),\
'wd1':tf.Variable(tf.random_normal([NUM_ENCODER2,NUM_ENCODER1])),\
'wd2':tf.Variable(tf.random_normal([NUM_ENCODER1,NUM_INPUT]))
}

biases={'be1':tf.Variable(tf.random_normal([NUM_ENCODER1])),\
'be2':tf.Variable(tf.random_normal([NUM_ENCODER2])),\
'bd1':tf.Variable(tf.random_normal([NUM_ENCODER1])),\
'bd2':tf.Variable(tf.random_normal([NUM_INPUT]))
}

loss_se=tf.reduce_mean(tf.pow(y_desired-y_output,2))
optimizer=tf.train.RMSPropOptimizer(LEAEING_RATE).minimize(loss_se)
init=tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
num_batches=int(len(trainData_in)/BATCH_SIZE)
for epoch in range(TRAINING_EPOCHS):
for i in range(num_batches):
batch_x=trainData_in[i*BATCH_SIZE:(i+1)*BATCH_SIZE]
batch_x_noisy=batch_x+0.3*np.random.randn(BATCH_SIZE,784)
_,cost=sess.run([optimizer,loss_se],feed_dict={x_input:batch_x_noisy,\
y_desired:batch_x})
if epoch % DISPLAY_STEP ==0:
print('Epoch:%04d'%epoch,'cost={:.9f}'.format(cost))
print('Opitimization Finished!')

testData_in_noisy_show=testData_in[:EXAMPLES_TO_SHOW]+0.3*np.random.randn(EXAMPLES_TO_SHOW,784)
y_decode_test=sess.run(y_output,feed_dict={x_input:testData_in_noisy_show
})

f,a=plt.subplots(2,8,figsize=(10,3))
for i in range(EXAMPLES_TO_SHOW):
a[0][i].imshow(np.reshape(testData_in_noisy_show[i],(28,28)))
a[1][i].imshow(np.reshape(y_decode_test[i],(28,28)))
f.show()
plt.draw()
plt.show()