Factor VAE模型详解

Factor VAE模型详解

论文链接:https://arxiv.org/abs/1802.05983

FactorVAE是一种旨在提升潜在变量解耦能力的生成模型,通过显式地减少潜在变量间的依赖性来实现。以下从背景、原理、损失函数推导及与Beta-VAE的对比进行详细说明。
在这里插入图片描述


一、 FactorVAE模型背景

传统变分自编码器(VAE)通过最大化证据下界(ELBO)学习数据的潜在表示,但其潜在变量可能存在纠缠(entanglement),即单个变量控制多个特征。Beta-VAE 通过增大 KL 散度权重 β \beta β间接惩罚总相关性(TC)以实现解耦,但这种方式存在两个主要问题:

  • 信息丢失:过大的 β \beta β不仅抑制 TC,还会压缩索引码互信息 I q ( z ; x ) I_q(z;x) Iq(z;x),导致潜在变量丢失重要数据信息。
    ‌- 解耦效率低‌:Beta-VAE 隐式控制 TC,无法直接优化变量间的独立性,解耦效果受限。
    ‌FactorVAE 的改进‌:
  • FactorVAE ‌显式地惩罚总相关性(TC)‌,在保持重构能力的同时,更高效地促进潜在变量独立,解决了 Beta-VAE 的信息丢失问题。

二、 FactorVAE原理

2.1. 总相关(TC)的定义

TC是多变量间依赖性的度量,定义为联合分布与边缘分布乘积的KL散度:
TC ( q ( z ) ) = K L ( q ( z ) ∥ ∏ i q ( z i ) ) \text{TC}(q(z)) = KL\left( q(z) \parallel \prod_{i} q(z_i) \right) TC(q(z))=KL(q(z)iq(zi))
最小化TC可促使潜在变量相互独立。

2.2. KL散度的分解

VAE的KL散度项可分解为:
K L ( q ( z ∣ x ) ∥ p ( z ) ) = TC ( q ( z ) ) ⏟ 变量间依赖 + ∑ i K L ( q ( z i ) ∥ p ( z i ) ) ⏟ 维度独立KL散度 + I q ( z ; x ) ⏟ 索引码互信息 KL(q(z|x)\parallel p(z)) = \underbrace{\text{TC}(q(z))}_{\text{变量间依赖}} + \underbrace{\sum_i KL(q(z_i)\parallel p(z_i))}_{\text{维度独立KL散度}} + \underbrace{I_q(z;x)}_{索引码互信息} KL(q(zx)p(z))=变量间依赖 TC(q(z))+维度独立KL散度 iKL(q(zi)p(zi))+索引码互信息 Iq(z;x)
FactorVAE显式惩罚TC项,而Beta-VAE整体缩放整个KL项。


三、 损失函数推导

FactorVAE的损失函数在ELBO基础上加入TC正则项:
L FactorVAE = E q ( z ∣ x ) [ log ⁡ p ( x ∣ z ) ] − ∑ i K L ( q ( z i ∣ x ) ∥ p ( z i ) ) − γ ⋅ TC ( q ( z ) ) \mathcal{L}_{\text{FactorVAE}} = \mathbb{E}_{q(z|x)}[\log p(x|z)] - \sum_i KL(q(z_i|x)\parallel p(z_i)) - \gamma \cdot \text{TC}(q(z)) LFactorVAE=Eq(zx)[logp(xz)]iKL(q(zix)p(zi))γTC(q(z))
其中:

  • 第一项为重构损失,保证生成质量。
  • 第二项确保各潜在变量与先验分布(如标准正态)匹配。
  • 第三项(γ > 0)惩罚变量间依赖性,γ控制解耦强度。

3.1. TC的估计

直接计算TC困难,FactorVAE通过密度比估计Density Ratio Estimation来近似估计:
引入判别器 D ( z ) D(z) D(z)区分联合样本 z ∼ q ( z ) z\sim q(z) zq(z)和独立样本 z ∼ Π i = 1 d q ( z i ) z\sim \Pi_{i=1}^dq(z_i) zΠi=1dq(zi),通过对抗训练估计密度比,当判别器 D ( z ) D(z) D(z)最优时,判别器的输出等于 q ( z ) Π d = 1 d q ( z i ) \frac{q(z)}{\Pi_{d=1}^dq(z_i)} Πd=1dq(zi)q(z)
T C = E q ( z ) log ⁡ q ( z ) Π d = 1 d q ( z i ) = E q ( z ) log ⁡ D ( z ) \begin{aligned} TC&=\mathbb{E}_q(z)\log \frac{q(z)}{\Pi_{d=1}^dq(z_i)}\\ &=\mathbb{E}_{q(z)}\log D(z) \end{aligned} TC=Eq(z)logΠd=1dq(zi)q(z)=Eq(z)logD(z)

  1. 采样策略:从联合分布 q ( z ) q(z) q(z)采样真实样本 z z z,并通过打乱各维度生成近似独立样本 z ~ ∼ ∏ i q ( z i ) \tilde{z} \sim \prod_i q(z_i) z~iq(zi)
  2. 训练鉴别器 D D D:区分真实样本 z z z与打乱样本 z ~ \tilde{z} z~,损失函数为:
    L D = E z ∼ q ( z ) [ log ⁡ D ( z ) ] + E z ~ ∼ ∏ q ( z i ) [ log ⁡ ( 1 − D ( z ~ ) ) ] \mathcal{L}_D = \mathbb{E}_{z \sim q(z)}[\log D(z)] + \mathbb{E}_{\tilde{z} \sim \prod q(z_i)}[\log (1-D(\tilde{z}))] LD=Ezq(z)[logD(z)]+Ez~q(zi)[log(1D(z~))]
  3. 估计TC:利用鉴别器输出计算:
    TC ≈ E z ∼ q ( z ) [ log ⁡ D ( z ) 1 − D ( z ) ] \text{TC} \approx \mathbb{E}_{z \sim q(z)}\left[\log \frac{D(z)}{1-D(z)}\right] TCEzq(z)[log1D(z)D(z)]

四、 FactorVAE vs. Beta-VAE

4.1 区别

特性Beta-VAEFactorVAE
正则化目标整体缩放KL散度( β ⋅ K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) ) \beta·KL(q(z|x)||p(z))) βKL(q(zx)∣∣p(z))显式惩罚TC项( γ ⋅ T C ( q ( z ) \gamma\cdot TC(q(z) γTC(q(z)))
解耦控制间接,隐式通过增大 β \beta β间接影响 TC,可能影响各变量与先验的匹配显式优化TC作为独立正则项,直接针对变量间依赖,保留各变量独立性
重构质量 β \beta β值易导致重构模糊更优平衡,解耦与重构质量兼顾
实现复杂度简单,仅需调整 β \beta β复杂,需额外训练判别器估计 TC

4.2 FactorVAE优势

  1. 解耦效果更强:显式优化TC,直接减少变量间依赖,在解耦指标(如MIG)上表现更优。
  2. 灵活调控 γ \gamma γ独立控制TC项,不影响各变量与先验的匹配(如Beta-VAE中 β \beta β增大会强制所有变量趋近先验)。
  3. 生成质量保留:避免过度约束KL散度,重构图像更清晰。

五、 总结

FactorVAE通过分解KL散度并显式惩罚总相关,在解耦表示学习中实现了更精细的控制。相比Beta-VAE,其优势在于直接针对变量间依赖性进行优化,避免了全局KL缩放带来的信息损失,从而在解耦能力和生成质量间达到更优平衡。

六、代码

import torch
import torch.nn as nn
import torch.nn.functional as F

# 超参数
latent_dim = 64     # 潜在空间维度
gamma = 10.0        # TC 正则项权重
alpha = 1.0         # KL 散度权重

# 编码器(同 Beta-VAE)
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1),  # 64x64 → 32x32
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1), # 32x32 → 16x16
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1),# 16x16 → 8x8
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, 2, 1), # 8x8 → 4x4
            nn.ReLU(),
            nn.Flatten()
        )
        self.fc_mu = nn.Linear(256*4*4, latent_dim)
        self.fc_logvar = nn.Linear(256*4*4, latent_dim)

    def forward(self, x):
        h = self.conv(x)
        return self.fc_mu(h), self.fc_logvar(h)

# 解码器(同 Beta-VAE)
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(latent_dim, 256*4*4)
        self.conv = nn.Sequential(
            nn.Unflatten(1, (256, 4, 4)),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),  # 4x4 → 8x8
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),   # 8x8 → 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),    # 16x16 → 32x32
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, 2, 1),     # 32x32 → 64x64
            nn.Tanh()
        )

    def forward(self, z):
        h = self.fc(z)
        return self.conv(h)

# TC 判别器(估计密度比)
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1)
        )

    def forward(self, z):
        return self.model(z)

# FactorVAE 主模型
class FactorVAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.discriminator = Discriminator()

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar

    def loss_function(self, x, x_recon, mu, logvar, z):
        # 重构损失
        recon_loss = F.mse_loss(x_recon, x, reduction='sum')

        # KL 散度(各维度独立)
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        # 估计 TC(通过判别器)
        log_qz = self.discriminator(z)
        permuted_z = z[:, torch.randperm(z.size(1))]  # 打乱维度,近似 q(z_i)
        log_qz_product = self.discriminator(permuted_z)
        tc_loss = (log_qz - log_qz_product).mean()

        # 总损失
        total_loss = recon_loss + alpha * kl_loss + gamma * tc_loss
        return total_loss, recon_loss, kl_loss, tc_loss

# 训练示例
if __name__ == "__main__":
    model = FactorVAE()
    optimizer_G = torch.optim.Adam(list(model.encoder.parameters()) + list(model.decoder.parameters()), lr=1e-4)
    optimizer_D = torch.optim.Adam(model.discriminator.parameters(), lr=1e-4)

    # 假设 dataloader 已定义
    for epoch in range(100):
        for x, _ in dataloader:
            # 训练生成器(编码器+解码器)
            x_recon, mu, logvar = model(x)
            z = model.reparameterize(mu, logvar)
            total_loss, recon_loss, kl_loss, tc_loss = model.loss_function(x, x_recon, mu, logvar, z)
            
            optimizer_G.zero_grad()
            total_loss.backward(retain_graph=True) # 保留计算图
            optimizer_G.step()

            # 训练判别器
            with torch.no_grad():
                permuted_z = z[:, torch.randperm(z.size(1))]
            
            log_qz = model.discriminator(z)
            log_qz_product = model.discriminator(permuted_z)
            d_loss = 0.5 * (F.binary_cross_entropy_with_logits(log_qz, torch.ones_like(log_qz)) +
                          F.binary_cross_entropy_with_logits(log_qz_product, torch.zeros_like(log_qz_product)))

            optimizer_D.zero_grad()
            d_loss.backward()
            optimizer_D.step()

代码实现细节

  • log_qz:判别器对原始潜在变量 z z z(来自联合分布 q ( z ) q(z) q(z))的输出。
  • log_qz_product:判别器对维度打乱后的潜在变量 z p e r m u t e d z_{permuted} zpermuted(近似来自乘积分布 Π i = 1 d q ( z i ) \Pi_{i=1}^d q(z_i) Πi=1dq(zi))的输出。
  • TC Loss:通过以下公式近似:
    T C L o s s = E q ( z ) log ⁡ q ( z ) Π i = 1 d q ( z i ) ≈ ( l o g _ q z − l o g _ q z _ p r o d u c t ) . m e a n ( ) \begin{aligned} TC Loss&=\mathbb{E}_{q(z)}\log \frac{q(z)}{\Pi_{i=1}^d q(z_i)}\\ &\approx(log\_qz-log\_qz\_product).mean() \end{aligned} TCLoss=Eq(z)logΠi=1dq(zi)q(z)(log_qzlog_qz_product).mean()
  • 在PyTorch中,二元交叉熵损失函数 F.binary_cross_entropy_with_logits 内部已包含Sigmoid激活,因此
    • 判别器输出无需额外激活:直接输出logits(未归一化的分数),函数内部会自动应用Sigmoid。

    • 数值稳定性:直接使用logits计算损失比手动应用Sigmoid更稳定(避免数值溢出)。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

贝塔西塔

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值