离散变分自编码器(dVAE)详解(采用Gumbel Softmax)

离散变分自编码器(dVAE)详解

一、背景

变分自编码器(VAE)是一种强大的生成模型,在许多领域得到广泛应用。但传统 VAE 通常假设潜在空间是连续的,在处理离散数据或需要离散潜在表示时存在局限性。dVAE(Discrete Variational Autoencoder,离散变分自编码器) 应运而生,旨在解决离散潜在变量的学习和生成问题,能够在离散的潜在空间中进行建模,从而更灵活地处理诸如文本、类别标签等离散数据类型,拓展了变分自编码器的应用范围。

二、原理

2.1 离散变分自编码器(dVAE)与Gumbel Softmax

1. dVAE的核心思想

离散变分自编码器(dVAE)是标准VAE的变体,主要区别在于:

  • 潜在空间离散化:dVAE使用离散潜在变量而非连续变量
  • 结构化表示:离散变量能更好地捕捉数据的分类或分层结构
  • 可解释性:离散潜在变量通常具有更明确的语义含义

2. Gumbel Softmax在dVAE中的作用

当潜在变量离散时,传统的重参数化技巧失效,因为:

  • 离散采样不可导,梯度无法反向传播
  • 无法直接优化离散分布的参数

dVAE和VQVAE总体概念相似,主要不同之处在于引入了Gumbel Softmax进行训练,有效避免了VQ-VAE训练中由于ArgMin操作不能求导而产生的问题(Straight Through Estimator直通估计近似)。

2.2 模型架构

编码器: dVAE 的编码器将输入数据映射到离散潜在变量的概率分布上。不同于 VQ-VAE(Vector Quantized-VAE)等模型直接输出确定性的离散索引,dVAE 输出的是离散变量的概率分布。 一般会使用 Gumbel-Softmax 技巧来对离散采样过程进行松弛,使得在训练过程中可以通过梯度下降来优化模型。具体来说,通过引入 Gumbel 噪声,并结合 Softmax 函数,将离散的采样过程转化为可微的操作,从而可以在反向传播时计算梯度,更新网络参数。

解码器: 解码器则根据从编码器得到的离散潜在变量(采样得到),重建出原始输入数据或者生成新的数据。解码器通常是一个神经网络,它将离散潜在变量映射回数据空间,比如在图像生成任务中,将离散潜在变量映射回图像的像素空间。

离散潜在变量: dVAE 中的离散潜在变量可以是多维的,每个维度都对应着不同的离散取值。这些离散变量的组合构成了潜在空间,模型通过学习这些离散变量的概率分布,来捕捉输入数据的特征和结构。

三、损失函数推导

dVAE 的损失函数主要基于变分推断的原理,目标是最大化证据下界(ELBO,Evidence Lower BOund)。对于给定的输入数据 x x x,其损失函数推导如下:

3.1. 变分下界

根据贝叶斯定理,数据 x x x 的对数似然可以表示为:
log ⁡ p ( x ) = log ⁡ p ( x , z ) p ( z ∣ x ) = log ⁡ p ( x , z ) − log ⁡ p ( z ∣ x ) \log p(x) = \log \frac{p(x,z)}{p(z|x)} = \log p(x, z) - \log p(z|x) logp(x)=logp(zx)p(x,z)=logp(x,z)logp(zx)

通过引入变分分布 q ( z ∣ x ) q(z|x) q(zx),可以得到对数似然的一个下界(证据下界):
log ⁡ p ( x ) ≥ E q ( z ∣ x ) [ log ⁡ p ( x ∣ z ) ] − D K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) \log p(x) \geq \mathbb{E}_{q(z|x)}[\log p(x|z)] - D_{KL}(q(z|x)||p(z)) logp(x)Eq(zx)[logp(xz)]DKL(q(zx)∣∣p(z))

其中, E q ( z ∣ x ) [ log ⁡ p ( x ∣ z ) ] \mathbb{E}_{q(z|x)}[\log p(x|z)] Eq(zx)[logp(xz)]重建损失项,衡量了在给定潜在变量 z z z 时,模型重建输入数据 x x x 的能力;
D K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) D_{KL}(q(z|x)||p(z)) DKL(q(zx)∣∣p(z))KL 散度项,衡量了变分分布 q ( z ∣ x ) q(z|x) q(zx) 与先验分布 p ( z ) p(z) p(z) 的差异。

3.2. 具体到 dVAE

  • 重建损失:在 dVAE 中,由于使用了离散潜在变量,重建损失通常根据具体任务来定义。例如在图像生成中,可能使用像素空间的均方误差(MSE)或者交叉熵损失。假设 x x x 是输入图像, x ^ \hat{x} x^ 是重建图像,那么重建损失可以表示为:
    L r e c = − E q ( z ∣ x ) [ log ⁡ p ( x ∣ z ) ] = ∑ i = 1 N ℓ ( x i , x ^ i ) L_{rec} = -\mathbb{E}_{q(z|x)}[\log p(x|z)] = \sum_{i=1}^{N} \ell(x_i, \hat{x}_i) Lrec=Eq(zx)[logp(xz)]=i=1N(xi,x^i)
    其中, N N N 是数据集中样本的数量, ℓ ( x i , x ^ i ) \ell(x_i, \hat{x}_i) (xi,x^i) 是针对单个样本 i i i 的损失函数。

  • KL 散度损失:对于离散潜在变量的先验分布 p ( z ) p(z) p(z),通常假设为均匀分布或者其他简单的分布。KL 散度项衡量了编码器输出的分布 q ( z ∣ x ) q(z|x) q(zx) 与先验分布 p ( z ) p(z) p(z) 的差异。
    L K L = D K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) = ∑ z q ( z ∣ x ) log ⁡ q ( z ∣ x ) p ( z ) L_{KL} = D_{KL}(q(z|x)||p(z)) = \sum_{z} q(z|x) \log \frac{q(z|x)}{p(z)} LKL=DKL(q(zx)∣∣p(z))=zq(zx)logp(z)q(zx)

最终,dVAE 的损失函数是重建损失和 KL 散度损失的加权和:
L = L r e c + λ L K L L = L_{rec} + \lambda L_{KL} L=Lrec+λLKL

其中, λ \lambda λ 是一个超参数,用于平衡重建损失和 KL 散度损失的相对重要性 。在训练过程中,通过最小化这个损失函数,dVAE 可以学习到合适的离散潜在变量表示,以及有效的编码器和解码器参数。

四、与VQVAE对比

4.1 架构

组件dVAEVQ-VAE
潜在表示概率分布确定性点
量化方式Gumbel-Softmax (软量化)最近邻搜索 (硬量化)
梯度传播通过 Gumbel-Softmax直通估计 (Straight-Through)
潜在变量多维离散变量空间网格的离散索引

4.2 潜在编码

  • dVAE潜在编码
    z ∈ { 0 , 1 } K (one-hot 或 softmax) z \in \{0,1\}^K \quad \text{(one-hot 或 softmax)} z{0,1}K(one-hot  softmax)

    • 每个位置独立选择类别
    • 支持层次化潜在结构
    • 可解释性强(每个维度对应特定概念)
  • VQVAE潜在编码
    z ∈ Z H × W (空间索引) z \in \mathbb{Z}^{H \times W} \quad \text{(空间索引)} zZH×W(空间索引)

    • 空间位置独立量化
    • 保持空间结构信息
    • 更适合图像数据

4.3 KL 散度对比

  • dVAE:

    • 显式计算 KL 散度
    • 鼓励后验接近先验(通常均匀分布)
    • 提供正则化,防止过拟合
  • VQ-VAE:

    • 无显式 KL 项
    • 量化过程隐含正则化
    • 在训练中保持后验熵恒定
    • 第二阶段通过先验模型学习分布

4.4 训练动态对比

特性dVAEVQ-VAE
收敛速度较慢(温度退火)较快
训练稳定性较高(可导)中等(依赖 EMA)
表示灵活性概率分布确定性点
超参数敏感度温度参数敏感相对鲁棒
端到端训练完全可导需要直通估计

五、代码

import torch
import torch.nn as nn
import torch.nn.functional as F

class Codebook(nn.Module):
    """类似 VQ-VAE 的嵌入层"""
    def __init__(self, num_embeddings, embedding_dim):
        """
        参数:
            num_embeddings: 嵌入向量的数量 (类别数)
            embedding_dim: 每个嵌入向量的维度
        """
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1.0/num_embeddings, 1.0/num_embeddings)
    
    def forward(self, z_hard):
        """
        参数:
            z_hard: 离散索引 [batch_size, latent_dim]
        返回:
            z_quantized: 量化后的嵌入向量 [batch_size, latent_dim, embedding_dim]
        """
        return self.embedding(z_hard)

class dVAE(nn.Module):
    """带嵌入层的离散变分自编码器"""
    
    def __init__(self, input_dim, latent_dim, num_classes, embedding_dim, hidden_dim=512):
        """
        参数:
            input_dim: 输入数据维度 (如 784 for MNIST)
            latent_dim: 潜在变量数量
            num_classes: 每个潜在变量的类别数
            embedding_dim: 每个类别的嵌入维度
            hidden_dim: 隐藏层维度
        """
        super().__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.embedding_dim = embedding_dim
        
        # 编码器
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim * num_classes)  # 输出 logits
        )
        
        # 嵌入层 (codebook)
        self.codebook = Codebook(num_classes, embedding_dim)
        
        # 解码器
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim * embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )
    
    def encode(self, x):
        """编码输入并返回 logits"""
        batch_size = x.size(0)
        logits = self.encoder(x)
        return logits.view(batch_size, self.latent_dim, self.num_classes)
    
    def quantize(self, logits, temperature=1.0, hard=False):
        """
        量化过程:
        1. 应用 Gumbel Softmax 获得离散分布
        2. 如果是 hard 模式,转换为 one-hot 索引
        3. 通过嵌入层获取量化向量
        
        返回:
            z_quantized: 量化后的嵌入向量
            z_indices: 离散索引 (hard 模式)
            z_soft: Gumbel Softmax 输出 (soft 模式)
        """
        batch_size, latent_dim, num_classes = logits.size()
        
        # 展平以便批量处理
        flat_logits = logits.view(-1, num_classes)
        
        # 应用 Gumbel Softmax
        z_soft = gumbel_softmax(flat_logits, temperature, hard=False)
        
        # 获取离散索引 (用于嵌入查找)
        z_indices = torch.argmax(z_soft, dim=-1)
        
        # 如果是 hard 模式,使用直通估计
        if hard:
            # 创建 one-hot 向量
            z_hard = F.one_hot(z_indices, num_classes).float()
            # 直通技巧: 前向传播离散,反向传播连续
            z_soft = (z_hard - z_soft).detach() + z_soft
        
        # 重塑索引
        z_indices = z_indices.view(batch_size, latent_dim)
        
        # 通过嵌入层获取量化向量
        z_quantized = self.codebook(z_indices)
        
        # 重塑量化向量: [batch_size, latent_dim * embedding_dim]
        z_quantized = z_quantized.view(batch_size, -1)
        
        return z_quantized, z_indices, z_soft
    
    def forward(self, x, temperature=1.0, hard=False):
        # 编码
        logits = self.encode(x)
        
        # 量化
        z_quantized, z_indices, z_soft = self.quantize(logits, temperature, hard)
        
        # 解码重建
        recon_x = self.decoder(z_quantized)
        
        return recon_x, logits, z_quantized, z_indices

"""
gumbel softmax
"""
def gumbel_softmax(logits, temperature=1.0, eps=1e-10):
    """
    Gumbel-Softmax 函数
    返回连续近似和离散索引
    """
    # 生成 Gumbel 噪声
    uniform = torch.rand_like(logits)
    gumbel_noise = -torch.log(-torch.log(uniform + eps) + eps
    
    # 添加噪声并应用带温度的 softmax
    y = logits + gumbel_noise
    y = F.softmax(y / temperature, dim=-1)
    
    # 获取离散索引
    indices = torch.argmax(y, dim=-1)
    
    return y, indices

"""
损失函数
"""
def dvae_loss(recon_x, x, logits, z_quantized, z_soft, codebook, beta=0.25):
    """
    dVAE 损失函数:
    - 重建损失
    - KL 散度
    - 嵌入层优化损失 (类似 VQ-VAE)
    
    beta: 嵌入损失权重 (通常 0.1-0.5)
    """
    # 重建损失 (均方误差)
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    
    # KL 散度: 后验 q(z|x) 与先验 p(z) 的 KL 散度
    batch_size, latent_dim, num_classes = logits.size()
    flat_logits = logits.view(-1, num_classes)
    
    # 假设先验是均匀分布
    prior_logits = torch.zeros_like(flat_logits)
    
    q_dist = torch.distributions.Categorical(logits=flat_logits)
    p_dist = torch.distributions.Categorical(logits=prior_logits)
    
    kl_div = torch.distributions.kl.kl_divergence(q_dist, p_dist).sum()
    
    # 嵌入层优化损失 (类似 VQ-VAE 的 commitment loss)
    # 1. 将嵌入向量视为常数,优化编码器输出
    # 2. 将编码器输出视为常数,优化嵌入向量
    
    # 获取解码器输入的量化向量
    z_quantized_detached = z_quantized.detach()
    
    # 计算编码器输出与量化向量的距离
    # 重塑 z_soft: [batch_size * latent_dim, num_classes]
    flat_z_soft = z_soft.view(-1, num_classes)
    
    # 计算编码器输出的嵌入表示
    # 使用 z_soft 作为权重,计算加权平均嵌入
    embedding_weights = codebook.embedding.weight  # [num_classes, embedding_dim]
    encoder_embedding = torch.matmul(flat_z_soft, embedding_weights)  # [batch*latent, embedding_dim]
    encoder_embedding = encoder_embedding.view(batch_size, latent_dim, -1)
    
    # 计算编码器嵌入与量化向量的距离
    commitment_loss = F.mse_loss(
        encoder_embedding.detach(), 
        z_quantized
    ) + F.mse_loss(
        encoder_embedding,
        z_quantized_detached
    )
    
    # 总损失
    total_loss = recon_loss + kl_div + beta * commitment_loss
    
    return total_loss, recon_loss, kl_div, commitment_loss

"""
训练过程(包含嵌入优化)
"""
def train_dvae(model, dataloader, epochs=50, lr=1e-3, beta=0.25):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    optimizer = torch.optim.Adam([
        {'params': model.encoder.parameters()},
        {'params': model.decoder.parameters()},
        {'params': model.codebook.parameters(), 'lr': lr * 10}  # 嵌入层更高学习率
    ], lr=lr)
    
    for epoch in range(epochs):
        # 温度退火:从 1.0 线性降到 0.1
        temperature = max(0.1, 1.0 - 0.9 * epoch / epochs)
        
        total_loss = 0.0
        for x, _ in dataloader:
            x = x.to(device).view(x.size(0), -1)
            
            # 前向传播 (训练时使用 soft 模式)
            recon_x, logits, z_quantized, z_soft = model(x, temperature=temperature, hard=False)
            
            # 计算损失
            loss, recon_loss, kl_div, commit_loss = dvae_loss(
                recon_x, x, logits, z_quantized, z_soft, model.codebook, beta
            )
            
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{epochs} | "
              f"Temp: {temperature:.3f} | "
              f"Loss: {total_loss/len(dataloader):.4f} | "
              f"Recon: {recon_loss.item()/x.size(0):.4f} | "
              f"KL: {kl_div.item()/x.size(0):.4f} | "
              f"Commit: {commit_loss.item():.4f}")

5.1. 代码细节

在Gumbel Softmax理论中使用 l o g ( p i ) + g u m b e l _ n o i s e log(p_i) + gumbel\_noise log(pi)+gumbel_noise,但在实际实现中我们使用 l o g i t s + g u m b e l _ n o i s e logits + gumbel\_noise logits+gumbel_noise,两者在数学上是等价的,但后者更高效且更稳定。

以下是数学等价性证明

步骤 1:建立关系

令:

  • z i = logits 值 z_i = \text{logits 值} zi=logits 
  • p i = softmax ( z i ) = exp ⁡ ( z i ) ∑ j exp ⁡ ( z j ) p_i = \text{softmax}(z_i) = \frac{\exp(z_i)}{\sum_j \exp(z_j)} pi=softmax(zi)=jexp(zj)exp(zi)

则:
log ⁡ ( p i ) = log ⁡ ( exp ⁡ ( z i ) ∑ j exp ⁡ ( z j ) ) = z i − log ⁡ ( ∑ j exp ⁡ ( z j ) ) \log(p_i) = \log\left( \frac{\exp(z_i)}{\sum_j \exp(z_j)} \right) = z_i - \log\left( \sum_j \exp(z_j) \right) log(pi)=log(jexp(zj)exp(zi))=zilog(jexp(zj))

C = log ⁡ ( ∑ j exp ⁡ ( z j ) ) C = \log\left( \sum_j \exp(z_j) \right) C=log(jexp(zj))(对所有 i i i 相同的常数),则:
log ⁡ ( p i ) = z i − C \log(p_i) = z_i - C log(pi)=ziC

步骤 2:Gumbel-Max 等价性

情况 A:使用 log ⁡ ( p ) \boldsymbol{\log(p)} log(p)
v i A = log ⁡ ( p i ) + g i = ( z i − C ) + g i v_i^A = \log(p_i) + g_i = (z_i - C) + g_i viA=log(pi)+gi=(ziC)+gi

情况 B:使用 logits \boldsymbol{\text{logits}} logits
v i B = z i + g i v_i^B = z_i + g_i viB=zi+gi

argmax \boldsymbol{\text{argmax}} argmax
arg ⁡ max ⁡ i v i A = arg ⁡ max ⁡ i ( z i − C + g i ) = arg ⁡ max ⁡ i ( z i + g i ) = arg ⁡ max ⁡ i v i B \arg\max_i v_i^A = \arg\max_i (z_i - C + g_i) = \arg\max_i (z_i + g_i) = \arg\max_i v_i^B argimaxviA=argimax(ziC+gi)=argimax(zi+gi)=argimaxviB

因为常数偏移 C C C 不影响 argmax \text{argmax} argmax 结果。

步骤 3:Gumbel-Softmax 等价性

情况 A:使用 log ⁡ ( p ) \boldsymbol{\log(p)} log(p)
y i A = exp ⁡ ( log ⁡ ( p i ) + g i τ ) ∑ j exp ⁡ ( log ⁡ ( p j ) + g j τ ) = exp ⁡ ( ( z i − C ) + g i τ ) ∑ j exp ⁡ ( ( z j − C ) + g j τ ) y_i^A = \frac{\exp\left( \frac{\log(p_i) + g_i}{\tau} \right)}{\sum_j \exp\left( \frac{\log(p_j) + g_j}{\tau} \right)} = \frac{\exp\left( \frac{(z_i - C) + g_i}{\tau} \right)}{\sum_j \exp\left( \frac{(z_j - C) + g_j}{\tau} \right)} yiA=jexp(τlog(pj)+gj)exp(τlog(pi)+gi)=jexp(τ(zjC)+gj)exp(τ(ziC)+gi)

情况 B:使用 logits \boldsymbol{\text{logits}} logits
y i B = exp ⁡ ( z i + g i τ ) ∑ j exp ⁡ ( z j + g j τ ) y_i^B = \frac{\exp\left( \frac{z_i + g_i}{\tau} \right)}{\sum_j \exp\left( \frac{z_j + g_j}{\tau} \right)} yiB=jexp(τzj+gj)exp(τzi+gi)

展开 A \boldsymbol{A} A
y i A = exp ⁡ ( z i + g i τ ) exp ⁡ ( − C τ ) ∑ j exp ⁡ ( z j + g j τ ) exp ⁡ ( − C τ ) = exp ⁡ ( z i + g i τ ) ∑ j exp ⁡ ( z j + g j τ ) = y i B y_i^A = \frac{\exp\left( \frac{z_i + g_i}{\tau} \right) \exp\left( -\frac{C}{\tau} \right)}{\sum_j \exp\left( \frac{z_j + g_j}{\tau} \right) \exp\left( -\frac{C}{\tau} \right)} = \frac{\exp\left( \frac{z_i + g_i}{\tau} \right)}{\sum_j \exp\left( \frac{z_j + g_j}{\tau} \right)} = y_i^B yiA=jexp(τzj+gj)exp(τC)exp(τzi+gi)exp(τC)=jexp(τzj+gj)exp(τzi+gi)=yiB

常数因子 exp ⁡ ( − C / τ ) \exp(-C/\tau) exp(C/τ) 在分子和分母中抵消,因此两者完全相等。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

贝塔西塔

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值