encoder+decoder+show输入和输出的图片+encoder之后的向量分类
import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=False)
#hyper parameters
batch_size = 32
save_path = 'model'
max_train_epoch = 5
lr = 0.001
n_inputs = 784
classes = 10
hidden1 = 256
hidden2 = 2
d_hidden1 = 256
d_hidden2 = 784
example_to_show = 10
#network structure
class nn(object):
def __init__(self, inputs, name='nn', trainning=True, reuse=False):
with tf.variable_scope(name, reuse=reuse):
self.global_step = tf.Variable(0, trainable=False, dtype=tf.int32)
with tf.variable_scope('encoder'):
with tf.variable_scope('hidden1'):
hd1 = tf.layers.dense(inputs, hidden1, activation=tf.nn.sigmoid, name='hidden1')
with tf.variable_scope('hidden2'):
hd2 = tf.layers.dense(hd1, hidden2, activation=tf.nn.sigmoid, name='hidden2')
with tf.variable_scope('decoder'):
with tf.variable_scope('d_hidden1'):
d_hd1 = tf.layers.dense(hd2, d_hidden1, activation=tf.nn.sigmoid, name='d_hidden1')
with tf.variable_scope('d_hidden2'):
d_hd2 = tf.layers.dense(d_hd1, d_hidden2, activation=tf.nn.sigmoid, name='d_hidden1')
self.y_pred = d_hd2
self.y_gt = inputs
self.encoder_vec = hd2
with tf.variable_scope('loss'):
self.loss = tf.reduce_mean(tf.square(self.y_pred - self.y_gt))
def summary(self):
tf.summary.scalar('loss', self.loss)
#placehoder
input_x = tf.placeholder(dtype=tf.float32, shape=(None, n_inputs),name='inputs')
labels = tf.placeholder(dtype=tf.float32, shape=(None, classes), name='labels')
#model
train_model = nn(input_x)
#opt
train_up = tf.train.AdamOptimizer(lr).minimize(train_model.loss, train_model.global_step)
#save
saver = tf.train.Saver()
with tf.Session() as sess:
#recoder, summary
train_model.summary()
train_writer = tf.summary.FileWriter('log', graph=sess.graph)
merged = tf.summary.merge_all()
#restore or initail
ckpt = tf.train.get_checkpoint_state(save_path)
if ckpt:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
saver.restore(sess, os.path.join(save_path, ckpt_name))
else:
sess.run(tf.global_variables_initializer())
#circulation start
global_step_val = sess.run(train_model.global_step)
tot_batch = int(mnist.train.num_examples / batch_size)
now_epoch = int(global_step_val / tot_batch)
while now_epoch < max_train_epoch:
print('epoch:', now_epoch)
for i in range(tot_batch):
batch_images, batch_labels = mnist.train.next_batch(batch_size)
_, tmp_loss = sess.run([train_up, train_model.loss], feed_dict={input_x:batch_images})
global_step_val += 1
if global_step_val % 100 == 0:
saver.save(sess, os.path.join(save_path, 'nn.ckpt'), global_step_val)
merged_summary = sess.run(merged, feed_dict={input_x:mnist.test.images[:batch_size]})
train_writer.add_summary(merged_summary, global_step_val)
epoch_loss = tmp_loss
print('loss', epoch_loss)
now_epoch += 1
result = sess.run(train_model.y_pred, feed_dict={input_x:mnist.test.images[:example_to_show]})
f, a = plt.subplots(2, 10, figsize=(10, 2))
# f, a = plt.subplots(2, 10, figsize=(10, 2))
for i in range(10):
a[0][i].imshow(np.reshape(mnist.test.images[i], (28, 28)))
a[1][i].imshow(np.reshape(result[i], (28, 28)))
plt.show()
all_vec = sess.run(train_model.encoder_vec, feed_dict={input_x:mnist.test.images})
plt.scatter(all_vec[:, 0], all_vec[:, 1], c=mnist.test.labels)
plt.colorbar()
plt.show()