本文实现的是自动编码器(Auto Encoder,AE),而不是变分自动编码器(Variational Auto Encoder,VAE)。因此代码只能实现通过Mnist数据集自编码出一个相似的新的手写数字集,而不是实现通过输入随机高斯分布的隐含变量生成全新的手写数字。
1、code
# @Time : 2022/8/22 21:21
# @Author : CSDN User: ctrl A_ctrl C_ctrl V
# @Function: valid AE(Auto Encoder) using mnist dataset
import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt
import numpy as np
import random
# hyper parameter
epochs = 10
batchSize = 512
# load dataset
(x_train, _), (x_valid, _) = keras.datasets.mnist.load_data()
assert x_train.shape == (60000, 28, 28)
assert x_valid.shape == (10000, 28, 28)
x_train = x_train.reshape(x_train.shape[0], -1) # (60000,784)
x_valid = x_valid.reshape(x_valid.shape[0], -1)
# normalization
x_train = tf.cast(x_train, tf.float32) / 255
x_valid = tf.cast(x_valid, tf.float32) / 255
# encoder and decoder layer
inputSize = 784
hiddenSize = 32
outputSize = 784
inputDim = keras.layers.Input(shape=(inputSize,))
encodeLayer = keras.layers.Dense(hiddenSize, activation='relu')(inputDim)
decoderLayer = keras.layers.Dense(outputSize, activation='sigmoid')(encodeLayer)
# bulid model
model = keras.Model(inputs=inputDim, outputs=decoderLayer)
print(model.summary())
# get encoder and decoder from model
encoder = keras.Model(inputs=inputDim, outputs=encodeLayer)
decoderInput = keras.layers.Input(shape=(hiddenSize,))
decoderOutput = model.layers[-1](decoderInput)
decoder = keras.Model(inputs=decoderInput, outputs=decoderOutput)
# train
# VAE是没有label的,以输入图像本身作为label,因此这里的 x=y=x_train
model.compile(optimizer='adam', loss='mse')
model.fit(x=x_train, y=x_train, epochs=epochs, batch_size=batchSize, shuffle=True, validation_data=(x_valid, x_valid))
# valid
# display ten images randomly for visualization
encoder_valid = encoder.predict(x_valid)
decoder_valid = decoder.predict(encoder_valid)
x_valid = x_valid.numpy()
visualNum = 10
startNum = random.randint(0, 10000 - visualNum)
plt.figure(figsize=(20, 4))
for i in range(1, visualNum + 1):
plt.subplot(2, visualNum, i)
plt.imshow(x_valid[startNum + i].reshape(28, 28))
plt.subplot(2, visualNum, visualNum + i)
plt.imshow(decoder_valid[startNum + i].reshape(28, 28))
plt.show()
# test
# test with random matrix,display ten images randomly for visualization
test_tensor = np.random.rand(visualNum, hiddenSize)
test_output = decoder.predict(test_tensor)
plt.figure(figsize=(20, 4))
for i in range(1, visualNum + 1):
plt.subplot(2, visualNum, i)
plt.imshow(test_output[i - 1].reshape(28, 28))
plt.show()
2、生成结果
(1) epoch=1
验证集(生成的图像比较模糊,但已经有基本轮廓):
用随机生成的矩阵进行测试(毫无规律):
(2) epoch=5
验证集(生成的图像有一些模糊,但轮廓非常清晰):
用随机生成的矩阵进行测试(依然毫无规律):
(3) epoch=10
验证集(生成的图像已经非常清晰):
用随机生成的矩阵进行测试(依然毫无规律):
(4)总结
正如前面所言,AE只能复现图像,不能生成图像。所以随着epoch的增加,复现的图像越来越清晰,但无法通过随机矩阵生成我们想要的图像。要想实现真正的图像生成需要用VAE和GAN。