VAE通俗理解及公式推导

引言

变分自编码器(Variational Autoencoder, VAE)是一种特殊的神经网络架构,它结合了传统的自动编码器(Autoencoder)思想和概率图模型的概念。通俗来说,VAE尝试学习数据的压缩表示(我们称之为“编码”),同时确保这些编码符合某种概率分布(通常是高斯分布)。这样做有两个好处:一是可以防止过拟合,二是使得我们可以从这个分布中随机采样来生成新的、类似的数据。

通俗理解

假设我们在编码维度为6的大型人脸数据集上训练了一个自编码器模型。理想的自编码器将学习人脸的描述性属性,如肤色、是否戴眼镜等,以试图用一些压缩的表示来描述观察。
在这里插入图片描述
在上面的示例中,我们使用单个值 { 0 , 1 } \{0,1\} {0,1}来描述输入图像在潜在特征上的表现,表示相应的人脸有没有对应的属性。
但在实际情况中,我们可能遇到单值无法准确衡量的情况。例如,当我们输入蒙娜丽莎的照片时,将微笑特征设定为特定的单值(相当于断定蒙娜丽莎笑了或者没笑)显然不如将微笑特征设定为某个概率分布(例如将微笑特征设定为x到y范围内的某个数,这个范围内既有数值可以表示蒙娜丽莎笑了又有数值可以表示蒙娜丽莎没笑)更合适。而变分自编码器便是用“取值的概率分布”代替原先的单值来描述对特征的观察的模型,如下图的右边部分所示,经过变分自编码器的编码,每张图片的微笑特征不再是自编码器中的单值而是一个概率分布。
在这里插入图片描述
通过这种方法,我们可以将每个原始的人脸压缩为一组可准确衡量人脸潜在属性的概率分布,这个过程是利用编码器完成的推断过程。随后,我们将从人脸潜在属性的状态分布中随机采样,生成一个向量作为解码器模型的输入,如下图的decoder部分所示。
在这里插入图片描述
通过构造我们的编码器模型来输出可能值的范围(统计分布),我们将随机采样这些值以供给我们的解码器模型,我们实质上实施了连续,平滑的潜在空间表示。对于潜在分布的所有采样,我们期望我们的解码器模型能够准确重构输入。因此,在潜在空间中彼此相邻的值应该与非常类似的重构相对应。

公式推导

上述过程我们可以简化为原始观察x和隐空间中变量z之间的转化,如下图所示。
在这里插入图片描述
其中, P ( z ∣ x ) P(z|x) P(zx)是编码器推断隐空间的正向过程中所产生的后验分布,表示由特定数据分布 P ( x ) P(x) P(x)生成隐藏属性 P ( z ) P(z) P(z)的概率分布;
与此对应的则是由隐空间生成原始数据的过程中所产生的似然分布 p ( x ∣ z ) p(x|z) p(xz)
我们知道,VAE最重要的一点动机就是产生新数据,与此直接相关的过程就是z->x的生成过程。也就是说,如果能从后验分布中采样隐变量 P ( z ) P(z) P(z),那么我们很容易重建原始分布得到 P ( x ) P(x) P(x),产生新的样本。但现在问题是, P ( z ) P(z) P(z)的分布我们也不知道。
但是,我们可以做一个假设: P ( z ) ∼ N ( 0 , 1 ) P(z)\sim N(0,1) P(z)N(0,1),即假设其服从正态分布。有了 P ( z ) P(z) P(z)的分布,我们就可以计算 p ( x ∣ z ) p(x|z) p(xz)生成新数据了。实际上,P(z)的确是服从高斯分布的,但是其均值和方差却与原始数据x有关,而后依然可以利用参数重整化的手段转换为 P ( z ) ∼ N ( u , σ 2 ) P(z)\sim N(u,\sigma^2) P(z)N(u,σ2).

所以,现在的问题变成如何计算 p ( z ∣ x ) p(z|x) p(zx)得到隐变量z
假设存在一个分布 q ( z ∣ x ) q(z|x) q(zx)近似 p ( z ∣ x ) p(z|x) p(zx),我们将定义它具有可伸缩的分布。如果我们可以定义 q ( z ∣ x ) q(z|x) q(zx)的参数,使其与 p ( z ∣ x ) p(z|x) p(zx)十分相似,就可以用它来对复杂的分布进行近似的推理。KL散度可以衡量两个分布之间的差异,其值越小代表两个分布越接近,所以我们可以通过最小化两个分布之间的KL散度来保证 q ( z ∣ x ) q(z|x) q(zx) p ( z ∣ x ) p(z|x) p(zx)相似。即 m i n   D K L ( q ( z ∣ x ) ∣ ∣ p ( z ∣ x ) ) min \ D_{KL}(q(z|x)||p(z|x)) min DKL(q(zx)∣∣p(zx))
公式的推导过程大家随便搜搜就能看到,这里我们只说结果,可以通过最大化下面式子的方式最小化上述表达式:
E q ( z ∣ x ) l o g   p ( x ∣ z ) − D K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) E_{q(z|x)}log\ p(x|z)-D_{KL}(q(z|x)||p(z)) Eq(zx)log p(xz)DKL(q(zx)∣∣p(z))
在这里,第一个式子代表重构的可能性,第二个则确保我们学习的分布q类似于真实的先验分布p。

为了重新访问我们的图形模型,我们可以使用q来推断可能隐藏的变量(潜在状态),这些变量可以用于生成观察。我们可以进一步将这个模型构造成神经网络结构,其中编码器模型学习从x到z的映射,解码器模型学习从z到x的映射。
在这里插入图片描述
我们网络的损失函数包括两项,第一个惩罚重构误差(可以认为是正如前面所讨论的最大化重构的可能性),第二个鼓励我们学习分布q(z | x)类似于真实的先验分布p(z),我们假设单元高斯分布,每个维度j的潜在空间。
L ( x , x ^ ) + ∑ i K L ( q i ( z ∣ x ) ∣ ∣ p ( z ) ) \mathcal{L}(x,\hat x)+\sum_i KL(q_i (z|x)||p(z)) L(x,x^)+iKL(qi(zx)∣∣p(z))

总结

变分自编码器的工作原理

  • 编码阶段:与普通自编码器一样,VAE也有一个编码器网络,它接收输入数据并将其映射到一个潜在空间中的点。但是,在VAE中,这个潜在空间是概率性的,而不是固定的数值。编码器输出两个向量:一个代表均值(μ),另一个代表方差(σ²),这两个参数定义了一个可能包含原始数据的潜在变量的概率分布。
  • 重参数化技巧:为了能够有效地训练模型,并且允许梯度通过随机抽样过程传递,VAE使用了一种称为重参数化技巧的技术。简单来说,就是用一个额外的随机噪声ε乘以方差σ再加上均值μ,从而得到潜在变量z的样本。这样做的结果是,即使对于相同的输入,每次前向传播都会产生略微不同的z值,这有助于模型学习到更鲁棒的数据表示。
  • 解码阶段:一旦得到了潜在变量z,解码器网络就会试图根据z重建原始输入数据。解码器的目标是最小化输入和输出之间的差异,即重构误差。
  • 损失函数:VAE的损失函数由两部分组成。一部分是重构误差,用来衡量解码器生成的数据与原始输入之间的差距;另一部分是KL散度,用于惩罚编码器产生的分布偏离预设的先验分布的程度。通过最小化这个综合损失函数,VAE能够在保持良好的重构性能的同时,也确保了潜在空间的良好结构。

应用

  • 生成新数据:由于VAE学到了数据的概率分布,因此可以通过在潜在空间中随机抽样来生成全新的、但看起来合理的数据样本。例如,可以生成从未见过的手写数字图片、人脸图像等。
  • 数据插值:在两个已知数据点对应的潜在变量之间进行线性插值,然后通过解码器转换回原始数据空间,可以获得平滑过渡的一系列中间状态。
  • 异常检测:如果某些数据点在重构时表现出较大的误差,那么它们可能是异常点或者不属于训练集中存在的模式。
  • 降维与可视化:类似于PCA或t-SNE,VAE也可以用来将高维数据投影到低维空间,便于观察和理解复杂数据集的内部结构。
  • 半监督学习:利用VAE对未标注数据进行编码和解码,可以帮助提高有少量标签数据的学习任务的效果。

VAE核心代码

# VAE 模型
class VAE(nn.Module):
    def __init__(self, encoder_layer_size, latent_size, decoder_layer_sizes, conditional=False, num_labels=0):
        super().__init__()
        if conditional:
            assert num_labels>0
        assert type(encoder_layer_sizes)==list
        assert type(latent_size)==int
        assert type(decoder_layer_sizes)==list
        self.latent_size = latent_size
        self.encoder = Encoder(encoder_layer_sizes, latent_size, conditional, num_labels)
        self.decoder = Decoder(decoder_layer_sizes, latent_size, conditional, num_labels)
 
    def forward(self, x, c=None):
        if x.dim() > 2:
            x = x.view(-1, 28*28)
        means, log_var = self.encoder(x, c)
        z = self.reparameterize(means, log_var)
        recon_x = self.decoder(z, c)
        return recon_x, means, log_var, z
 
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps*std
 
    def inference(self, z, c=None):
        recon_x = self.decoder(z, c)
        return recon_x

编码器

# 编码器
 
class Encoder(nn.Module):
    def __init__(self, layer_sizes, latent_size, conditional, num_labels):
        super().__init()
        self.conditional = conditional
        if self.conditional:
            layer_sizes[0] += num_labels
 
        self.MLP = nn.Sequential()
 
        for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
            self.MLP.add_module(name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
            self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
 
        self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
        self.linear_log_var = nn.Linear(layer_sizes[-1], latent_size)
 
    def forward(self, x, c=None):
        if self.conditional:
            c = idx2onehot(c, n=10)
            x = torch.cat((x,c), dim=-1)
        x = self.MLP(x)
        means = self.linear_means(x)
        log_vars = self.linear_log_var(x)
        
        return means, log_vars

解码器

# 解码器
 
class Decoder(nn.Module):
    def __init__(self, layer_sizes, latent_size, conditional, num_labels):
        super().__init__()
        self.MLP = nn.Sequential()
        self.conditional = conditional
        if self.conditional:
            input_size = latent_size + num_labels
        else:
            input_size = latent_size
 
        for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
            self.MLP.add_module(name="L{:d}".format(i),module=nn.Linear(in_size, out_size))
            if i+1<len(layer_sizes):
                self.MLP.add_module(name="A{:d}".format(i),module=nn.ReLU())
            else:
                self.MLP.add_module(name="Sigmoid", module=nn.Sigmoid())
 
    def forward(self, z, c):
        if self.conditional:
            c = idx2onehot(c, n=10)
            z = torch.cat((z,c), dim=-1)
        x = self.MLP(z)
        
        return x
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Marlowee

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值