生成模型StackGAN模型详解

1. StackGAN模型详解

论文链接:https://arxiv.org/pdf/1612.03242v1

在这里插入图片描述

StackGAN 是一种用于生成高分辨率图像的生成对抗网络(GAN)模型,尤其在文本到图像生成任务中表现出色。StackGAN基于条件GAN(cGAN)框架,将文本描述作为条件输入,但通过分阶段设计解决了单一Conditional GAN生成高分辨率图像的困难。以下从背景、原理、损失函数推导三个方面详细解析该模型:


一、背景

1. 问题背景

传统GAN在生成高分辨率图像时存在以下问题:

  • 模式崩溃model collapse:生成器倾向于生成单一模式的图像,缺乏多样性。
  • 低分辨率限制:直接生成高分辨率图像会导致训练不稳定,细节难以捕捉。
  • 文本-图像对齐:文本描述与生成图像的内容一致性难以保证。

StackGAN 提出 分阶段生成 策略,将生成过程分解为两个阶段:

  • Stage-I:根据文本描述生成低分辨率图像(64x64或128x128),捕捉基本轮廓和颜色。
  • Stage-II:基于Stage-I的输出和文本描述生成高分辨率图像(256x256或更高),添加细节并提升真实感。

2. 核心思想

  • 逐步细化:通过多阶段生成逐步提升分辨率,避免直接生成高分辨率图像的困难。
  • 条件增强(Conditioning Augmentation):引入随机性增强文本条件的鲁棒性,防止过拟合。

二、原理与架构

在这里插入图片描述

1. 分阶段生成的思想

  • 人类绘画时通常会先勾勒草图,再逐步细化细节。StackGAN借鉴了这一思想,将生成过程分解为两个阶段:
  • Stage-I GAN:生成低分辨率图像(如64x64),捕捉文本描述的全局结构和基本布局。
  • Stage-II GAN:基于Stage-I的输出,生成高分辨率图像(如256x256),补充细节并提升真实感。

1. Stage-I:生成低分辨率图像

  • 输入:文本编码(通过预训练模型如BERT提取) + 噪声向量。
  • 生成器(G₁)
    • 将文本编码通过条件增强映射为潜在变量 s ∼ N ( μ ( c ) , Σ ( c ) ) s \sim \mathcal{N}(\mu(c), \Sigma(c)) sN(μ(c),Σ(c))
    • 结合(concatenate)噪声向量生成低分辨率图像。
  • 判别器(D₁)
    • 判断图像是否真实,同时验证图像与文本的对齐性。

2. Stage-II:生成高分辨率图像

  • 输入:Stage-I的输出图像 + 文本编码(条件增强)。
  • 生成器(G₂)
    • 将文本编码通过条件增强映射为潜在变量 s 1 ∼ N ( μ ( c ) , Σ ( c ) ) s_1 \sim \mathcal{N}(\mu(c), \Sigma(c)) s1N(μ(c),Σ(c)),同Stage-I模块。
    • 结合(concatenate)Stage-I生成的低分辨率图像。
    • 细化低分辨率图像,添加细节(如纹理、阴影)。
    • 通过残差网络结构保留Stage-I的全局信息。
  • 判别器(D₂)
    • 判断图像真实性,并确保与文本描述一致。

3. 条件增强(Conditioning Augmentation)

  • 目的:解决文本-图像数据对的稀疏性问题,增强生成多样性。
  • 方法
    • 对文本嵌入进行随机扰动(通过高斯分布采样),增加生成样本的多样性,避免过拟合
  • 步骤
    • 将文本编码 c c c 映射为高斯分布 N ( μ ( c ) , Σ ( c ) ) \mathcal{N}(\mu(c), \Sigma(c)) N(μ(c),Σ(c))
    • 采样潜在变量 s 2 ∼ N ( μ ( c ) , Σ ( c ) ) s_2 \sim \mathcal{N}(\mu(c), \Sigma(c)) s2N(μ(c),Σ(c)) 作为生成器输入。
    • 通过KL散度约束潜在变量分布接近标准正态分布。
  • KL散度约束的方法类似VAE的latent space规范化

三、损失函数推导

StackGAN的损失函数由两阶段对抗损失和条件增强的KL散度组成。

1. Stage-I 损失函数

  • 对抗损失(条件GAN):
    L D 1 = − E x ∼ p data [ log ⁡ D 1 ( x , c ) ] − E z ∼ p z [ log ⁡ ( 1 − D 1 ( G 1 ( z , s 1 ) , c ) ) ] \mathcal{L}_{D_1} = -\mathbb{E}_{x \sim p_{\text{data}}}[\log D_1(x, c)] - \mathbb{E}_{z \sim p_z}[\log(1 - D_1(G_1(z, s_1), c))] LD1=Expdata[logD1(x,c)]Ezpz[log(1D1(G1(z,s1),c))]
    L G 1 = − E z ∼ p z [ log ⁡ D 1 ( G 1 ( z , s 1 ) , c ) ] \mathcal{L}_{G_1} = -\mathbb{E}_{z \sim p_z}[\log D_1(G_1(z, s_1), c)] LG1=Ezpz[logD1(G1(z,s1),c)]
  • 条件增强的KL散度
    L KL 1 = D KL ( N ( μ ( c ) , Σ ( c ) ) ∥ N ( 0 , I ) ) \mathcal{L}_{\text{KL}_1} = D_{\text{KL}}\left(\mathcal{N}(\mu(c), \Sigma(c)) \Vert \mathcal{N}(0, I)\right) LKL1=DKL(N(μ(c),Σ(c))N(0,I))
  • 总损失
    L Stage-I = L G 1 + λ k l L KL 1 \mathcal{L}_{\text{Stage-I}} = \mathcal{L}_{G_1} + \lambda_{kl} \mathcal{L}_{\text{KL}_1} LStage-I=LG1+λklLKL1
    其中 λ k l \lambda_{kl} λkl 为权重系数。

2. Stage-II 损失函数

  • 对抗损失
    L D 2 = − E x ∼ p data [ log ⁡ D 2 ( x , c ) ] − E z ∼ p z , f a k e I ∼ G 1 [ log ⁡ ( 1 − D 2 ( G 2 ( f a k e I , s 2 ) , c ) ) ] \mathcal{L}_{D_2} = -\mathbb{E}_{x \sim p_{\text{data}}}[\log D_2(x, c)] - \mathbb{E}_{z \sim p_z, fake_I \sim G_1}[\log(1 - D_2(G_2(fake_I, s_2), c))] LD2=Expdata[logD2(x,c)]Ezpz,fakeIG1[log(1D2(G2(fakeI,s2),c))]
    L G 2 = − E z ∼ p z , f a k e I ∼ G 1 [ log ⁡ D 2 ( G 2 ( f a k e I , s 2 ) , c ) ] \mathcal{L}_{G_2} = -\mathbb{E}_{z \sim p_z, fake_I \sim G_1}[\log D_2(G_2(fake_I, s_2), c)] LG2=Ezpz,fakeIG1[logD2(G2(fakeI,s2),c)]
  • 条件增强的KL散度
    L KL 2 = D KL ( N ( μ ( c ) , Σ ( c ) ) ∥ N ( 0 , I ) ) \mathcal{L}_{\text{KL}_2} = D_{\text{KL}}\left(\mathcal{N}(\mu(c), \Sigma(c)) \Vert \mathcal{N}(0, I)\right) LKL2=DKL(N(μ(c),Σ(c))N(0,I))
  • 总损失
    L Stage-II = L G 2 + λ k l L KL 2 \mathcal{L}_{\text{Stage-II}} = \mathcal{L}_{G_2} + \lambda_{kl} \mathcal{L}_{\text{KL}_2} LStage-II=LG2+λklLKL2

3. KL散度推导

条件增强中,潜在变量 s s s 服从分布 N ( μ ( c ) , Σ ( c ) ) \mathcal{N}(\mu(c), \Sigma(c)) N(μ(c),Σ(c)),通过KL散度约束其接近标准正态分布:(推导过程参见VAE博客https://blog.csdn.net/sjtu_wyy/article/details/147063416?spm=1001.2014.3001.5501)
D KL = 1 2 ( tr ( Σ ( c ) ) + ∥ μ ( c ) ∥ 2 − dim ⁡ ( s ) − log ⁡ det ⁡ ( Σ ( c ) ) ) D_{\text{KL}} = \frac{1}{2}\left(\text{tr}(\Sigma(c)) + \|\mu(c)\|^2 - \dim(s) - \log\det(\Sigma(c))\right) DKL=21(tr(Σ(c))+μ(c)2dim(s)logdet(Σ(c)))


四、关键改进与优势

  1. 分阶段生成:降低高分辨率生成的难度。
  2. 残差结构:保留低级特征,提升细节生成能力。
  3. 条件增强:增强鲁棒性和多样性。
  4. 文本-图像对齐:通过联合嵌入空间确保内容一致性。

五、总结

StackGAN通过分阶段生成和条件增强技术,有效解决了高分辨率图像生成的难题,尤其在文本到图像任务中实现了质量和多样性平衡。其核心在于将复杂问题分解为多个简单子问题,并通过对抗训练逐步优化。后续工作如StackGAN++进一步优化了多阶段生成结构。

六、代码

"""
    模型结构定义
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

# ----------------------------
# Stage-I 生成器
# ----------------------------
class StageI_Generator(nn.Module):
    def __init__(self, text_dim=256, noise_dim=100, ca_dim=128):
        super().__init__()
        # 条件增强(Conditioning Augmentation)
        self.fc = nn.Linear(text_dim, 2 * ca_dim)
        self.ca_dim = ca_dim
        
        # 主网络:输入噪声 + 条件向量
        self.model = nn.Sequential(
            nn.Linear(noise_dim + ca_dim, 4*4*512),
            nn.BatchNorm1d(4*4*512),
            nn.ReLU(),
            nn.Unflatten(1, (512, 4, 4)),
            # 上采样块(4x4 → 64x64)
            nn.ConvTranspose2d(512, 256, 4, 2, 1),  # 8x8
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),  # 16x16
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),  # 32x32
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, 2, 1),     # 64x64
            nn.Tanh()
        )

    def forward(self, z, text_embed):
        # 条件增强:文本嵌入 → (μ, Σ) → 采样
        h = self.fc(text_embed)
        mu, logvar = h[:, :self.ca_dim], h[:, self.ca_dim:]
        c = self.reparameterize(mu, torch.exp(0.5*logvar))
        # 拼接噪声和条件
        x = torch.cat([z, c], dim=1)
        return self.model(x), mu, logvar

    def reparameterize(self, mu, sigma):
        eps = torch.randn_like(sigma)
        return mu + eps * sigma

# ----------------------------
# Stage-II 生成器
# ----------------------------
class StageII_Generator(nn.Module):
    def __init__(self, text_dim=256, ca_dim=128):
        super().__init__()
        # 条件增强(复用Stage-I的增强结果)
        self.fc = nn.Linear(text_dim, 2 * ca_dim)
        
        # 残差块网络
        self.res_blocks = nn.Sequential(
            ResidualBlock(512 + ca_dim, 512),  # 输入:Stage-I特征 + 条件
            ResidualBlock(512, 512),
            ResidualBlock(512, 512),
            nn.Conv2d(512, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),      # 64x64 → 128x128
            nn.Conv2d(256, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),      # 128x128 → 256x256
            nn.Conv2d(128, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 3, 3, padding=1),
            nn.Tanh()
        )

    def forward(self, stageI_img, text_embed):
        # 条件增强
        h = self.fc(text_embed)
        mu, logvar = h[:, :self.ca_dim], h[:, self.ca_dim:]
        c = self.reparameterize(mu, torch.exp(0.5*logvar))
        # 拼接Stage-I图像特征和条件
        c = c.unsqueeze(2).unsqueeze(3).expand(-1, -1, stageI_img.size(2), stageI_img.size(3))
        x = torch.cat([stageI_img, c], dim=1)
        return self.res_blocks(x), mu, logvar

# ----------------------------
# 残差块(Residual Block)
# ----------------------------
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels)
        )
        self.shortcut = nn.Identity() if in_channels == out_channels else \
                        nn.Conv2d(in_channels, out_channels, 1)

    def forward(self, x):
        return F.relu(self.conv(x) + self.shortcut(x))

# ----------------------------
# 判别器(Stage-I & Stage-II共用结构)
# ----------------------------
class Discriminator(nn.Module):
    def __init__(self, text_dim=256, img_channels=3):
        super().__init__()
        # 图像特征提取
        self.img_encoder = nn.Sequential(
            nn.Conv2d(img_channels, 64, 4, 2, 1),  # 64x64 → 32x32
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),           # 32x32 → 16x16
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1),           # 16x16 → 8x8
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 4, 2, 1),            # 8x8 → 4x4
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Flatten()
        )
        # 文本嵌入处理
        self.text_encoder = nn.Sequential(
            nn.Linear(text_dim, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 256)
        )
        # 联合判别
        self.joint = nn.Sequential(
            nn.Linear(512*4*4 + 256, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, img, text_embed):
        img_feat = self.img_encoder(img)
        text_feat = self.text_encoder(text_embed)
        joint_feat = torch.cat([img_feat, text_feat], dim=1)
        return self.joint(joint_feat)

"""
     训练流程
"""
# 超参数
text_dim = 256       # 文本嵌入维度(需预训练如BERT)
noise_dim = 100      # 噪声维度
ca_dim = 128         # 条件增强维度
lambda_kl = 2.0      # KL散度权重
lambda_l1 = 50.0     # L1重构损失权重

# 初始化模型
G1 = StageI_Generator(text_dim, noise_dim, ca_dim)
G2 = StageII_Generator(text_dim, ca_dim)
D1 = Discriminator(text_dim, img_channels=3)
D2 = Discriminator(text_dim, img_channels=3)

# 优化器
opt_G1 = optim.Adam(G1.parameters(), lr=0.0001, betas=(0.5, 0.999))
opt_G2 = optim.Adam(G2.parameters(), lr=0.0001, betas=(0.5, 0.999))
opt_D1 = optim.Adam(D1.parameters(), lr=0.0004, betas=(0.5, 0.999))
opt_D2 = optim.Adam(D2.parameters(), lr=0.0004, betas=(0.5, 0.999))

# 训练循环(伪代码)
for epoch in range(epochs):
    for real_img, text_embed in dataloader:  # 假设text_embed已预处理
        # ----------------------------
        # Stage-I 训练
        # ----------------------------
        # 生成Stage-I图像
        z = torch.randn(batch_size, noise_dim)
        fake_img1, mu1, logvar1 = G1(z, text_embed)
        
        # 更新D1
        opt_D1.zero_grad()
        real_validity1 = D1(real_img, text_embed)
        fake_validity1 = D1(fake_img1.detach(), text_embed)
        d1_loss = -torch.mean(torch.log(real_validity1 + 1e-8) + torch.mean(torch.log(1 - fake_validity1 + 1e-8))
        d1_loss.backward()
        opt_D1.step()
        
        # 更新G1
        opt_G1.zero_grad()
        fake_validity1 = D1(fake_img1, text_embed)
        kl_loss1 = -0.5 * torch.sum(1 + logvar1 - mu1.pow(2) - logvar1.exp())
        g1_loss = -torch.mean(torch.log(fake_validity1 + 1e-8)) + lambda_kl * kl_loss1
        g1_loss.backward()
        opt_G1.step()
        
        # ----------------------------
        # Stage-II 训练
        # ----------------------------
        # 生成Stage-II图像
        fake_img2, mu2, logvar2 = G2(fake_img1.detach(), text_embed)
        
        # 更新D2
        opt_D2.zero_grad()
        real_validity2 = D2(real_img, text_embed)
        fake_validity2 = D2(fake_img2.detach(), text_embed)
        d2_loss = -torch.mean(torch.log(real_validity2 + 1e-8)) + torch.mean(torch.log(1 - fake_validity2 + 1e-8))
        d2_loss.backward()
        opt_D2.step()
        
        # 更新G2
        opt_G2.zero_grad()
        fake_validity2 = D2(fake_img2, text_embed)
        kl_loss2 = -0.5 * torch.sum(1 + logvar2 - mu2.pow(2) - logvar2.exp())
        l1_loss = F.l1_loss(fake_img2, real_img)  # 假设有高分辨率真实图像
        g2_loss = -torch.mean(torch.log(fake_validity2 + 1e-8)) + lambda_kl * kl_loss2 + lambda_l1 * l1_loss
        g2_loss.backward()
        opt_G2.step()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

贝塔西塔

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值