本文以图像的重构为例对VAE进行梳理
一、文章思路
我们首先了解一下传统AE的不足之处:隐空间没有良好的“规则性”,因为神经网络是一个非线性变化的过程;接着我们试图引入噪声来解决传统AE的局限性,但只引入噪声的话,隐空间总得不到完全的覆盖,因此我们将encoder的点输出转换成分布输出;之后,一是为了避免重构误差过大,二是为了避免VAE退化成传统AE,我们引入KL Loss,以实现将分布逼近标准正态分布的目的(此处重构 Loss和KL Loss其实包含有一种对抗的思想,重构 Loss不希望有噪声,而KL Loss希望有一定的噪声);接着我们从数学的角度推导出了上述两种误差的由来;最后,我们介绍了重参数技巧,并给出图片和代码以辅助梳理VAE的网络结构。
二、AE的局限性
传统的AE其实很简单,由于神经网络具有非常强大的拟合能力,因此将神经网络作为encoder和decoder,从而实现以下的效果:
- 利用encoder,能将高维的图片向量encode到一个低维空间(隐空间)中
- 利用decoder,能将隐空间中的低维向量decode回高维图片向量
如上图所示,我们即构造了一个具有强大重构能力的AE。对于一个合格的生成模型来说,decoder部分应该是能够单独提取出来的,并且隐空间下的任意一个向量,decoder都能将其恢复成一张有意义的图片。但是很遗憾,上述的AE并不能达到这个标准。下面我们举一个简单的例子:
假设我们训练好了一个AE,效果如上图所示:
- 对于全月图,它重构的效果很好
- 对于半月图,它重构的效果也很好
现在,假设我们取了隐空间中夹在全月图和半月图对应的隐状态中间的一个状态,如果是一个合格的生成模型,我们期望它decode得到的是一个3/4月图。但是实际上,该AE恢复得到的是一张模糊而且无法辨认的乱码图。一个合理的解释是:
我们采用神经网络作为encoder和decoder,这使得encode和decode的过程其实是一个非线性变换的过程,所以在隐空间上,不同状态之间的迁移是没有规律的。
对于一个合格的生成模型来说,我们期望它的隐空间是具有“规则性”的,具体表现为:
- 连续性:隐空间中的两个相邻状态decode后不应呈现两个完全不同的内容,应该具有一定的相似性
- 完整性:任意从隐空间中采样的状态,经过decode之后恢复得到的应该是“有意义”的内容
很遗憾,从这两个表现来看,传统的AE并不是一个合格的生成模型。
三、VAE的引入
如何解决上述传统AE存在的问题呢?一个可行的办法就是引入噪声,使得一张图片对应的隐空间状态的区域扩大,从而掩盖失真的空白状态区域,如下图所示:
现在,全月图对应的隐状态是左边的绿色箭头,半月图对应的隐状态是右边的绿色箭头。我们仍然关注之前的那个状态点,现在decoder既想让它恢复成全月图,又想让它恢复成半月图,于是它的恢复效果就是这两种图的折中——3/4月图。
但是这么做还不够,引入噪声并不能使得隐空间的所有状态都得到覆盖,例如上图中的黄色点对应的仍然是一个失真点。为了解决这个问题,我们可以把噪声无限拉长,使其覆盖整个隐空间,不过我们得保证在原编码点取到的概率最高,离原编码点越远,取到的概率越低,因此我们很自然地引入了分布:
至此,我们完成AE向VAE进化的第一步:
encoder的输出不再是隐空间的一个点,而是一个分布!(对于正态分布而言,其分布可由均值和方差描述,噪声强度即由方差控制)。
但是这么做实际上还存在一个问题,虽然我们将encoder的输出转变成了一个分布,但是由于噪声是由神经网络学出来的,神经网络的训练目标是减小重构误差,也就是减小噪声,所以神经网络会想方设法让噪声变为0,也就是分布的方差变为0,即encoder的输出又变成了一个点,也就是分布的均值。这样VAE就慢慢退化成了AE!
这当然是我们不想看到的,因此我们还得控制分布服从一个标准的正态分布,这样就能保证方差不会太大,也就是噪声不会太大,符合重构误差的喜好;又能保证encoder的输出仍是一个分布,维持隐空间的“规则性”。
那么如何控制分布服从一个标准的正态分布呢?这就引出了AE向VAE进化的第二步:
增加 K L ( N ( μ , σ 2 ) ∣ ∣ N ( 0 , 1 ) ) KL(N(\mu,\sigma^2)||N(0,1)) KL(N(μ,σ2)∣∣N(0,1))作为损失
综上,AE向VAE的进化之路其实也就是两步:
- encoder的输出转换成一个分布
- 在重构损失的基础上,增加KL损失作为代价
四、数学推导
VAE为什么叫做变分自编码器?“变分”其实就体现在数学推导过程用到了变分推断的思想。
具体推导过程如下图所示:
下面仔细分析一下这两项:
可以发现这两项其实就是VAE两个Loss函数的构成
五、网络架构 + 代码实现
我们先利用下面这张图梳理一下VAE的整个流程:
- encode过程:将 X X X丢进encoder,每一个样本 x x x都会生成一个专属的分布 P ( Z ∣ X k ) P(Z|X_k) P(Z∣Xk),注意我们要保证每个专属分布 P ( Z ∣ X k ) P(Z|X_k) P(Z∣Xk)都尽量逼近标准正态分布,这一点我们通过增加 KL Loss实现
- decode过程:从每一个专属分布 P ( Z ∣ X k ) P(Z|X_k) P(Z∣Xk)采样出一个 Z k Z_k Zk,将所有 Z k Z_k Zk丢进decoder,生成 X k ^ \hat{X_k} Xk^
但是上述这张图在训练的时候存在一个问题,我们“从专属分布 P ( Z ∣ X k ) P(Z|X_k) P(Z∣Xk)采样出 Z k Z_k Zk”这个操作在反向传播的过程中是不可导的,我们可以通过重参数技巧解决这个问题,即:
从 N ( μ , σ 2 ) N(\mu,\sigma^2) N(μ,σ2)中采样一个 Z Z Z,相当于从 N ( 0 , I ) N(0,I) N(0,I)中采样一个 ε \varepsilon ε,然后让 Z = μ + ε σ Z=\mu+\varepsilon\sigma Z=μ+εσ
所以现在网络结构变成了如下这张图所示:
下面我们通过截取部分keras代码梳理一下这个网络结构,但注意在代码实际实现的过程中,有一个需要注意的点:
正常情况下,专属分布的方差 σ 2 = n n ( x ) \sigma^2=nn(x) σ2=nn(x)总是非负的,需要加激活函数进行处理;
在代码实现中,实际拟合的是 l o g σ 2 = n n ( x ) log\sigma^2=nn(x) logσ2=nn(x),由于 l o g log log可正可负,所以不需要加激活函数进行处理,因此KL Loss损失的计算也会发生相应的变化,但本质是一样的。
from keras.layers import Input, Dense, Lambda
from keras.models import Model
from keras import backend as K
original_dim = 784
latent_dim = 2 # 隐变量取2维只是为了方便后面画图
intermediate_dim = 256
# encoder部分
x = Input(shape=(original_dim,))
h = Dense(intermediate_dim, activation='relu')(x)
# 算p(Z|X)的均值和方差
z_mean = Dense(latent_dim)(h)
z_log_var = Dense(latent_dim)(h)
# 采样部分(重参数技巧)
def sampling(args):
z_mean, z_log_var = args
epsilon = K.random_normal(shape=K.shape(z_mean))
return z_mean + K.exp(z_log_var / 2) * epsilon
# 重参数层,相当于给输入加入噪声
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])
# decoder部分
decoder_h = Dense(intermediate_dim, activation='relu')
decoder_mean = Dense(original_dim, activation='sigmoid')
h_decoded = decoder_h(z)
x_decoded_mean = decoder_mean(h_decoded)
# 损失函数部分:xent_loss是重构loss,kl_loss是KL loss
xent_loss = K.sum(K.binary_crossentropy(x, x_decoded_mean), axis=-1)
kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
vae_loss = K.mean(xent_loss + kl_loss)