零样本文本到图像生成:一篇经典论文的深度解析
对于深度学习研究者来说,Transformer 和大语言模型(LLM)早已是耳熟能详的工具。然而,在这些技术成为主流之前,一些开创性的工作奠定了它们在多模态领域的基石。今天,我们来回顾一篇经典论文——《Zero-Shot Text-to-Image Generation》,由 Aditya Ramesh 等人在 OpenAI 发表。这篇论文提出了一种基于大规模自回归 Transformer 的简单方法,用于从文本生成图像,并在零样本(zero-shot)场景下展现了惊艳的表现。以下将详细介绍其核心思想、技术细节和数学公式,适合熟悉 Transformer 和 LLM 的研究者深入理解。
原论文:https://arxiv.org/pdf/2102.12092
核心思想:用自回归 Transformer 统一建模文本和图像
论文的核心创新在于将文本到图像生成任务简化为一个序列建模问题,通过一个大规模自回归 Transformer 来联合建模文本和图像的 token。这种方法摒弃了传统方法中复杂的架构设计(如多尺度生成器或额外的辅助损失),转而依靠数据规模和模型容量来提升性能。其关键假设是:如果有足够的数据和计算资源,一个简单的自回归模型就能在零样本条件下与专门设计的领域特定模型竞争。
具体来说,作者将文本和图像表示为一个连续的 token 序列:
- 文本:通过 BPE(Byte Pair Encoding)编码为最多 256 个 token。
- 图像:通过一个离散变分自编码器(dVAE)将 256×256 的 RGB 图像压缩为 32×32 的 token 网格,每个 token 从 8192 个可能值中取值。
- 联合建模:将文本 token 和图像 token 拼接成一个长序列(最多 256 + 1024 = 1280 个 token),然后用 Transformer 自回归地预测下一个 token。
这种方法的好处在于,它将多模态任务转化为 Transformer 擅长的单模态序列预测问题,同时通过 dVAE 的压缩大幅减少了图像的上下文长度(从像素级别的 256×256×3 到 token 级别的 32×32),从而降低了计算成本。
技术细节:两阶段训练流程
为了实现这一目标,论文提出了一个两阶段训练流程:
阶段 1:训练离散变分自编码器(dVAE)
dVAE 的作用是将高分辨率图像压缩为离散 token,同时保留主要视觉特征。训练目标是最大化证据下界(Evidence Lower Bound, ELB),其数学形式为:
ln p θ , ψ ( x , y ) ⩾ E z ∼ q ϕ ( z ∣ x ) [ ln p θ ( x ∣ y , z ) − β D K L ( q ϕ ( z ∣ x ) ∥ p ψ ( y , z ) ) ] \ln p_{\theta, \psi}(x, y) \geqslant \mathbb{E}_{z \sim q_\phi(z \mid x)} \left[ \ln p_\theta(x \mid y, z) - \beta D_{\mathrm{KL}}(q_\phi(z \mid x) \| p_\psi(y, z)) \right] lnpθ,ψ(x,y)⩾Ez∼qϕ(z∣x)[lnpθ(x∣y,z)−βDKL(qϕ(z∣x)∥pψ(y,z))]
其中:
- ( x x x) 是图像,( y y y) 是文本,( z z z) 是 dVAE 编码器生成的图像 token。
- ( q ϕ ( z ∣ x ) q_\phi(z \mid x) qϕ(z∣x)) 是编码器分布,输出 32×32 个离散 token,每个 token 从 8192 个码本向量中选择。
- ( p θ ( x ∣ y , z ) p_\theta(x \mid y, z) pθ(x∣y,z)) 是解码器分布,重建原始图像。
- ( p ψ ( y , z ) p_\psi(y, z) pψ(y,z)) 是 Transformer 建模的先验分布。
- ( β \beta β) 是 KL 散度的权重(实践中设为 6.6,以促进码本使用)。
由于 ( q ϕ q_\phi qϕ) 是离散分布,无法直接用重参数化技巧优化,作者采用了 Gumbel-Softmax 松弛(relaxation)(具体可以参考笔者的另一篇博客:Gumbel 噪声与 Gumbel-Softmax 松弛:从离散到连续的桥梁),将离散采样近似为连续分布,并在训练中逐渐降低温度 ( τ \tau τ)(从 1 到 1/16),使松弛趋于真实离散分布。重建损失则使用了一种新颖的 logit-Laplace 分布,其概率密度函数为(具体可以参考笔者的另一篇博客:Logit-Laplace 分布:解决图像生成中像素值范围匹配问题的创新分布):
f ( x ∣ μ , b ) = 1 2 b x ( 1 − x ) exp ( − ∣ logit ( x ) − μ ∣ b ) f(x \mid \mu, b) = \frac{1}{2 b x (1 - x)} \exp \left( -\frac{|\operatorname{logit}(x) - \mu|}{b} \right) f(x∣μ,b)=2bx(1−x)1exp(−b∣logit(x)−μ∣)
这种分布定义在 (0, 1) 区间,解决了传统 Laplace 或 Gaussian 分布支持域与像素值范围不匹配的问题。
阶段 2:训练自回归 Transformer
在固定 dVAE 参数后,作者用一个 120 亿参数的稀疏 Transformer(参考 Child et al., 2019)建模 ( p ψ ( y , z ) p_\psi(y, z) pψ(y,z))。输入是拼接的文本和图像 token 序列,输出是下一个 token 的概率分布。Transformer 采用解码器架构,包含 64 个自注意力层,每层有 62 个注意力头,支持三种注意力掩码:
- 文本到文本:标准因果掩码。
- 图像到图像:行、列或卷积掩码(最后一层使用 11×11 卷积掩码)。
- 文本到图像:图像 token 可关注所有文本 token。
训练数据包含 2.5 亿个从互联网收集的文本-图像对,远超传统数据集(如 MS-COCO)的规模。作者还引入了混合精度训练和分布式优化(如 PowerSGD 梯度压缩)来应对大规模计算需求。
数学公式与优化细节
证据下界的分解
整体目标是最大化图像 ( x x x) 和文本 ( y y y) 的联合似然 ( p θ , ψ ( x , y ) p_{\theta, \psi}(x, y) pθ,ψ(x,y))。通过引入隐变量 ( z z z)(图像 token),似然被分解为:
p θ , ψ ( x , y , z ) = p θ ( x ∣ y , z ) p ψ ( y , z ) p_{\theta, \psi}(x, y, z) = p_\theta(x \mid y, z) p_\psi(y, z) pθ,ψ(x,y,z)=pθ(x∣y,z)pψ(y,z)
ELB 的推导基于变分推断,确保训练目标是可优化的下界。实践中,(\beta > 1)(如 6.6)有助于平衡重建质量和码本利用率。
Gumbel-Softmax 松弛
对于离散分布 ( q ϕ ( z ∣ x ) q_\phi(z \mid x) qϕ(z∣x)),Gumbel-Softmax 将离散选择近似为:
z i = softmax ( ( g i + logits i ) / τ ) z_i = \text{softmax}((g_i + \text{logits}_i) / \tau) zi=softmax((gi+logitsi)/τ)
其中 ( g i g_i gi) 是 Gumbel 噪声,( logits i \text{logits}_i logitsi) 是编码器输出的 logits,( τ \tau τ) 是温度参数。低 ( τ \tau τ) 使分布更接近 one-hot 向量。
分布式优化与梯度压缩
为应对 120 亿参数的内存需求,作者使用了参数分片和 PowerSGD 梯度压缩。压缩率定义为:
Compression Rate = 1 − 5 r 8 d model \text{Compression Rate} = 1 - \frac{5r}{8 d_{\text{model}}} Compression Rate=1−8dmodel5r
其中 ( r r r) 是压缩秩,( d model d_{\text{model}} dmodel) 是 Transformer 隐藏层维度。实验表明,约 85% 的压缩率在不同模型规模下均有效。
实验结果与能力
零样本性能
在 MS-COCO 数据集上,该模型无需训练标签即可生成高质量图像,人力评估显示其样本 90% 的时间优于先前方法(如 DF-GAN)。FID 分数接近最佳先前方法,且在轻微模糊后表现更优。
意外能力
模型展现了意想不到的泛化能力,例如:
- 概念组合:如“手风琴制成的貘”。
- 图像转换:如“顶部猫的底部素描”或“红色猫”。
这些能力表明,大规模自回归模型在多模态任务中具有强大的隐式推理能力。
总结与启发
这篇论文展示了规模化(数据、模型、计算)的力量,将复杂的文本到图像生成简化为一个统一的序列建模问题。对于熟悉 Transformer 和 LLM 的研究者,这是一个值得重温的经典:它不仅预示了 DALL·E 等后续工作的方向,还揭示了自回归建模在多模态领域的潜力。无论是 dVAE 的离散表示,还是 Transformer 的序列联合建模,都为今天的生成式 AI 提供了宝贵启示。
如果你对代码实现感兴趣,论文提到开源了部分实现(https://github.com/openai/DALL-E),不妨一探究竟!
代码实现
以下是基于《Zero-Shot Text-to-Image Generation》论文的核心思想,使用 Python 实现的训练代码和零样本生成图像代码。由于论文中描述的是一个大规模系统(120 亿参数 Transformer 和 2.5 亿图像-文本对),完整复现需要大量计算资源和数据。这里提供一个简化的版本,使用 PyTorch 实现 dVAE 和 Transformer 的核心逻辑,适合在较小规模上运行和理解。
实现说明
- dVAE:实现一个简单的离散变分自编码器,将图像压缩为离散 token。
- Transformer:实现一个自回归 Transformer,联合建模文本和图像 token。
- 数据:假设使用一个小型数据集(如 MNIST)作为示例,实际应用需替换为大规模图像-文本对。
- 限制:由于资源限制,代码简化了模型规模和训练细节,但保留了论文的核心思想。
训练代码
以下代码训练 dVAE 和 Transformer,假设输入是 28×28 的 MNIST 图像和简单文本描述。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
from torchvision import datasets, transforms
from torchtext.data.utils import get_tokenizer
import numpy as np
# 超参数
IMG_SIZE = 28 # MNIST 图像大小
CODEBOOK_SIZE = 512 # dVAE 码本大小
LATENT_SIZE = 7 # 压缩后网格大小 (28 / 4 = 7)
TEXT_VOCAB_SIZE = 1000 # 文本词汇表大小
MAX_TEXT_LEN = 10 # 最大文本长度
HIDDEN_DIM = 256 # Transformer 隐藏维度
NUM_LAYERS = 4 # Transformer 层数
NUM_HEADS = 4 # 注意力头数
# dVAE 模型
class DVAE(nn.Module):
def __init__(self):
super(DVAE, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, 4, stride=2, padding=1), # 28x28 -> 14x14
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2, padding=1), # 14x14 -> 7x7
nn.ReLU(),
nn.Conv2d(64, CODEBOOK_SIZE, 1) # 7x7x512
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(CODEBOOK_SIZE, 64, 4, stride=2, padding=1, output_padding=0),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1, output_padding=0),
nn.ReLU(),
nn.Conv2d(32, 1, 3, padding=1),
nn.Sigmoid()
)
self.codebook = nn.Parameter(torch.randn(CODEBOOK_SIZE, 64))
def forward(self, x):
logits = self.encoder(x) # [B, 512, 7, 7]
dist = Categorical(logits=logits.permute(0, 2, 3, 1)) # [B, 7, 7, 512]
z = dist.sample() # [B, 7, 7]
z_embed = self.codebook[z] # [B, 7, 7, 64]
z_embed = z_embed.permute(0, 3, 1, 2) # [B, 64, 7, 7]
recon = self.decoder(z_embed) # [B, 1, 28, 28]
return recon, logits, z
# Transformer 模型
class TextImageTransformer(nn.Module):
def __init__(self):
super(TextImageTransformer, self).__init__()
self.text_embedding = nn.Embedding(TEXT_VOCAB_SIZE, HIDDEN_DIM)
self.image_embedding = nn.Embedding(CODEBOOK_SIZE, HIDDEN_DIM)
self.pos_embedding = nn.Parameter(torch.randn(1, MAX_TEXT_LEN + LATENT_SIZE * LATENT_SIZE, HIDDEN_DIM))
transformer_layer = nn.TransformerDecoderLayer(HIDDEN_DIM, NUM_HEADS, dim_feedforward=512)
self.transformer = nn.TransformerDecoder(transformer_layer, NUM_LAYERS)
self.text_head = nn.Linear(HIDDEN_DIM, TEXT_VOCAB_SIZE)
self.image_head = nn.Linear(HIDDEN_DIM, CODEBOOK_SIZE)
def forward(self, text, image_tokens, mask):
B, T = text.shape
B, I = image_tokens.shape
text_embed = self.text_embedding(text) # [B, T, H]
image_embed = self.image_embedding(image_tokens) # [B, I, H]
seq = torch.cat([text_embed, image_embed], dim=1) # [B, T+I, H]
seq += self.pos_embedding[:, :T+I]
out = self.transformer(seq, seq, tgt_mask=mask) # [B, T+I, H]
text_logits = self.text_head(out[:, :T]) # [B, T, TEXT_VOCAB_SIZE]
image_logits = self.image_head(out[:, T:]) # [B, I, CODEBOOK_SIZE]
return text_logits, image_logits
# 数据加载(示例使用 MNIST 和简单文本)
transform = transforms.Compose([transforms.ToTensor()])
mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
loader = torch.utils.data.DataLoader(mnist, batch_size=32, shuffle=True)
tokenizer = get_tokenizer('basic_english')
text_vocab = {'<pad>': 0, '<eos>': 1} # 简单词汇表
for i, (img, label) in enumerate(mnist):
if i >= 1000: break
text_vocab[f"digit_{label}"] = len(text_vocab)
# 训练 dVAE
dvae = DVAE().cuda()
optimizer = optim.Adam(dvae.parameters(), lr=1e-3)
for epoch in range(5):
for img, _ in loader:
img = img.cuda()
recon, logits, z = dvae(img)
recon_loss = nn.functional.binary_cross_entropy(recon, img)
kl_loss = Categorical(logits=logits.permute(0, 2, 3, 1)).entropy().mean()
loss = recon_loss + 6.6 * kl_loss # beta = 6.6
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch}, Loss: {loss.item()}")
# 训练 Transformer
transformer = TextImageTransformer().cuda()
optimizer = optim.Adam(transformer.parameters(), lr=1e-4)
for epoch in range(5):
for img, label in loader:
img = img.cuda()
_, _, z = dvae(img) # [B, 7, 7]
z = z.flatten(1) # [B, 49]
text = torch.tensor([[text_vocab[f"digit_{l.item()}"]] for l in label], dtype=torch.long).cuda()
text = torch.cat([text, torch.ones_like(text) * text_vocab['<eos>']], dim=1) # [B, 2]
text = nn.functional.pad(text, (0, MAX_TEXT_LEN - 2), value=text_vocab['<pad>']) # [B, 10]
seq_len = MAX_TEXT_LEN + LATENT_SIZE * LATENT_SIZE
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().cuda()
text_logits, image_logits = transformer(text, z, mask)
text_loss = nn.functional.cross_entropy(text_logits.view(-1, TEXT_VOCAB_SIZE), text.view(-1))
image_loss = nn.functional.cross_entropy(image_logits.view(-1, CODEBOOK_SIZE), z.view(-1))
loss = text_loss + image_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch}, Loss: {loss.item()}")
torch.save(dvae.state_dict(), "dvae.pth")
torch.save(transformer.state_dict(), "transformer.pth")
零样本生成图像代码
以下代码加载训练好的模型,并根据输入文本生成图像。
import torch
import torch.nn as nn
from torchvision.utils import save_image
# 加载模型(假设已定义 DVAE 和 TextImageTransformer 类)
dvae = DVAE().cuda()
transformer = TextImageTransformer().cuda()
dvae.load_state_dict(torch.load("dvae.pth"))
transformer.load_state_dict(torch.load("transformer.pth"))
dvae.eval()
transformer.eval()
# 生成函数
def generate_image(text_input, text_vocab, max_len=MAX_TEXT_LEN, latent_size=LATENT_SIZE):
# 文本预处理
tokens = tokenizer(text_input.lower())[:max_len-1]
text = [text_vocab.get(token, text_vocab['<pad>']) for token in tokens] + [text_vocab['<eos>']]
text = text + [text_vocab['<pad>']] * (max_len - len(text))
text = torch.tensor([text], dtype=torch.long).cuda() # [1, max_len]
# 初始化图像 token
image_tokens = torch.zeros(1, latent_size * latent_size, dtype=torch.long).cuda() # [1, 49]
seq_len = max_len + latent_size * latent_size
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().cuda()
# 自回归生成
with torch.no_grad():
for i in range(max_len, seq_len):
text_logits, image_logits = transformer(text, image_tokens, mask)
next_token = torch.argmax(image_logits[:, i-max_len], dim=-1) # [1]
image_tokens[:, i-max_len] = next_token
# 解码图像
z = image_tokens.view(1, latent_size, latent_size) # [1, 7, 7]
z_embed = dvae.codebook[z] # [1, 7, 7, 64]
z_embed = z_embed.permute(0, 3, 1, 2) # [1, 64, 7, 7]
img = dvae.decoder(z_embed) # [1, 1, 28, 28]
return img
# 示例生成
text_input = "digit_5"
img = generate_image(text_input, text_vocab)
save_image(img, "generated_image.png")
print("Image generated and saved as 'generated_image.png'")
使用说明
- 环境:
- 安装依赖:
pip install torch torchvision torchtext
- 确保有 GPU 支持(代码默认使用 CUDA)。
- 安装依赖:
- 运行训练:
- 保存
train_text_to_image.py
,然后运行python train_text_to_image.py
。 - 训练完成后会保存
dvae.pth
和transformer.pth
。
- 保存
- 运行生成:
- 保存
generate_image.py
,确保与训练代码在同一目录(因需要text_vocab
)。 - 运行
python generate_image.py
,生成图像保存为generated_image.png
。
- 保存
- 扩展:
- 替换 MNIST 为真实图像-文本数据集(如 MS-COCO)。
- 增加模型规模(隐藏维度、层数、头数)以接近论文中的 120 亿参数。
- 实现混合精度训练和分布式优化(需参考 PyTorch 文档)。
注意事项
- 简化:此代码使用 MNIST(28×28 单通道图像)而非论文中的 256×256 RGB 图像,网格大小从 32×32 简化为 7×7,码本大小从 8192 减为 512。
- Gumbel-Softmax:训练中未显式实现松弛,而是直接用 Categorical 采样,实际应用可添加 Gumbel-Softmax。
- 数据:需自行准备大规模文本-图像对,论文使用 2.5 亿对数据。
这个实现展示了论文的核心流程:dVAE 压缩图像 + Transformer 自回归建模。希望对研究者理解和复现这一经典工作有所帮助!
具体解析
下面将详细解释这段代码中训练 Transformer 的过程,阐明 dVAE 的作用,输出的 z
是什么,以及如何实现文本和图像的联动(即如何让模型根据文本生成对应的图像内容)。这段代码是基于《Zero-Shot Text-to-Image Generation》论文的核心思想设计的,适合熟悉 Transformer 的研究者理解其多模态建模机制。
训练 Transformer 的过程解释
这段代码的目标是训练一个自回归 Transformer,使其能够联合建模文本和图像 token,并在给定文本的情况下生成对应的图像内容。以下是逐行解析:
1. 初始化 Transformer 和优化器
transformer = TextImageTransformer().cuda()
optimizer = optim.Adam(transformer.parameters(), lr=1e-4)
- 作用:创建一个
TextImageTransformer
实例并将其移动到 GPU 上,使用 Adam 优化器(学习率 1e-4)来更新模型参数。 - 模型结构:
TextImageTransformer
是一个解码器形式的 Transformer,包含文本嵌入层、图像嵌入层和多层自注意力机制,输出文本和图像的预测 logits。
2. 训练循环
for epoch in range(5):
for img, label in loader:
- 作用:进行 5 个 epoch 的训练,每次迭代处理一个批次的数据。
loader
提供 MNIST 数据集的图像 (img
) 和标签 (label
)。
3. 使用 dVAE 生成图像 token(z
)
img = img.cuda()
_, _, z = dvae(img) # [B, 7, 7]
z = z.flatten(1) # [B, 49]
- dVAE 的作用:dVAE(离散变分自编码器)将输入图像(这里是 28×28 的 MNIST 图像)压缩为离散的 token 表示。论文中提到,dVAE 将 256×256 的 RGB 图像压缩为 32×32 的网格,每个位置从 8192 个码本向量中选择一个值。这里简化为 7×7 的网格(因为 28 / 4 ≈ 7),码本大小为 512。
- 输出
z
是什么:z
是 dVAE 编码器生成的离散 token,表示图像的压缩表示。具体来说:- 输入图像
[B, 1, 28, 28]
通过编码器生成 logits[B, 512, 7, 7]
。 - 对 logits 应用 Categorical 分布采样,得到
[B, 7, 7]
的 token 索引,每个值是 0 到 511 之间的整数。 z.flatten(1)
将其展平为[B, 49]
(7×7=49),便于与文本 token 拼接成序列。
- 输入图像
- 意义:
z
是图像的低维表示,保留了主要视觉特征,同时大幅减少了上下文长度(从 28×28=784 个像素到 49 个 token),适合 Transformer 处理。
4. 准备文本输入
text = torch.tensor([[text_vocab[f"digit_{l.item()}"]] for l in label], dtype=torch.long).cuda()
text = torch.cat([text, torch.ones_like(text) * text_vocab['<eos>']], dim=1) # [B, 2]
text = nn.functional.pad(text, (0, MAX_TEXT_LEN - 2), value=text_vocab['<pad>']) # [B, 10]
- 作用:为每个图像生成对应的文本描述(这里简化为 “digit_X”,X 是标签),并将其转换为 token 序列。
- 步骤:
- 从
text_vocab
中获取每个标签的 token(如digit_5
对应的整数)。 - 添加结束符
<eos>
,形成[B, 2]
的序列。 - 用
<pad>
填充到固定长度MAX_TEXT_LEN=10
,结果为[B, 10]
。
- 从
- 意义:文本 token 表示条件输入,Transformer 将基于这些 token 预测图像 token。
5. 生成注意力掩码
seq_len = MAX_TEXT_LEN + LATENT_SIZE * LATENT_SIZE
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().cuda()
- 作用:创建一个自回归掩码,长度为
10 + 49 = 59
,确保每个位置只能关注前面的 token。 - 细节:
torch.triu
生成上三角矩阵,diagonal=1
表示当前 token 不关注自己,仅关注之前的 token,符合自回归建模。
6. 前向传播
text_logits, image_logits = transformer(text, z, mask)
- 作用:将文本 token
[B, 10]
和图像 token[B, 49]
输入 Transformer,输出预测的文本 logits[B, 10, TEXT_VOCAB_SIZE]
和图像 logits[B, 49, CODEBOOK_SIZE]
。 - 过程:
- 文本和图像 token 被嵌入为
[B, 59, HIDDEN_DIM]
的序列。 - Transformer 使用自注意力机制处理序列,输出每个位置的隐藏状态。
- 通过线性层分别预测文本和图像的下一个 token 分布。
- 文本和图像 token 被嵌入为
7. 计算损失
text_loss = nn.functional.cross_entropy(text_logits.view(-1, TEXT_VOCAB_SIZE), text.view(-1))
image_loss = nn.functional.cross_entropy(image_logits.view(-1, CODEBOOK_SIZE), z.view(-1))
loss = text_loss + image_loss
- 作用:计算交叉熵损失,监督 Transformer 正确预测文本和图像 token。
- 细节:
text_loss
:让模型复现输入文本(自监督)。image_loss
:让模型根据文本预测正确的图像 token(z
)。- 总损失是两者的和,未加权重(论文中可能调整权重)。
8. 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
- 作用:清零梯度,计算梯度并更新 Transformer 参数,使其逐步学会文本和图像的联合分布。
9. 打印损失
print(f"Epoch {epoch}, Loss: {loss.item()}")
- 作用:监控训练进度。
dVAE 的作用
dVAE 是整个流程的关键预处理步骤,其作用包括:
- 压缩图像:将高维图像(28×28 或论文中的 256×256×3)压缩为低维离散 token(7×7 或 32×32),减少 Transformer 的序列长度,提高计算效率。
- 离散表示:通过码本(codebook)将连续像素值映射为离散索引,便于 Transformer 建模(类似语言模型中的词汇表)。
- 保留信息:通过训练优化重建损失和 KL 散度,确保
z
保留图像的主要视觉特征(尽管细节可能丢失,如论文中提到的猫毛纹理)。
在训练 Transformer 时,dVAE 已预训练好,作用是提供图像的 token 表示(z
),作为监督信号和输入的一部分。
输出 z
的含义
z
是 dVAE 对图像的编码结果,形状为[B, 7, 7]
(展平后[B, 49]
),每个值是从 0 到 511 的整数,表示码本中的索引。- 物理意义:它是图像的离散表示,每个 token 对应 7×7 网格中的一个区域,类似于“图像词汇”。
- 后续处理:Transformer 用
z
作为目标,学习预测这些 token;生成时,Transformer 输出类似的 token,再由 dVAE 解码器重建图像。
如何实现文本和图像的联动
Transformer 通过自回归建模实现文本和图像的联动,即学会根据文本生成对应的图像内容。以下是具体机制:
1. 联合序列建模
- 输入序列是
[text_tokens, image_tokens]
(如[B, 59]
),文本在前,图像在后。 - Transformer 自回归地预测下一个 token:文本部分复现输入文本,图像部分预测
z
。 - 关键点:图像 token 可以关注所有文本 token(论文中通过注意力掩码实现),从而捕捉文本描述与图像内容的关系。
2. 注意力机制
- Transformer 的自注意力层允许模型学习文本和图像 token 之间的依赖关系。例如:
- 输入文本 “digit_5”,模型通过注意力机制将 “digit_5” 的嵌入与后续图像 token 关联。
- 训练时,图像 token(
z
)作为目标,迫使模型根据前文预测正确的图像内容。
3. 损失函数驱动
image_loss
监督 Transformer 输出与z
一致,z
是由 dVAE 从真实图像生成的。- 通过最小化损失,模型逐渐学会:给定文本 “digit_5”,生成与数字 5 图像对应的 token 序列。
4. 零样本生成
- 训练完成后,给定新文本(如 “digit_7”),Transformer 从文本开始自回归生成图像 token。
- 生成的 token 通过 dVAE 解码器重建图像,实现文本到图像的转换。
- 零样本能力:由于训练数据多样(论文中 2.5 亿对),模型能泛化到未见过的文本描述。
5. 实现联动的本质
- 数据驱动:大量文本-图像对(如 “digit_5” 和 5 的图像)让模型学习统计相关性。
- 自回归性质:Transformer 预测下一个 token 的能力使其能从文本逐步构建图像。
- dVAE 的桥梁作用:dVAE 将图像转化为 token,与文本统一为序列,Transformer 无需区分模态,直接建模联合分布。
总结
- 训练过程:dVAE 提供图像 token(
z
),Transformer 学习文本和z
的联合分布,通过自回归预测实现联动。 - dVAE 作用:压缩图像为离散 token,降低计算复杂度并统一模态。
z
的意义:图像的离散表示,作为 Transformer 的输入和目标。- 文本-图像联动:通过自注意力、联合建模和损失监督,Transformer 学会从文本生成对应图像 token,再由 dVAE 解码为图像。
这种方法的核心在于将多模态任务转化为序列建模问题,充分利用 Transformer 的强大序列处理能力,体现了论文“简单但有效”的设计哲学。
后记
2025年3月26日15点13分于上海,在grok 3大模型辅助下完成。