指导图像生成的除了conditional信息,还有unconditional信息。(在深度学习领域,当我们谈论"无条件信息"时,通常指的是模型在生成数据时不依赖于任何特定的输入或条件变量。例如,在无条件生成模型中(如Variational Autoencoders (VAEs) 或 Generative Adversarial Networks (GANs) 的无条件版本),模型的目标是学习整个数据分布,并能够基于该分布生成新的样本。)
-
条件(Conditional):这里的“条件”通常指的是对模型生成输出的一种指导或约束。例如,在文本引导的图像生成任务中,条件可以是基于输入文本的嵌入表示,模型会根据这些嵌入去生成与文本描述相符合的图像。
-
无条件(Unconditional):相对地,“无条件”是指不依赖特定输入进行生成的情况。在这种模式下,模型自由地生成各种可能的输出,没有特定的文本或其他条件限制。
unconditional的意义,一是可以基于训练数据分布来生成新的样本,
生成新的样本的目的:
-
数据增强与扩充:通过生成新的、逼真的样本,确实可以在一定程度上增加训练数据的多样性,特别是在处理小样本问题时有助于改善模型泛化能力。
以图像生成为例,无条件生成模型并不需要给定一个具体的描述或标签作为条件来生成图片,而是根据训练数据中的图像总体分布特性自动生成看起来像真实世界图像的新图片。同样,在文本生成任务中,无条件模型会生成新的文本序列,这些序列不依赖于特定的上下文提示,而是基于语料库整体的语言规律和模式。纯粹基于数据本身的统计属性得到的信息表示。
Stable Diffusion 训练过程与c和uc相关的参数 ----CFG
以下是如何在训练阶段应用Classifier-Free Guidance的一般步骤:
-
同时训练条件和无条件模型:
- 训练一个扩散模型,该模型能够基于噪声变量(通常是高斯噪声)生成图像。
- 同时训练这个模型来处理两种任务:一是从纯随机噪声中无条件地生成图像;二是根据给定的条件信息(例如文本描述、类别标签等)生成图像。
-
共享权重与采样策略:
- 在同一个网络架构中,无论是进行有条件还是无条件的生成过程,网络参数是共享的。
- 在训练过程中,每次迭代都会同时优化模型对于无条件数据样本以及带有条件信息的数据样本的拟合程度。
-
计算损失并更新权重:
- 对于每个批次的数据,一部分样本仅使用噪声作为输入进行无条件学习。
- 另一部分样本则将条件信息与噪声结合,共同作为输入进行条件学习。
- 模型会根据两部分样本分别计算损失,并合并这些损失以更新网络权重。
-
采样时使用CFG混合得分:
- 在训练后的推理或采样阶段,CFG通过混合模型对于条件和无条件数据的预测得分来进行指导。
- 生成图像时,首先为相同的噪声样本计算出条件和无条件的预测结果,然后按照guidance scale的比例因子加权求和得到最终的修改过的噪声预测值,用于逆扩散过程生成更符合条件信息的图像。
import torch
from torch import nn
from torch.nn import functional as F
# 假设有一个定义好的UNet类
class UNet(nn.Module):
# ... 定义UNet网络结构 ...
# 初始化模型、损失函数和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# 假设有无条件数据加载器和条件数据加载器
uncond_dataloader = ...
cond_dataloader = ...
num_epochs = 100
for epoch in range(num_epochs):
for uncond_batch, cond_batch in zip(uncond_dataloader, cond_dataloader):
# 前向传播:无条件预测
uncond_inputs = uncond_batch["noises"].to(device)
uncond_outputs = model(uncond_inputs)
# 前向传播:条件预测
cond_inputs = cond_batch["noises"].to(device)
cond_labels = cond_batch["labels"].to(device) # 条件信息
cond_outputs = model(cond_inputs, cond_labels)
# 计算无条件和条件损失
uncond_loss = loss_fn(uncond_outputs, uncond_batch["images"].to(device))
cond_loss = loss_fn(cond_outputs, cond_batch["images"].to(device))
# 总损失是条件和无条件损失的组合
total_loss = uncond_loss + cond_loss
# 反向传播并更新权重
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
# 打印当前损失等信息以监控训练过程
print(f"Epoch: {epoch}, Total Loss: {total_loss.item()}")
# 在训练完成后,模型即可应用于Classifier-Free Guidance采样过程