【单点知识】变分自动编码器VAE介绍

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

在此前文章【单点知识】基于PyTorch讲解自动编码器(Autoencoder)中已经介绍了自动编码器,本文将再继续介绍自动编码器的变种——变分自动编码器VAE

1 为什么需要自动变分编码器?

变分自动编码器(VAE)是Diederik P. Kingma和Max Welling在2013年提出的(原文:https://arxiv.org/pdf/1312.6114)。VAE通过引入概率模型来生成数据。编码器输出的是一个概率分布(通常为高斯分布),解码器从这个分布中采样生成数据。

回忆我们之前说明的自动编码器的原理:编码器将输入数据 x x x映射到一个低维的编码表示 z z z,即 g ϕ ( x ) = z g_{\phi}(x)=z gϕ(x)=z,这里 z z z(也称隐变量,潜在变量)是一个固定的值(或者向量,或者矩阵),这就非常影响生成的多样性。

而VAE是一种生成模型,结合了自动编码器(Autoencoder)的思想和概率图模型的概念。与传统的自动编码器不同,VAE 不仅学习数据的压缩表示,还通过引入概率分布生成新的数据样本。VAE 在生成模型领域有着广泛的应用,特别是在图像生成、文本生成和数据增强等方面。

举个例子:如果没有自动变分编码器,那我们向现在的文生图AI工具输入“请画一只猫”,那AI工具永远只会生成一只固定的猫。而有了自动变分编码器,AI工具就可以输出黑猫、白猫、长毛猫、短毛猫等各种猫。

2. VAE工作原理

在这里插入图片描述

2.1 编码器(Encoder)

编码器将输入数据 x x x映射到一个潜在空间中的概率分布参数。具体来说,编码器输出两个向量:均值 μ \mu μ 和方差 σ \sigma σ。这两个向量定义了一个高斯分布 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(zx),其中 z z z 是潜在变量,这个分布表示了给定输入数据 x x x时,潜在变量 z z z的概率分布。

2.2 重参数化技巧(Reparameterization Trick)

为了使潜在变量 z z z可以通过梯度下降进行优化,VAE 使用重参数化技巧。具体来说,从标准正态分布 N ( 0 , 1 ) \mathcal{N}(0, 1) N(0,1) (均值为0,标准差为1)中采样一个随机变量 ϵ \epsilon ϵ,然后通过以下公式生成 z z z
z = μ + σ ⋅ ϵ z = \mu + \sigma \cdot \epsilon z=μ+σϵ
这样, z z z 的梯度可以通过 μ \mu μ σ \sigma σ传递,从而实现端到端的训练。

2.3 解码器(Decoder)

解码器将潜在变量 z z z映射回原始数据空间,生成重构数据 x ′ x' x。解码器的目标是最小化重构数据 x ′ x' x和原始数据 x x x之间的差异,通常使用均方误差(MSE)或交叉熵损失函数。

3. 损失函数

变分自动编码器(VAE)的损失函数由两部分组成:重构损失(Reconstruction Loss)和KL散度(Kullback-Leibler Divergence)。这两部分共同指导模型的学习过程,确保生成的样本既忠实于输入数据又具有多样性。

3.1 重构损失(Reconstruction Loss)

重构损失衡量的是模型生成的重构数据 x ′ x' x与原始输入数据 x x x 之间的差异。常用的重构损失函数包括:

  • 均方误差(Mean Squared Error, MSE)
    L recon = 1 N ∑ i = 1 N ( x i − x i ′ ) 2 \mathcal{L}_{\text{recon}} = \frac{1}{N} \sum_{i=1}^{N} (x_i - x'_i)^2 Lrecon=N1i=1N(xixi)2
    其中 ( N ) 是数据点的数量。

  • 二元交叉熵(Binary Cross-Entropy)
    L recon = − 1 N ∑ i = 1 N [ x i log ⁡ x i ′ ) + ( 1 − x i ) log ⁡ ( 1 − x i ′ ) ] \mathcal{L}_{\text{recon}} = -\frac{1}{N} \sum_{i=1}^{N} \left[ x_i \log x'_i) + (1 - x_i) \log(1 - x'_i) \right] Lrecon=N1i=1N[xilogxi)+(1xi)log(1xi)]
    适用于二值数据或概率分布。

3.2 KL 散度(Kullback-Leibler Divergence)

KL 散度衡量的是编码器生成的潜在分布 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(zx) 与先验分布 p ( z ) p(z) p(z) 之间的差异。在 VAE 中,先验分布 p ( z ) p(z) p(z) 通常选择为标准正态分布 N ( 0 , 1 ) \mathcal{N}(0, 1) N(0,1)

对于高斯分布 q ϕ ( z ∣ x ) = N ( μ , σ 2 ) q_{\phi}(z|x) = \mathcal{N}(\mu, \sigma^2) qϕ(zx)=N(μ,σ2),KL 散度可以表示为:
KL ( q ϕ ( z ∣ x ) ∥ p ( z ) ) = 1 2 ∑ i = 1 D ( σ i 2 + μ i 2 − 1 − log ⁡ ( σ i 2 ) ) \text{KL}(q_{\phi}(z|x) \| p(z)) = \frac{1}{2} \sum_{i=1}^{D} \left( \sigma_i^2 + \mu_i^2 - 1 - \log(\sigma_i^2) \right) KL(qϕ(zx)p(z))=21i=1D(σi2+μi21log(σi2))
其中 D D D 是潜在变量 z z z 的维度, μ i \mu_i μi σ i \sigma_i σi 分别是编码器输出的均值和标准差。

3.3 总的损失函数

VAE 的总损失函数是重构损失和KL散度的加权和:
L = L recon + β ⋅ KL ( q ϕ ( z ∣ x ) ∥ p ( z ) ) \mathcal{L} = \mathcal{L}_{\text{recon}} + \beta \cdot \text{KL}(q_{\phi}(z|x) \| p(z)) L=Lrecon+βKL(qϕ(zx)p(z))
其中 β \beta β 是一个超参数,用于平衡重构损失和KL散度的影响。通常情况下, β = 1 \beta = 1 β=1,但在某些变体中, β \beta β可以调整以改变模型的行为。

4. 总结

VAE通过引入概率分布和重参数化技巧,不仅能够学习数据的高效表示,还能生成新的数据样本。VAE 在生成模型领域有着广泛的应用,特别是在需要生成高质量新样本的任务中表现出色。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

使者大牙

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值