import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
mnist=input_data.read_data_sets("mnist",one_hot=True);
#设置超参数
learninig_rate=0.1;
training_epochs=20;
batch_size=100;
display_step=1;
#网络的参数
n_input=784;
n_hidden_1=256;
n_hidden_2=128;
#自动编码器设置好以后用来测试效果的图片的数量
examples_to_show=10;
X=tf.placeholder("float",[None,n_input]);
weights={
"encoder_h1":tf.Variable(tf.random_normal([n_input,n_hidden_1])),
"encoder_h2":tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2])),
"decoder_h1":tf.Variable(tf.random_normal([n_hidden_2,n_hidden_1])),
"decoder_h2":tf.Variable(tf.random_normal([n_hidden_1,n_input]))
}
biases={
"encoder_b1":tf.Variable(tf.random_normal([n_hidden_1])),
"encoder_b2":tf.Variable(tf.random_normal([n_hidden_2])),
"decoder_b1":tf.Variable(tf.random_normal([n_hidden_1])),
"decoder_b2":tf.Variable(tf.random_normal([n_input]))
}
def encoder(x):
layer_1=tf.nn.sigmoid(tf.matmul(x,weights["encoder_h1"]+biases["encoder_b1"]));
layer_2=tf.nn.sigmoid(tf.matmul(layer_1,weights["encoder_h2"]+biases["encoder_b2"]));
return layer_2;
def decoder(x):
layer_1=tf.nn.sigmoid(tf.matmul(x,weights["decoder_h1"]+biases["decoder_b1"]));
layer_2=tf.nn.sigmoid(tf.matmul(layer_1,weights["decoder_h2"]+biases["decoder_b2"]));
return layer_2;
encoder_op=encoder(X);
decoder_op=decoder(encoder_op);
y_pred=decoder_op;
y_true=X;
cost=tf.reduce_mean(tf.square(y_true-y_pred));
optimizer=tf.train.GradientDescentOptimizer(learninig_rate).minimize(cost);
init=tf.global_variables_initializer();
with tf.Session() as sess:
sess.run(init);
total_batch=int(mnist.train.num_examples//batch_size);
for epoch in range(training_epochs):
for i in range(total_batch):
batch_xs,batch_ys=mnist.train.next_batch(batch_size);
_,c=sess.run([optimizer,cost],feed_dict={X:batch_xs})
if epoch % display_step==0:
print("Epoch:","%04d"%(epoch+1),"cost=","{:.9f}".format(c));
print("Optimizer Finished!!!");
#对测试集应用训练好的自动编码网络
encode_decode=sess.run(y_pred,feed_dict={X:mnist.test.images[:examples_to_show]});
#比较测试集原始图像和自动编码器的重建结果
f,a=plt.subplots(2,10,figsize=(10,2));
for i in range(examples_to_show):
a[0][i].imshow(np.reshape(mnist.test.images[i],(28,28)));
a[1][i].imshow(np.reshape(encode_decode[i],(28,28)));
f.show();
plt.draw();
MNIST无监督学习-自编码器
最新推荐文章于 2024-01-02 01:20:11 发布