VAE的网络首先,全部都是全连接层,其次拥有均值层和方差层(通常认为它是加上了log的)
最后满足 Z = MEAN + EPS * LOG_VAR,使得Z的值连续可导。
这个EPS需要认为构造一个正太分布(0,1).
import tensorflow as tf
from PIL import Image
import numpy as np
(x_train,y_train),(x_test,y_test) = tf.keras.datasets.mnist.load_data()
print(x_train.shape,y_train.shape)
#(60000, 28, 28) (60000,)
z_dim = 20
batchSize = 128
lr = 0.00001
EPOCH = 300
x_train = tf.cast(x_train,dtype=tf.float32)
x_test = tf.cast(x_test,dtype=tf.float32)
x_train = tf.divide(x_train,255.)
x_test = tf.divide(x_test,255.)
dataBase_train = tf.data.Dataset.from_tensor_slices(x_train).shuffle(10000).batch(batchSize,drop_remainder=True)
dataBase_test = tf.data.Dataset.from_tensor_slices(x_test).shuffle(10000).batch(batchSize,drop_remainder=True)
class VAE(tf.keras.Model):
def __init__(self):
super(VAE,self).__init__()
#encooder
self.fc1 = tf.keras.layers.Dense(512,activation= tf.nn.relu)
self.fc2 = tf.keras.layers.Dense(256,activation= tf.nn.relu)
self.fc4 = tf.keras.layers.Dense(64,activation= tf.nn.relu)
self.mean = tf.keras.layers.Dense(z_dim)
self.log_var = tf.keras.layers.Dense(z_dim)
#decoder
self.output_fc2 = tf.keras.layers.Dense(128,activation= tf.nn.relu)
self.output_fc5 = tf.keras.layers.Dense(784)
def encoder(self,inputs):
x = tf.reshape(inputs,[-1,784])
x = self.fc1(x)
x = self.fc2(x)
logits = self.fc4(x)
mean = self.mean(logits)
log_var = self.log_var(logits)
return mean,log_var
def reparameterize(self,mean,log_var):
eps = tf.random.normal([log_var.shape[1]],mean=0.,stddev=1.)
z = mean + eps * tf.exp(log_var) ** 0.5
return z
def decoder(self,z):
x = self.output_fc2(z)
logits = self.output_fc5(x)
return logits
def call(self, inputs, training=None, mask=None):
mean,log_var = self.encoder(inputs)
z = self.reparameterize(mean,log_var)
logits = self.decoder(z)
return logits,mean,log_var
model=VAE()
model.build(input_shape=[None,28,28])
model.summary()
optimizer = tf.optimizers.Adam(learning_rate=lr)
print()
path = r"E:\PycharmProjects\untitled\VAE\save_image\writer"
index = 0
for epoch in range(EPOCH):
for step,data in enumerate(dataBase_train):
with tf.GradientTape() as tape:
logtis,mean,log_var = model(data)
data = tf.reshape(data,[-1,784])
logits_loss = tf.keras.losses.binary_crossentropy(data,logtis,from_logits=True)
#KL离散度,将VAE的数据表达通过一定的函数形式转化为(0,1)分布
KI散度公式最终为:
#上式为标准正太分布,简化后为下式
kl = - 0.5 * (log_var + 1 - mean ** 2 - tf.exp(log_var))
kl = tf.reduce_mean(kl)
loss = tf.reduce_mean((logits_loss + 0.9 * kl))
grad = tape.gradient(loss,model.trainable_variables)
optimizer.apply_gradients(zip(grad,model.trainable_variables))
if step % 100 == 0:
print("Epoch:",epoch,"Loss:",float(loss))
writer = tf.summary.create_file_writer(path)
with writer.as_default():
tf.summary.scalar("loss",loss,step=step)
if epoch > 10 and epoch <=20:
optimizer.learning_rate = lr * 10 ** 0.5
if epoch > 40 and epoch <=100:
optimizer.learning_rate = lr * 0.0001
# text_image = next(iter(dataBase_test))
text_image = tf.random.normal([batchSize,28,28])
image = next(iter(text_image))
# print(image,image.shape)
new_img = Image.new('L',(28,28))
logits,_,_ = model(image)
logits = tf.reshape(logits,[-1,28,28])
logits = tf.nn.sigmoid(logits) * 255.
image = Image.fromarray(np.uint8(logits[0]),mode='L')
new_img.paste(image,(0,0))
p = r"E:\VAE\save_image" #图片保存路径
new_img.save(p + "/{a}.png".format(a = index))
index += 1
print("Image Saved!")