离散变分自编码器(dVAE)详解
一、背景
变分自编码器(VAE)是一种强大的生成模型,在许多领域得到广泛应用。但传统 VAE 通常假设潜在空间是连续的,在处理离散数据或需要离散潜在表示时存在局限性。dVAE(Discrete Variational Autoencoder,离散变分自编码器) 应运而生,旨在解决离散潜在变量的学习和生成问题,能够在离散的潜在空间中进行建模,从而更灵活地处理诸如文本、类别标签等离散数据类型,拓展了变分自编码器的应用范围。
二、原理
2.1 离散变分自编码器(dVAE)与Gumbel Softmax
1. dVAE的核心思想
离散变分自编码器(dVAE)是标准VAE的变体,主要区别在于:
- 潜在空间离散化:dVAE使用离散潜在变量而非连续变量
- 结构化表示:离散变量能更好地捕捉数据的分类或分层结构
- 可解释性:离散潜在变量通常具有更明确的语义含义
2. Gumbel Softmax在dVAE中的作用
当潜在变量离散时,传统的重参数化技巧失效,因为:
- 离散采样不可导,梯度无法反向传播
- 无法直接优化离散分布的参数
dVAE和VQVAE总体概念相似,主要不同之处在于引入了Gumbel Softmax进行训练,有效避免了VQ-VAE训练中由于ArgMin操作不能求导而产生的问题(Straight Through Estimator直通估计近似)。
2.2 模型架构
编码器: dVAE 的编码器将输入数据映射到离散潜在变量的概率分布上。不同于 VQ-VAE(Vector Quantized-VAE)等模型直接输出确定性的离散索引,dVAE 输出的是离散变量的概率分布。 一般会使用 Gumbel-Softmax 技巧来对离散采样过程进行松弛,使得在训练过程中可以通过梯度下降来优化模型。具体来说,通过引入 Gumbel 噪声,并结合 Softmax 函数,将离散的采样过程转化为可微的操作,从而可以在反向传播时计算梯度,更新网络参数。
解码器: 解码器则根据从编码器得到的离散潜在变量(采样得到),重建出原始输入数据或者生成新的数据。解码器通常是一个神经网络,它将离散潜在变量映射回数据空间,比如在图像生成任务中,将离散潜在变量映射回图像的像素空间。
离散潜在变量: dVAE 中的离散潜在变量可以是多维的,每个维度都对应着不同的离散取值。这些离散变量的组合构成了潜在空间,模型通过学习这些离散变量的概率分布,来捕捉输入数据的特征和结构。
三、损失函数推导
dVAE 的损失函数主要基于变分推断的原理,目标是最大化证据下界(ELBO,Evidence Lower BOund)。对于给定的输入数据 x x x,其损失函数推导如下:
3.1. 变分下界
根据贝叶斯定理,数据
x
x
x 的对数似然可以表示为:
log
p
(
x
)
=
log
p
(
x
,
z
)
p
(
z
∣
x
)
=
log
p
(
x
,
z
)
−
log
p
(
z
∣
x
)
\log p(x) = \log \frac{p(x,z)}{p(z|x)} = \log p(x, z) - \log p(z|x)
logp(x)=logp(z∣x)p(x,z)=logp(x,z)−logp(z∣x)
通过引入变分分布
q
(
z
∣
x
)
q(z|x)
q(z∣x),可以得到对数似然的一个下界(证据下界):
log
p
(
x
)
≥
E
q
(
z
∣
x
)
[
log
p
(
x
∣
z
)
]
−
D
K
L
(
q
(
z
∣
x
)
∣
∣
p
(
z
)
)
\log p(x) \geq \mathbb{E}_{q(z|x)}[\log p(x|z)] - D_{KL}(q(z|x)||p(z))
logp(x)≥Eq(z∣x)[logp(x∣z)]−DKL(q(z∣x)∣∣p(z))
其中,
E
q
(
z
∣
x
)
[
log
p
(
x
∣
z
)
]
\mathbb{E}_{q(z|x)}[\log p(x|z)]
Eq(z∣x)[logp(x∣z)] 是重建损失项,衡量了在给定潜在变量
z
z
z 时,模型重建输入数据
x
x
x 的能力;
D
K
L
(
q
(
z
∣
x
)
∣
∣
p
(
z
)
)
D_{KL}(q(z|x)||p(z))
DKL(q(z∣x)∣∣p(z)) 是KL 散度项,衡量了变分分布
q
(
z
∣
x
)
q(z|x)
q(z∣x) 与先验分布
p
(
z
)
p(z)
p(z) 的差异。
3.2. 具体到 dVAE
-
重建损失:在 dVAE 中,由于使用了离散潜在变量,重建损失通常根据具体任务来定义。例如在图像生成中,可能使用像素空间的均方误差(MSE)或者交叉熵损失。假设 x x x 是输入图像, x ^ \hat{x} x^ 是重建图像,那么重建损失可以表示为:
L r e c = − E q ( z ∣ x ) [ log p ( x ∣ z ) ] = ∑ i = 1 N ℓ ( x i , x ^ i ) L_{rec} = -\mathbb{E}_{q(z|x)}[\log p(x|z)] = \sum_{i=1}^{N} \ell(x_i, \hat{x}_i) Lrec=−Eq(z∣x)[logp(x∣z)]=i=1∑Nℓ(xi,x^i)
其中, N N N 是数据集中样本的数量, ℓ ( x i , x ^ i ) \ell(x_i, \hat{x}_i) ℓ(xi,x^i) 是针对单个样本 i i i 的损失函数。 -
KL 散度损失:对于离散潜在变量的先验分布 p ( z ) p(z) p(z),通常假设为均匀分布或者其他简单的分布。KL 散度项衡量了编码器输出的分布 q ( z ∣ x ) q(z|x) q(z∣x) 与先验分布 p ( z ) p(z) p(z) 的差异。
L K L = D K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) = ∑ z q ( z ∣ x ) log q ( z ∣ x ) p ( z ) L_{KL} = D_{KL}(q(z|x)||p(z)) = \sum_{z} q(z|x) \log \frac{q(z|x)}{p(z)} LKL=DKL(q(z∣x)∣∣p(z))=z∑q(z∣x)logp(z)q(z∣x)
最终,dVAE 的损失函数是重建损失和 KL 散度损失的加权和:
L
=
L
r
e
c
+
λ
L
K
L
L = L_{rec} + \lambda L_{KL}
L=Lrec+λLKL
其中, λ \lambda λ 是一个超参数,用于平衡重建损失和 KL 散度损失的相对重要性 。在训练过程中,通过最小化这个损失函数,dVAE 可以学习到合适的离散潜在变量表示,以及有效的编码器和解码器参数。
四、与VQVAE对比
4.1 架构
组件 | dVAE | VQ-VAE |
---|---|---|
潜在表示 | 概率分布 | 确定性点 |
量化方式 | Gumbel-Softmax (软量化) | 最近邻搜索 (硬量化) |
梯度传播 | 通过 Gumbel-Softmax | 直通估计 (Straight-Through) |
潜在变量 | 多维离散变量 | 空间网格的离散索引 |
4.2 潜在编码
-
dVAE潜在编码
z ∈ { 0 , 1 } K (one-hot 或 softmax) z \in \{0,1\}^K \quad \text{(one-hot 或 softmax)} z∈{0,1}K(one-hot 或 softmax)- 每个位置独立选择类别
- 支持层次化潜在结构
- 可解释性强(每个维度对应特定概念)
-
VQVAE潜在编码
z ∈ Z H × W (空间索引) z \in \mathbb{Z}^{H \times W} \quad \text{(空间索引)} z∈ZH×W(空间索引)- 空间位置独立量化
- 保持空间结构信息
- 更适合图像数据
4.3 KL 散度对比
-
dVAE:
- 显式计算 KL 散度
- 鼓励后验接近先验(通常均匀分布)
- 提供正则化,防止过拟合
-
VQ-VAE:
- 无显式 KL 项
- 量化过程隐含正则化
- 在训练中保持后验熵恒定
- 第二阶段通过先验模型学习分布
4.4 训练动态对比
特性 | dVAE | VQ-VAE |
---|---|---|
收敛速度 | 较慢(温度退火) | 较快 |
训练稳定性 | 较高(可导) | 中等(依赖 EMA) |
表示灵活性 | 概率分布 | 确定性点 |
超参数敏感度 | 温度参数敏感 | 相对鲁棒 |
端到端训练 | 完全可导 | 需要直通估计 |
五、代码
import torch
import torch.nn as nn
import torch.nn.functional as F
class Codebook(nn.Module):
"""类似 VQ-VAE 的嵌入层"""
def __init__(self, num_embeddings, embedding_dim):
"""
参数:
num_embeddings: 嵌入向量的数量 (类别数)
embedding_dim: 每个嵌入向量的维度
"""
super().__init__()
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
self.embedding.weight.data.uniform_(-1.0/num_embeddings, 1.0/num_embeddings)
def forward(self, z_hard):
"""
参数:
z_hard: 离散索引 [batch_size, latent_dim]
返回:
z_quantized: 量化后的嵌入向量 [batch_size, latent_dim, embedding_dim]
"""
return self.embedding(z_hard)
class dVAE(nn.Module):
"""带嵌入层的离散变分自编码器"""
def __init__(self, input_dim, latent_dim, num_classes, embedding_dim, hidden_dim=512):
"""
参数:
input_dim: 输入数据维度 (如 784 for MNIST)
latent_dim: 潜在变量数量
num_classes: 每个潜在变量的类别数
embedding_dim: 每个类别的嵌入维度
hidden_dim: 隐藏层维度
"""
super().__init__()
self.latent_dim = latent_dim
self.num_classes = num_classes
self.embedding_dim = embedding_dim
# 编码器
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, latent_dim * num_classes) # 输出 logits
)
# 嵌入层 (codebook)
self.codebook = Codebook(num_classes, embedding_dim)
# 解码器
self.decoder = nn.Sequential(
nn.Linear(latent_dim * embedding_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim),
nn.Sigmoid()
)
def encode(self, x):
"""编码输入并返回 logits"""
batch_size = x.size(0)
logits = self.encoder(x)
return logits.view(batch_size, self.latent_dim, self.num_classes)
def quantize(self, logits, temperature=1.0, hard=False):
"""
量化过程:
1. 应用 Gumbel Softmax 获得离散分布
2. 如果是 hard 模式,转换为 one-hot 索引
3. 通过嵌入层获取量化向量
返回:
z_quantized: 量化后的嵌入向量
z_indices: 离散索引 (hard 模式)
z_soft: Gumbel Softmax 输出 (soft 模式)
"""
batch_size, latent_dim, num_classes = logits.size()
# 展平以便批量处理
flat_logits = logits.view(-1, num_classes)
# 应用 Gumbel Softmax
z_soft = gumbel_softmax(flat_logits, temperature, hard=False)
# 获取离散索引 (用于嵌入查找)
z_indices = torch.argmax(z_soft, dim=-1)
# 如果是 hard 模式,使用直通估计
if hard:
# 创建 one-hot 向量
z_hard = F.one_hot(z_indices, num_classes).float()
# 直通技巧: 前向传播离散,反向传播连续
z_soft = (z_hard - z_soft).detach() + z_soft
# 重塑索引
z_indices = z_indices.view(batch_size, latent_dim)
# 通过嵌入层获取量化向量
z_quantized = self.codebook(z_indices)
# 重塑量化向量: [batch_size, latent_dim * embedding_dim]
z_quantized = z_quantized.view(batch_size, -1)
return z_quantized, z_indices, z_soft
def forward(self, x, temperature=1.0, hard=False):
# 编码
logits = self.encode(x)
# 量化
z_quantized, z_indices, z_soft = self.quantize(logits, temperature, hard)
# 解码重建
recon_x = self.decoder(z_quantized)
return recon_x, logits, z_quantized, z_indices
"""
gumbel softmax
"""
def gumbel_softmax(logits, temperature=1.0, eps=1e-10):
"""
Gumbel-Softmax 函数
返回连续近似和离散索引
"""
# 生成 Gumbel 噪声
uniform = torch.rand_like(logits)
gumbel_noise = -torch.log(-torch.log(uniform + eps) + eps
# 添加噪声并应用带温度的 softmax
y = logits + gumbel_noise
y = F.softmax(y / temperature, dim=-1)
# 获取离散索引
indices = torch.argmax(y, dim=-1)
return y, indices
"""
损失函数
"""
def dvae_loss(recon_x, x, logits, z_quantized, z_soft, codebook, beta=0.25):
"""
dVAE 损失函数:
- 重建损失
- KL 散度
- 嵌入层优化损失 (类似 VQ-VAE)
beta: 嵌入损失权重 (通常 0.1-0.5)
"""
# 重建损失 (均方误差)
recon_loss = F.mse_loss(recon_x, x, reduction='sum')
# KL 散度: 后验 q(z|x) 与先验 p(z) 的 KL 散度
batch_size, latent_dim, num_classes = logits.size()
flat_logits = logits.view(-1, num_classes)
# 假设先验是均匀分布
prior_logits = torch.zeros_like(flat_logits)
q_dist = torch.distributions.Categorical(logits=flat_logits)
p_dist = torch.distributions.Categorical(logits=prior_logits)
kl_div = torch.distributions.kl.kl_divergence(q_dist, p_dist).sum()
# 嵌入层优化损失 (类似 VQ-VAE 的 commitment loss)
# 1. 将嵌入向量视为常数,优化编码器输出
# 2. 将编码器输出视为常数,优化嵌入向量
# 获取解码器输入的量化向量
z_quantized_detached = z_quantized.detach()
# 计算编码器输出与量化向量的距离
# 重塑 z_soft: [batch_size * latent_dim, num_classes]
flat_z_soft = z_soft.view(-1, num_classes)
# 计算编码器输出的嵌入表示
# 使用 z_soft 作为权重,计算加权平均嵌入
embedding_weights = codebook.embedding.weight # [num_classes, embedding_dim]
encoder_embedding = torch.matmul(flat_z_soft, embedding_weights) # [batch*latent, embedding_dim]
encoder_embedding = encoder_embedding.view(batch_size, latent_dim, -1)
# 计算编码器嵌入与量化向量的距离
commitment_loss = F.mse_loss(
encoder_embedding.detach(),
z_quantized
) + F.mse_loss(
encoder_embedding,
z_quantized_detached
)
# 总损失
total_loss = recon_loss + kl_div + beta * commitment_loss
return total_loss, recon_loss, kl_div, commitment_loss
"""
训练过程(包含嵌入优化)
"""
def train_dvae(model, dataloader, epochs=50, lr=1e-3, beta=0.25):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = torch.optim.Adam([
{'params': model.encoder.parameters()},
{'params': model.decoder.parameters()},
{'params': model.codebook.parameters(), 'lr': lr * 10} # 嵌入层更高学习率
], lr=lr)
for epoch in range(epochs):
# 温度退火:从 1.0 线性降到 0.1
temperature = max(0.1, 1.0 - 0.9 * epoch / epochs)
total_loss = 0.0
for x, _ in dataloader:
x = x.to(device).view(x.size(0), -1)
# 前向传播 (训练时使用 soft 模式)
recon_x, logits, z_quantized, z_soft = model(x, temperature=temperature, hard=False)
# 计算损失
loss, recon_loss, kl_div, commit_loss = dvae_loss(
recon_x, x, logits, z_quantized, z_soft, model.codebook, beta
)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}/{epochs} | "
f"Temp: {temperature:.3f} | "
f"Loss: {total_loss/len(dataloader):.4f} | "
f"Recon: {recon_loss.item()/x.size(0):.4f} | "
f"KL: {kl_div.item()/x.size(0):.4f} | "
f"Commit: {commit_loss.item():.4f}")
5.1. 代码细节
在Gumbel Softmax理论中使用 l o g ( p i ) + g u m b e l _ n o i s e log(p_i) + gumbel\_noise log(pi)+gumbel_noise,但在实际实现中我们使用 l o g i t s + g u m b e l _ n o i s e logits + gumbel\_noise logits+gumbel_noise,两者在数学上是等价的,但后者更高效且更稳定。
以下是数学等价性证明
步骤 1:建立关系
令:
- z i = logits 值 z_i = \text{logits 值} zi=logits 值
- p i = softmax ( z i ) = exp ( z i ) ∑ j exp ( z j ) p_i = \text{softmax}(z_i) = \frac{\exp(z_i)}{\sum_j \exp(z_j)} pi=softmax(zi)=∑jexp(zj)exp(zi)
则:
log
(
p
i
)
=
log
(
exp
(
z
i
)
∑
j
exp
(
z
j
)
)
=
z
i
−
log
(
∑
j
exp
(
z
j
)
)
\log(p_i) = \log\left( \frac{\exp(z_i)}{\sum_j \exp(z_j)} \right) = z_i - \log\left( \sum_j \exp(z_j) \right)
log(pi)=log(∑jexp(zj)exp(zi))=zi−log(j∑exp(zj))
令
C
=
log
(
∑
j
exp
(
z
j
)
)
C = \log\left( \sum_j \exp(z_j) \right)
C=log(∑jexp(zj))(对所有
i
i
i 相同的常数),则:
log
(
p
i
)
=
z
i
−
C
\log(p_i) = z_i - C
log(pi)=zi−C
步骤 2:Gumbel-Max 等价性
情况 A:使用
log
(
p
)
\boldsymbol{\log(p)}
log(p)
v
i
A
=
log
(
p
i
)
+
g
i
=
(
z
i
−
C
)
+
g
i
v_i^A = \log(p_i) + g_i = (z_i - C) + g_i
viA=log(pi)+gi=(zi−C)+gi
情况 B:使用
logits
\boldsymbol{\text{logits}}
logits
v
i
B
=
z
i
+
g
i
v_i^B = z_i + g_i
viB=zi+gi
取
argmax
\boldsymbol{\text{argmax}}
argmax:
arg
max
i
v
i
A
=
arg
max
i
(
z
i
−
C
+
g
i
)
=
arg
max
i
(
z
i
+
g
i
)
=
arg
max
i
v
i
B
\arg\max_i v_i^A = \arg\max_i (z_i - C + g_i) = \arg\max_i (z_i + g_i) = \arg\max_i v_i^B
argimaxviA=argimax(zi−C+gi)=argimax(zi+gi)=argimaxviB
因为常数偏移 C C C 不影响 argmax \text{argmax} argmax 结果。
步骤 3:Gumbel-Softmax 等价性
情况 A:使用
log
(
p
)
\boldsymbol{\log(p)}
log(p)
y
i
A
=
exp
(
log
(
p
i
)
+
g
i
τ
)
∑
j
exp
(
log
(
p
j
)
+
g
j
τ
)
=
exp
(
(
z
i
−
C
)
+
g
i
τ
)
∑
j
exp
(
(
z
j
−
C
)
+
g
j
τ
)
y_i^A = \frac{\exp\left( \frac{\log(p_i) + g_i}{\tau} \right)}{\sum_j \exp\left( \frac{\log(p_j) + g_j}{\tau} \right)} = \frac{\exp\left( \frac{(z_i - C) + g_i}{\tau} \right)}{\sum_j \exp\left( \frac{(z_j - C) + g_j}{\tau} \right)}
yiA=∑jexp(τlog(pj)+gj)exp(τlog(pi)+gi)=∑jexp(τ(zj−C)+gj)exp(τ(zi−C)+gi)
情况 B:使用
logits
\boldsymbol{\text{logits}}
logits
y
i
B
=
exp
(
z
i
+
g
i
τ
)
∑
j
exp
(
z
j
+
g
j
τ
)
y_i^B = \frac{\exp\left( \frac{z_i + g_i}{\tau} \right)}{\sum_j \exp\left( \frac{z_j + g_j}{\tau} \right)}
yiB=∑jexp(τzj+gj)exp(τzi+gi)
展开
A
\boldsymbol{A}
A:
y
i
A
=
exp
(
z
i
+
g
i
τ
)
exp
(
−
C
τ
)
∑
j
exp
(
z
j
+
g
j
τ
)
exp
(
−
C
τ
)
=
exp
(
z
i
+
g
i
τ
)
∑
j
exp
(
z
j
+
g
j
τ
)
=
y
i
B
y_i^A = \frac{\exp\left( \frac{z_i + g_i}{\tau} \right) \exp\left( -\frac{C}{\tau} \right)}{\sum_j \exp\left( \frac{z_j + g_j}{\tau} \right) \exp\left( -\frac{C}{\tau} \right)} = \frac{\exp\left( \frac{z_i + g_i}{\tau} \right)}{\sum_j \exp\left( \frac{z_j + g_j}{\tau} \right)} = y_i^B
yiA=∑jexp(τzj+gj)exp(−τC)exp(τzi+gi)exp(−τC)=∑jexp(τzj+gj)exp(τzi+gi)=yiB
常数因子 exp ( − C / τ ) \exp(-C/\tau) exp(−C/τ) 在分子和分母中抵消,因此两者完全相等。