1. StackGAN模型详解
论文链接:https://arxiv.org/pdf/1612.03242v1
StackGAN 是一种用于生成高分辨率图像的生成对抗网络(GAN)模型,尤其在文本到图像生成任务中表现出色。StackGAN基于条件GAN(cGAN)框架,将文本描述作为条件输入,但通过分阶段设计解决了单一Conditional GAN生成高分辨率图像的困难。以下从背景、原理、损失函数推导三个方面详细解析该模型:
一、背景
1. 问题背景
传统GAN在生成高分辨率图像时存在以下问题:
- 模式崩溃model collapse:生成器倾向于生成单一模式的图像,缺乏多样性。
- 低分辨率限制:直接生成高分辨率图像会导致训练不稳定,细节难以捕捉。
- 文本-图像对齐:文本描述与生成图像的内容一致性难以保证。
StackGAN 提出 分阶段生成 策略,将生成过程分解为两个阶段:
- Stage-I:根据文本描述生成低分辨率图像(64x64或128x128),捕捉基本轮廓和颜色。
- Stage-II:基于Stage-I的输出和文本描述生成高分辨率图像(256x256或更高),添加细节并提升真实感。
2. 核心思想
- 逐步细化:通过多阶段生成逐步提升分辨率,避免直接生成高分辨率图像的困难。
- 条件增强(Conditioning Augmentation):引入随机性增强文本条件的鲁棒性,防止过拟合。
二、原理与架构
1. 分阶段生成的思想:
- 人类绘画时通常会先勾勒草图,再逐步细化细节。StackGAN借鉴了这一思想,将生成过程分解为两个阶段:
- Stage-I GAN:生成低分辨率图像(如64x64),捕捉文本描述的全局结构和基本布局。
- Stage-II GAN:基于Stage-I的输出,生成高分辨率图像(如256x256),补充细节并提升真实感。
1. Stage-I:生成低分辨率图像
- 输入:文本编码(通过预训练模型如BERT提取) + 噪声向量。
- 生成器(G₁):
- 将文本编码通过条件增强映射为潜在变量 s ∼ N ( μ ( c ) , Σ ( c ) ) s \sim \mathcal{N}(\mu(c), \Sigma(c)) s∼N(μ(c),Σ(c))。
- 结合(concatenate)噪声向量生成低分辨率图像。
- 判别器(D₁):
- 判断图像是否真实,同时验证图像与文本的对齐性。
2. Stage-II:生成高分辨率图像
- 输入:Stage-I的输出图像 + 文本编码(条件增强)。
- 生成器(G₂):
- 将文本编码通过条件增强映射为潜在变量 s 1 ∼ N ( μ ( c ) , Σ ( c ) ) s_1 \sim \mathcal{N}(\mu(c), \Sigma(c)) s1∼N(μ(c),Σ(c)),同Stage-I模块。
- 结合(concatenate)Stage-I生成的低分辨率图像。
- 细化低分辨率图像,添加细节(如纹理、阴影)。
- 通过残差网络结构保留Stage-I的全局信息。
- 判别器(D₂):
- 判断图像真实性,并确保与文本描述一致。
3. 条件增强(Conditioning Augmentation)
- 目的:解决文本-图像数据对的稀疏性问题,增强生成多样性。
- 方法:
- 对文本嵌入进行随机扰动(通过高斯分布采样),增加生成样本的多样性,避免过拟合
- 步骤
- 将文本编码 c c c 映射为高斯分布 N ( μ ( c ) , Σ ( c ) ) \mathcal{N}(\mu(c), \Sigma(c)) N(μ(c),Σ(c))。
- 采样潜在变量 s 2 ∼ N ( μ ( c ) , Σ ( c ) ) s_2 \sim \mathcal{N}(\mu(c), \Sigma(c)) s2∼N(μ(c),Σ(c)) 作为生成器输入。
- 通过KL散度约束潜在变量分布接近标准正态分布。
- KL散度约束的方法类似VAE的latent space规范化
三、损失函数推导
StackGAN的损失函数由两阶段对抗损失和条件增强的KL散度组成。
1. Stage-I 损失函数
- 对抗损失(条件GAN):
L D 1 = − E x ∼ p data [ log D 1 ( x , c ) ] − E z ∼ p z [ log ( 1 − D 1 ( G 1 ( z , s 1 ) , c ) ) ] \mathcal{L}_{D_1} = -\mathbb{E}_{x \sim p_{\text{data}}}[\log D_1(x, c)] - \mathbb{E}_{z \sim p_z}[\log(1 - D_1(G_1(z, s_1), c))] LD1=−Ex∼pdata[logD1(x,c)]−Ez∼pz[log(1−D1(G1(z,s1),c))]
L G 1 = − E z ∼ p z [ log D 1 ( G 1 ( z , s 1 ) , c ) ] \mathcal{L}_{G_1} = -\mathbb{E}_{z \sim p_z}[\log D_1(G_1(z, s_1), c)] LG1=−Ez∼pz[logD1(G1(z,s1),c)] - 条件增强的KL散度:
L KL 1 = D KL ( N ( μ ( c ) , Σ ( c ) ) ∥ N ( 0 , I ) ) \mathcal{L}_{\text{KL}_1} = D_{\text{KL}}\left(\mathcal{N}(\mu(c), \Sigma(c)) \Vert \mathcal{N}(0, I)\right) LKL1=DKL(N(μ(c),Σ(c))∥N(0,I)) - 总损失:
L Stage-I = L G 1 + λ k l L KL 1 \mathcal{L}_{\text{Stage-I}} = \mathcal{L}_{G_1} + \lambda_{kl} \mathcal{L}_{\text{KL}_1} LStage-I=LG1+λklLKL1
其中 λ k l \lambda_{kl} λkl 为权重系数。
2. Stage-II 损失函数
- 对抗损失:
L D 2 = − E x ∼ p data [ log D 2 ( x , c ) ] − E z ∼ p z , f a k e I ∼ G 1 [ log ( 1 − D 2 ( G 2 ( f a k e I , s 2 ) , c ) ) ] \mathcal{L}_{D_2} = -\mathbb{E}_{x \sim p_{\text{data}}}[\log D_2(x, c)] - \mathbb{E}_{z \sim p_z, fake_I \sim G_1}[\log(1 - D_2(G_2(fake_I, s_2), c))] LD2=−Ex∼pdata[logD2(x,c)]−Ez∼pz,fakeI∼G1[log(1−D2(G2(fakeI,s2),c))]
L G 2 = − E z ∼ p z , f a k e I ∼ G 1 [ log D 2 ( G 2 ( f a k e I , s 2 ) , c ) ] \mathcal{L}_{G_2} = -\mathbb{E}_{z \sim p_z, fake_I \sim G_1}[\log D_2(G_2(fake_I, s_2), c)] LG2=−Ez∼pz,fakeI∼G1[logD2(G2(fakeI,s2),c)] - 条件增强的KL散度:
L KL 2 = D KL ( N ( μ ( c ) , Σ ( c ) ) ∥ N ( 0 , I ) ) \mathcal{L}_{\text{KL}_2} = D_{\text{KL}}\left(\mathcal{N}(\mu(c), \Sigma(c)) \Vert \mathcal{N}(0, I)\right) LKL2=DKL(N(μ(c),Σ(c))∥N(0,I)) - 总损失:
L Stage-II = L G 2 + λ k l L KL 2 \mathcal{L}_{\text{Stage-II}} = \mathcal{L}_{G_2} + \lambda_{kl} \mathcal{L}_{\text{KL}_2} LStage-II=LG2+λklLKL2
3. KL散度推导
条件增强中,潜在变量
s
s
s 服从分布
N
(
μ
(
c
)
,
Σ
(
c
)
)
\mathcal{N}(\mu(c), \Sigma(c))
N(μ(c),Σ(c)),通过KL散度约束其接近标准正态分布:(推导过程参见VAE博客https://blog.csdn.net/sjtu_wyy/article/details/147063416?spm=1001.2014.3001.5501)
D
KL
=
1
2
(
tr
(
Σ
(
c
)
)
+
∥
μ
(
c
)
∥
2
−
dim
(
s
)
−
log
det
(
Σ
(
c
)
)
)
D_{\text{KL}} = \frac{1}{2}\left(\text{tr}(\Sigma(c)) + \|\mu(c)\|^2 - \dim(s) - \log\det(\Sigma(c))\right)
DKL=21(tr(Σ(c))+∥μ(c)∥2−dim(s)−logdet(Σ(c)))
四、关键改进与优势
- 分阶段生成:降低高分辨率生成的难度。
- 残差结构:保留低级特征,提升细节生成能力。
- 条件增强:增强鲁棒性和多样性。
- 文本-图像对齐:通过联合嵌入空间确保内容一致性。
五、总结
StackGAN通过分阶段生成和条件增强技术,有效解决了高分辨率图像生成的难题,尤其在文本到图像任务中实现了质量和多样性平衡。其核心在于将复杂问题分解为多个简单子问题,并通过对抗训练逐步优化。后续工作如StackGAN++进一步优化了多阶段生成结构。
六、代码
"""
模型结构定义
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
# ----------------------------
# Stage-I 生成器
# ----------------------------
class StageI_Generator(nn.Module):
def __init__(self, text_dim=256, noise_dim=100, ca_dim=128):
super().__init__()
# 条件增强(Conditioning Augmentation)
self.fc = nn.Linear(text_dim, 2 * ca_dim)
self.ca_dim = ca_dim
# 主网络:输入噪声 + 条件向量
self.model = nn.Sequential(
nn.Linear(noise_dim + ca_dim, 4*4*512),
nn.BatchNorm1d(4*4*512),
nn.ReLU(),
nn.Unflatten(1, (512, 4, 4)),
# 上采样块(4x4 → 64x64)
nn.ConvTranspose2d(512, 256, 4, 2, 1), # 8x8
nn.BatchNorm2d(256),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, 4, 2, 1), # 16x16
nn.BatchNorm2d(128),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, 2, 1), # 32x32
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, 3, 4, 2, 1), # 64x64
nn.Tanh()
)
def forward(self, z, text_embed):
# 条件增强:文本嵌入 → (μ, Σ) → 采样
h = self.fc(text_embed)
mu, logvar = h[:, :self.ca_dim], h[:, self.ca_dim:]
c = self.reparameterize(mu, torch.exp(0.5*logvar))
# 拼接噪声和条件
x = torch.cat([z, c], dim=1)
return self.model(x), mu, logvar
def reparameterize(self, mu, sigma):
eps = torch.randn_like(sigma)
return mu + eps * sigma
# ----------------------------
# Stage-II 生成器
# ----------------------------
class StageII_Generator(nn.Module):
def __init__(self, text_dim=256, ca_dim=128):
super().__init__()
# 条件增强(复用Stage-I的增强结果)
self.fc = nn.Linear(text_dim, 2 * ca_dim)
# 残差块网络
self.res_blocks = nn.Sequential(
ResidualBlock(512 + ca_dim, 512), # 输入:Stage-I特征 + 条件
ResidualBlock(512, 512),
ResidualBlock(512, 512),
nn.Conv2d(512, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Upsample(scale_factor=2), # 64x64 → 128x128
nn.Conv2d(256, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Upsample(scale_factor=2), # 128x128 → 256x256
nn.Conv2d(128, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 3, 3, padding=1),
nn.Tanh()
)
def forward(self, stageI_img, text_embed):
# 条件增强
h = self.fc(text_embed)
mu, logvar = h[:, :self.ca_dim], h[:, self.ca_dim:]
c = self.reparameterize(mu, torch.exp(0.5*logvar))
# 拼接Stage-I图像特征和条件
c = c.unsqueeze(2).unsqueeze(3).expand(-1, -1, stageI_img.size(2), stageI_img.size(3))
x = torch.cat([stageI_img, c], dim=1)
return self.res_blocks(x), mu, logvar
# ----------------------------
# 残差块(Residual Block)
# ----------------------------
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels)
)
self.shortcut = nn.Identity() if in_channels == out_channels else \
nn.Conv2d(in_channels, out_channels, 1)
def forward(self, x):
return F.relu(self.conv(x) + self.shortcut(x))
# ----------------------------
# 判别器(Stage-I & Stage-II共用结构)
# ----------------------------
class Discriminator(nn.Module):
def __init__(self, text_dim=256, img_channels=3):
super().__init__()
# 图像特征提取
self.img_encoder = nn.Sequential(
nn.Conv2d(img_channels, 64, 4, 2, 1), # 64x64 → 32x32
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, 4, 2, 1), # 32x32 → 16x16
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 256, 4, 2, 1), # 16x16 → 8x8
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),
nn.Conv2d(256, 512, 4, 2, 1), # 8x8 → 4x4
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2),
nn.Flatten()
)
# 文本嵌入处理
self.text_encoder = nn.Sequential(
nn.Linear(text_dim, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 256)
)
# 联合判别
self.joint = nn.Sequential(
nn.Linear(512*4*4 + 256, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, 1),
nn.Sigmoid()
)
def forward(self, img, text_embed):
img_feat = self.img_encoder(img)
text_feat = self.text_encoder(text_embed)
joint_feat = torch.cat([img_feat, text_feat], dim=1)
return self.joint(joint_feat)
"""
训练流程
"""
# 超参数
text_dim = 256 # 文本嵌入维度(需预训练如BERT)
noise_dim = 100 # 噪声维度
ca_dim = 128 # 条件增强维度
lambda_kl = 2.0 # KL散度权重
lambda_l1 = 50.0 # L1重构损失权重
# 初始化模型
G1 = StageI_Generator(text_dim, noise_dim, ca_dim)
G2 = StageII_Generator(text_dim, ca_dim)
D1 = Discriminator(text_dim, img_channels=3)
D2 = Discriminator(text_dim, img_channels=3)
# 优化器
opt_G1 = optim.Adam(G1.parameters(), lr=0.0001, betas=(0.5, 0.999))
opt_G2 = optim.Adam(G2.parameters(), lr=0.0001, betas=(0.5, 0.999))
opt_D1 = optim.Adam(D1.parameters(), lr=0.0004, betas=(0.5, 0.999))
opt_D2 = optim.Adam(D2.parameters(), lr=0.0004, betas=(0.5, 0.999))
# 训练循环(伪代码)
for epoch in range(epochs):
for real_img, text_embed in dataloader: # 假设text_embed已预处理
# ----------------------------
# Stage-I 训练
# ----------------------------
# 生成Stage-I图像
z = torch.randn(batch_size, noise_dim)
fake_img1, mu1, logvar1 = G1(z, text_embed)
# 更新D1
opt_D1.zero_grad()
real_validity1 = D1(real_img, text_embed)
fake_validity1 = D1(fake_img1.detach(), text_embed)
d1_loss = -torch.mean(torch.log(real_validity1 + 1e-8) + torch.mean(torch.log(1 - fake_validity1 + 1e-8))
d1_loss.backward()
opt_D1.step()
# 更新G1
opt_G1.zero_grad()
fake_validity1 = D1(fake_img1, text_embed)
kl_loss1 = -0.5 * torch.sum(1 + logvar1 - mu1.pow(2) - logvar1.exp())
g1_loss = -torch.mean(torch.log(fake_validity1 + 1e-8)) + lambda_kl * kl_loss1
g1_loss.backward()
opt_G1.step()
# ----------------------------
# Stage-II 训练
# ----------------------------
# 生成Stage-II图像
fake_img2, mu2, logvar2 = G2(fake_img1.detach(), text_embed)
# 更新D2
opt_D2.zero_grad()
real_validity2 = D2(real_img, text_embed)
fake_validity2 = D2(fake_img2.detach(), text_embed)
d2_loss = -torch.mean(torch.log(real_validity2 + 1e-8)) + torch.mean(torch.log(1 - fake_validity2 + 1e-8))
d2_loss.backward()
opt_D2.step()
# 更新G2
opt_G2.zero_grad()
fake_validity2 = D2(fake_img2, text_embed)
kl_loss2 = -0.5 * torch.sum(1 + logvar2 - mu2.pow(2) - logvar2.exp())
l1_loss = F.l1_loss(fake_img2, real_img) # 假设有高分辨率真实图像
g2_loss = -torch.mean(torch.log(fake_validity2 + 1e-8)) + lambda_kl * kl_loss2 + lambda_l1 * l1_loss
g2_loss.backward()
opt_G2.step()