前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站。
https://www.captainbed.cn/north
文章目录
1. 引言:扩散模型的崛起
近年来,扩散模型(Diffusion Model)在生成式AI领域异军突起,成为继GANs之后最受关注的图像生成技术。从OpenAI的DALL·E 2到Stable Diffusion,再到Midjourney,这些令人惊叹的AI绘画工具背后都采用了扩散模型作为核心技术。本文将深入剖析扩散模型的数学原理、实现细节,并通过代码示例展示其工作流程。
2. 扩散模型的基本原理
2.1 核心思想
扩散模型的核心思想源自物理学中的扩散过程:通过逐步向数据添加噪声,将复杂的数据分布逐渐转变为简单的高斯分布(前向过程),然后学习如何逆转这一过程(反向过程),从而从噪声中生成新的数据样本。
2.2 前向扩散过程(Forward Diffusion Process)
前向过程是一个固定的马尔可夫链,逐步向数据添加高斯噪声。给定一个数据点x₀(如图像),前向过程在T步内逐渐将其转换为纯噪声x_T。
数学表示为:
q(x_t|x_{t-1}) = N(x_t; √(1-β_t)x_{t-1}, β_tI)
其中β_t是噪声调度参数,控制每一步添加的噪声量。
2.3 反向扩散过程(Reverse Diffusion Process)
反向过程则是学习如何"去噪",即从x_T逐步恢复出x₀。这需要学习一个参数化的模型(通常是神经网络)来估计:
p_θ(x_{t-1}|x_t) = N(x_{t-1}; μ_θ(x_t,t), Σ_θ(x_t,t))
2.4 损失函数
扩散模型的训练目标是最小化负对数似然的变分上界(ELBO):
L = E_{t,x_0,ε}[||ε - ε_θ(x_t,t)||^2]
其中ε是添加的噪声,ε_θ是模型预测的噪声。
3. 扩散模型的架构实现
3.1 U-Net架构
大多数扩散模型使用改进的U-Net作为主干网络,它具有以下特点:
- 编码器-解码器结构,带有跳跃连接
- 加入了时间步嵌入(Time Embedding)
- 使用了自注意力机制(Self-Attention)
- 采用了残差连接
import torch
import torch.nn as nn
import torch.nn.functional as F
class TimeEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
inv_freq = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000) / dim)
self.register_buffer('inv_freq', inv_freq)
def forward(self, t):
t = t.view(-1).float()
sinusoid_in = torch.outer(t, self.inv_freq)
pos_emb = torch.cat([sinusoid_in.sin(), sinusoid_in.cos()], dim=-1)
return pos_emb
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, time_dim):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.time_mlp = nn.Linear(time_dim, out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
else:
self.shortcut = nn.Identity()
def forward(self, x, t):
h = self.conv1(x)
h += self.time_mlp(t)[:, :, None, None]
h = F.relu(h)
h = self.conv2(h)
return h + self.shortcut(x)
class UNet(nn.Module):
def __init__(self, in_channels=3, out_channels=3, channels=[64, 128, 256, 512], time_dim=256):
super().__init__()
self.time_mlp = nn.Sequential(
TimeEmbedding(time_dim),
nn.Linear(time_dim, time_dim),
nn.ReLU()
)
# 编码器
self.encoder = nn.ModuleList()
in_ch = in_channels
for ch in channels:
self.encoder.append(ResidualBlock(in_ch, ch, time_dim))
self.encoder.append(nn.Conv2d(ch, ch, kernel_size=3, stride=2, padding=1))
in_ch = ch
# 中间层
self.mid_block1 = ResidualBlock(channels[-1], channels[-1], time_dim)
self.mid_attn = nn.Sequential(
nn.GroupNorm(8, channels[-1]),
nn.Conv2d(channels[-1], channels[-1], kernel_size=1),
nn.ReLU(),
nn.Conv2d(channels[-1], channels[-1], kernel_size=1)
)
self.mid_block2 = ResidualBlock(channels[-1], channels[-1], time_dim)
# 解码器
self.decoder = nn.ModuleList()
in_ch = channels[-1]
for ch in reversed(channels[:-1]):
self.decoder.append(nn.ConvTranspose2d(in_ch, ch, kernel_size=4, stride=2, padding=1))
self.decoder.append(ResidualBlock(ch*2, ch, time_dim))
in_ch = ch
self.decoder.append(nn.Conv2d(channels[0], out_channels, kernel_size=3, padding=1))
def forward(self, x, t):
# 时间嵌入
t = self.time_mlp(t)
# 编码器
skips = []
for layer in self.encoder:
if isinstance(layer, ResidualBlock):
x = layer(x, t)
skips.append(x)
else:
x = layer(x)
# 中间层
x = self.mid_block1(x, t)
x = x + self.mid_attn(x)
x = self.mid_block2(x, t)
# 解码器
for layer in self.decoder:
if isinstance(layer, nn.ConvTranspose2d):
x = layer(x)
x = torch.cat([x, skips.pop()], dim=1)
elif isinstance(layer, ResidualBlock):
x = layer(x, t)
return x
3.2 噪声调度
噪声调度决定了β_t如何随时间变化,常见的有:
- 线性调度:β_t从β₁线性增加到β_T
- 余弦调度:基于余弦函数平滑变化
- 平方调度:β_t的平方根线性变化
def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
return torch.linspace(beta_start, beta_end, timesteps)
def cosine_beta_schedule(timesteps, s=0.008):
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
4. 扩散模型的训练流程
4.1 训练算法
- 从数据集中采样一个干净图像x₀
- 随机选择一个时间步t ∈ [1, T]
- 采样噪声ε ~ N(0, I)
- 计算加噪后的图像x_t
- 让网络预测噪声ε_θ(x_t, t)
- 计算均方误差损失||ε - ε_θ||²
def train_step(model, x0, t, loss_fn, betas):
# 计算前向过程参数
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
# 采样噪声
noise = torch.randn_like(x0)
# 计算x_t
sqrt_alpha = extract(sqrt_alphas_cumprod, t, x0.shape)
sqrt_one_minus_alpha = extract(sqrt_one_minus_alphas_cumprod, t, x0.shape)
x_t = sqrt_alpha * x0 + sqrt_one_minus_alpha * noise
# 预测噪声
predicted_noise = model(x_t, t)
# 计算损失
loss = loss_fn(noise, predicted_noise)
return loss
4.2 采样过程
采样是从纯噪声逐步去噪生成图像的过程:
- 从x_T ~ N(0, I)开始
- 对于t从T到1:
a. 预测噪声ε_θ(x_t, t)
b. 计算去噪后的x_{t-1} - 返回最终的x₀
@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3, timesteps=1000):
# 准备噪声调度参数
betas = linear_beta_schedule(timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# 从纯噪声开始
img = torch.randn((batch_size, channels, image_size, image_size))
for t in reversed(range(0, timesteps)):
# 准备时间步
t_tensor = torch.full((batch_size,), t, dtype=torch.long)
# 预测噪声
pred_noise = model(img, t_tensor)
# 计算x0的估计
sqrt_recip_alpha_t = extract(sqrt_recip_alphas, t_tensor, img.shape)
sqrt_one_minus_alpha_cumprod_t = extract(
sqrt_one_minus_alphas_cumprod, t_tensor, img.shape
)
x0_estimate = sqrt_recip_alpha_t * (img - sqrt_one_minus_alpha_cumprod_t * pred_noise)
# 计算均值
posterior_variance_t = extract(posterior_variance, t_tensor, img.shape)
mean = (
extract(torch.sqrt(alphas_cumprod_prev), t_tensor, img.shape) * betas[t] / (1. - alphas_cumprod[t])
) * x0_estimate + (
extract(torch.sqrt(alphas), t_tensor, img.shape) * (1. - alphas_cumprod_prev[t]) / (1. - alphas_cumprod[t])
) * img
if t > 0:
noise = torch.randn_like(img)
img = mean + torch.sqrt(posterior_variance_t) * noise
else:
img = mean
return img
5. 扩散模型的变体与改进
5.1 DDPM (Denoising Diffusion Probabilistic Models)
DDPM是最早提出的扩散模型之一,采用固定方差的高斯分布作为反向过程的转移分布。
5.2 DDIM (Denoising Diffusion Implicit Models)
DDIM提出了一种非马尔可夫的扩散过程,可以加速采样而不需要重新训练模型:
@torch.no_grad()
def ddim_sample(model, image_size, batch_size=16, channels=3, timesteps=1000, ddim_timesteps=50, eta=0.0):
# 准备噪声调度参数
betas = linear_beta_schedule(timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
# 选择DDIM的时间步
times = torch.linspace(0, timesteps-1, ddim_timesteps+1).long().flip(0)
time_pairs = list(zip(times[:-1], times[1:])))
# 从纯噪声开始
img = torch.randn((batch_size, channels, image_size, image_size))
for t, t_next in time_pairs:
# 准备时间步
t_tensor = torch.full((batch_size,), t, dtype=torch.long)
# 预测噪声
pred_noise = model(img, t_tensor)
# 计算x0的估计
alpha_cumprod_t = extract(alphas_cumprod, t_tensor, img.shape)
sqrt_one_minus_alpha_cumprod_t = extract(
torch.sqrt(1. - alphas_cumprod), t_tensor, img.shape
)
x0_estimate = (img - sqrt_one_minus_alpha_cumprod_t * pred_noise) / torch.sqrt(alpha_cumprod_t)
# 计算sigma
alpha_cumprod_t_next = extract(alphas_cumprod, torch.full((batch_size,), t_next, dtype=torch.long), img.shape)
sigma = eta * torch.sqrt((1 - alpha_cumprod_t_next) / (1 - alpha_cumprod_t)) * torch.sqrt(1 - alpha_cumprod_t / alpha_cumprod_t_next)
# 计算方向
sqrt_alpha_cumprod_t_next = extract(torch.sqrt(alphas_cumprod), torch.full((batch_size,), t_next, dtype=torch.long), img.shape)
dir_xt = torch.sqrt(1 - alpha_cumprod_t_next - sigma**2) * pred_noise
# 更新图像
noise = torch.randn_like(img) if t_next > 0 else 0
img = sqrt_alpha_cumprod_t_next * x0_estimate + dir_xt + sigma * noise
return img
5.3 Latent Diffusion Models (LDM)
Stable Diffusion采用的方法,先在潜在空间进行扩散,再通过VAE解码:
- 使用VAE编码器将图像压缩到潜在空间
- 在潜在空间进行扩散过程
- 使用VAE解码器将结果解码回像素空间
6. 扩散模型的应用:文本到图像生成
现代AI绘画系统如Stable Diffusion结合了扩散模型和CLIP文本编码器:
- 文本编码器将提示词转换为嵌入向量
- 交叉注意力机制将文本信息注入扩散过程
- 在潜在空间进行条件生成
class TextConditionedUNet(nn.Module):
def __init__(self, text_embed_dim, **unet_kwargs):
super().__init__()
self.unet = UNet(**unet_kwargs)
self.text_proj = nn.Linear(text_embed_dim, unet_kwargs['time_dim'])
self.attn = nn.MultiheadAttention(embed_dim=unet_kwargs['channels'][-1], num_heads=4)
def forward(self, x, t, text_embed):
# 时间嵌入
t_embed = self.unet.time_mlp(t)
# 文本嵌入
text_embed = self.text_proj(text_embed)
text_embed = text_embed.unsqueeze(1) # [batch, 1, time_dim]
# 合并时间与文本信息
context = torch.cat([t_embed.unsqueeze(1), text_embed], dim=1)
# U-Net编码器
skips = []
for layer in self.unet.encoder:
if isinstance(layer, ResidualBlock):
x = layer(x, t_embed + text_embed.squeeze(1))
skips.append(x)
else:
x = layer(x)
# 中间层带注意力
x = self.unet.mid_block1(x, t_embed + text_embed.squeeze(1))
b, c, h, w = x.shape
x_flat = x.view(b, c, -1).permute(2, 0, 1) # [h*w, b, c]
attn_out, _ = self.attn(x_flat, context.expand(-1, b, -1), context.expand(-1, b, -1))
x = (x + attn_out.permute(1, 2, 0).view(b, c, h, w)) / math.sqrt(2)
x = self.unet.mid_block2(x, t_embed + text_embed.squeeze(1))
# 解码器
for layer in self.unet.decoder:
if isinstance(layer, nn.ConvTranspose2d):
x = layer(x)
x = torch.cat([x, skips.pop()], dim=1)
elif isinstance(layer, ResidualBlock):
x = layer(x, t_embed + text_embed.squeeze(1))
return x
7. 扩散模型的评估指标
评估生成模型质量的常用指标:
- FID (Frechet Inception Distance):衡量生成图像与真实图像的分布距离
- IS (Inception Score):评估生成图像的多样性和质量
- Precision & Recall:分别衡量生成质量与覆盖范围
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
def evaluate_model(model, real_images, num_samples=5000, device='cuda'):
# 生成样本
fake_images = []
for _ in range(0, num_samples, batch_size):
batch_size = min(32, num_samples - len(fake_images))
fake_images.append(sample(model, image_size=real_images.shape[-1], batch_size=batch_size))
fake_images = torch.cat(fake_images, dim=0)[:num_samples]
# 计算FID
fid = FrechetInceptionDistance(feature=2048).to(device)
fid.update(real_images.to(device), real=True)
fid.update(fake_images.to(device), real=False)
fid_score = fid.compute()
# 计算IS
inception = InceptionScore().to(device)
inception.update(fake_images.to(device))
is_score = inception.compute()
return {'FID': fid_score.item(), 'IS': is_score[0].item()}
8. 扩散模型的优化技巧
8.1 加速采样方法
- 知识蒸馏:训练一个学生网络模仿多步去噪过程
- 子序列采样:只选择部分时间步进行采样
- 自适应步长:根据预测误差调整步长
8.2 质量提升技巧
- 分类器引导:使用分类器梯度指导生成过程
- 动态阈值:防止像素值超出合理范围
- 噪声修正:在采样过程中调整噪声预测
9. 扩散模型的局限性与未来方向
9.1 当前局限
- 采样速度慢(即使有加速方法)
- 对高分辨率图像生成仍有挑战
- 难以精确控制生成内容
9.2 未来方向
- 更高效的架构设计
- 更好的多模态条件控制
- 与其他生成模型(如GANs)的结合
- 视频与3D内容生成
10. 结语
扩散模型代表了生成式AI领域的重要突破,其基于物理启发的简单思想却实现了惊人的生成效果。从DDPM到Stable Diffusion,这一技术正在快速发展,不断刷新我们对AI创造力的认知。理解扩散模型的底层原理不仅有助于我们更好地使用这些工具,也为开发新一代生成模型奠定了基础。
附录:完整流程图
graph TD
A[原始图像x0] --> B[前向扩散过程:逐步加噪]
B --> C[纯噪声xT]
C --> D[反向扩散过程:逐步去噪]
D --> E[生成图像x0']
subgraph 训练阶段
B --> F[噪声预测网络εθ]
F --> G[最小化||ε-εθ||²]
end
subgraph 生成阶段
C --> H[从xT开始]
H --> I[预测噪声εθ]
I --> J[计算x_{t-1}]
J --> K{t>1?}
K -->|是| I
K -->|否| E
end
参考文献
- Ho, J., Jain, A., & Abbeel, P. (2020). Denoising Diffusion Probabilistic Models. NeurIPS.
- Song, J., Meng, C., & Ermon, S. (2021). Denoising Diffusion Implicit Models. ICLR.
- Rombach, R., et al. (2022). High-Resolution Image Synthesis with Latent Diffusion Models. CVPR.
- Dhariwal, P., & Nichol, A. (2021). Diffusion Models Beat GANs on Image Synthesis. NeurIPS.
希望这篇详细的文章能帮助你全面理解扩散模型的原理与实现!如需代码的完整实现,可以参考GitHub上的开源项目如HuggingFace Diffusers库。