学习记录:变分自编码器

基本介绍:

        VAE全称是Variational AutoEncoder,即变分自编码器。它不再是对一个样本直接生成一个隐层空间上的点,而是将经过神经网络编码后的隐藏层假设为一个标准的高斯分布,然后从这个分布中采样一个特征,再用这个特征进行解码,期望得到与原始输入相同的结果。

      (图片来源于:北京邮电大学计算机学院 鲁鹏)

        

         图片中的最小化用于防止VAE退化成最基础的AE,因为在进行训练时后为了减小误差,有可能会使得上图中的exp(σ)*e逐渐趋近于0,这样则不会再有随机采样的噪声。

        在获取了c之后,利用解码器,生成解码结果然后输出。

代码示例:

        (代码参考自:06 变分自编码器_哔哩哔哩_bilibili ,并对代码加上了一些自己的理解)

        环境:Python3.6+Keras2.2.0+Tensorflow-GPU1.9.0

        目标:以mnist数据库为例,学习手写数字图片之后,我们进行模仿,生成一些类似的图片

导入一些库

import numpy as np     #用于数据计算
import matplotlib.pyplot as plt  #用于画图
from keras.layers import Input, Dense, Lambda  #输入层、全连接层、自定义Lambda层(Lambda层不参与训练,只用于计算,用于生成中间的Z)
from keras.models import Model  #引入模型
from keras import backend as K  #后端backend是keras中非常重要的一个概念,它负责实现keras的各种操作
from keras import objectives   #目标函数,也叫损失函数
from keras.datasets import mnist  #引入手写数字库

定义一些常数

batch_size = 100  #分批,每次100个训练
original_dim = 784     #二维图片转为一维  28*28
intermediate_dim = 256 #用了两层全连接, 784-256    256-2   (input ,output)
latent_dim = 2   #第二层的输出 隐层为2层,可以方便后面做平面的可视化
epochs = 50   #轮次

encoder部分

x = Input(shape=(original_dim,))    #输入N个784的输入
h = Dense(intermediate_dim, activation='relu')(x) #全连接层256个输出,使用relu激活函数
z_mean = Dense(latent_dim)(h) #输入为h 输出为2维     隐层
z_log_var = Dense(latent_dim)(h)  #第二路,用于后面与正态分布采样进行融合处理

设计隐层的获取方式,自定义Lambda函数

def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(batch_size, latent_dim), mean=0.)#正太分布
    return z_mean + K.exp(z_log_var / 2) * epsilon  #z_mean 与z_log_var与epsilon合成一个新的隐层参数

z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])  #产生新的Z

设计要用到的全连接层

decoder_h = Dense(intermediate_dim, activation='relu')  #全连接层  256个神经元
decoder_mean = Dense(original_dim, activation='sigmoid')  #784个神经元 全连接层使用激活函数一般是使用sigmoid
h_decoded = decoder_h(z)
x_decoded_mean = decoder_mean(h_decoded)

自定义总的损失函数

def vae_loss(x, x_decoded_mean):   #定义一个总的损失函数
    xent_loss = original_dim * objectives.binary_crossentropy(x, x_decoded_mean)
    kl_loss = -0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
    return xent_loss + kl_loss

vae = Model(x, x_decoded_mean)   #定义一个model,x为输入,x_decoded_mean为输出
vae.compile(optimizer='rmsprop', loss=vae_loss)  #rmsprop为优化算法,vae_loss一个损失函数

开始训练

(x_train, y_train), (x_test, y_test) = mnist.load_data()  #加载数据库中的data

x_train = x_train.astype('float32') / 255.  #归一化到01区间
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))   #len(x_train)为数据个数,  后面维度为28*28
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

vae.fit(x_train, x_train,   #第一个输入,第二个输出
        shuffle=True,   #每次训练的时候是否将其随机打乱
        epochs=epochs,    #50轮迭代
        batch_size=batch_size,   #每128个作为一组进行训练
        validation_data=(x_test, x_test))

展示训练结果 

encoder = Model(x, z_mean)

x_test_encoded = encoder.predict(x_test, batch_size=batch_size)  #利用上面的训练结果预测
plt.figure(figsize=(6, 6))  #6x6大小
plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test)  #scatter散点图,x_test_encoded[:, 0]代表隐层第一个维度,
                                                                   #        x_test_encoded[:, 1]为第二个维度,c为点的颜色 
plt.colorbar()
plt.show()

对隐层进行解码,再定义一个生成器,从隐层到输出,用于产生新的样本

decoder_input = Input(shape=(latent_dim,))  #解码部分
_h_decoded = decoder_h(decoder_input)    #这是第一个全连接层进行解码
_x_decoded_mean = decoder_mean(_h_decoded) #第二个全连接层进行解码  获得一个28*28=784
generator = Model(decoder_input, _x_decoded_mean)

结果如下,和之前看到的隐层图是一致的,甚至能看到一些数字之间的过渡态

n = 20
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))  #28行、28列 ,每一个图片20像素
grid_x = np.linspace(-4, 4, n)
grid_y = np.linspace(-4, 4, n)

for i, xi in enumerate(grid_x):
    for j, yi in enumerate(grid_y):
        z_sample = np.array([[yi, xi]])
        x_decoded = generator.predict(z_sample)   #获得一个784的向量
        digit = x_decoded[0].reshape(digit_size, digit_size)  #将这个784行的向量转化为28*28的图片
        figure[(n - i - 1) * digit_size: (n - i) * digit_size,
               j * digit_size: (j + 1) * digit_size] = digit  #将这个28*28的图片贴到整个大图上

plt.figure(figsize=(10, 10))
plt.imshow(figure)
plt.show()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值