扩散模型基础(二):采样过程

文章详细介绍了使用扩散模型(基于BasicUNet的变种)进行带噪数据预测的方法,重点讲解了采样过程中的迭代策略和代码实现。作者还展示了如何在高噪声环境下通过混合因子优化预测效果,并给出了一个完整的MNIST数据集实验示例。
摘要由CSDN通过智能技术生成

学习了扩散模型从原理到实战-异步社区-致力于优质IT知识的出版和分享 (epubit.com)这本教材后,对教材里所提的内容进行了自我消化,总结总结。

在BasicUNet模型基础上(扩散模型基础(一):基于BasicUNet的扩散模型算法搭建-CSDN博客),对扩散模型的带噪数据预测进行优化。

采样过程是扩散模型实现数据预测的关键,对实现在高噪声环境下的预测效果具有很重要的意义,也是从Xt寻找最优X0的过程。也就是说,通过在预测模型中通过多次迭代,根据每次最好的预测结果继续往前预测,通过T次迭代后得到较为理想的预测效果。

代码如下:以5次迭代为例:

n_steps = 5  #迭代次数
x = torch.rand(8, 1, 28, 28).to(device) # 完全从随机值开始
step_history = [x.detach().cpu()]
pred_output_history = []

for i in range(n_steps):
    with torch.no_grad(): # 推理预测过程不需要考虑张量倒数
        pred = net(x) # Predict the denoised x0
    pred_output_history.append(pred.detach().cpu()) # Store model output for plotting
    mix_factor = 1/(n_steps - i) # How much we move towards the prediction
    x = x*(1-mix_factor) + pred*mix_factor # Move part of the way there
    step_history.append(x.detach().cpu()) # Store step for plotting

fig, axs = plt.subplots(n_steps, 2, figsize=(9, 4), sharex=True)
axs[0,0].set_title('x (model input)')
axs[0,1].set_title('model prediction')
for i in range(n_steps):
    axs[i, 0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0, 1), cmap='Greys')
    axs[i, 1].imshow(torchvision.utils.make_grid(pred_output_history[i])[0].clip(0, 1), cmap='Greys')
plt.show()

代码详解:

  • n_step:迭代次数
  • x = torch.rand(8,1,28,28):完全随机的8张1维图片,28*28像素。与MNIST数据集一致。
  • 循环中使用了mix_factor(混合因子):按照混合因子将Xt和Xt-1进行混合,避免丢失一些重要的数据内部结构信息。
  • 经过n_step次迭代后,得到最后的预测结果。

在遇到效果不好的时候,可以采用更多的n_step(如:50次),并调整模型配置、学习率、优化器等,以求获得更好的运行效果。

全部调试好的源代码如下:

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor())

train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
print(torchvision.utils.make_grid(x)[0].shape)

plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')
plt.show()

def corrupt(x, amount):
  """Corrupt the input `x` by mixing it with noise according to `amount`"""

  #print(amount)

  noise = torch.rand_like(x)

  #print(noise)

  amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works

  #print(amount)

  return x*(1-amount) + noise*amount

# Plotting the input data
fig, axs = plt.subplots(2, 1, figsize=(12, 5))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')

# Adding noise
amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
noised_x = corrupt(x, amount)

# Plottinf the noised version
axs[1].set_title('Corrupted data (-- amount increases -->)')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap='Greys')
plt.show()


class BasicUNet(nn.Module):
  """A minimal UNet implementation."""

  def __init__(self, in_channels=1, out_channels=1):
    super().__init__()
    self.down_layers = torch.nn.ModuleList([
      nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
      nn.Conv2d(32, 64, kernel_size=5, padding=2),
      nn.Conv2d(64, 64, kernel_size=5, padding=2),
    ])
    self.up_layers = torch.nn.ModuleList([
      nn.Conv2d(64, 64, kernel_size=5, padding=2),
      nn.Conv2d(64, 32, kernel_size=5, padding=2),
      nn.Conv2d(32, out_channels, kernel_size=5, padding=2),
    ])
    self.act = nn.SiLU()  # The activation function
    self.downscale = nn.MaxPool2d(2)
    self.upscale = nn.Upsample(scale_factor=2)

  def forward(self, x):
    h = []
    for i, l in enumerate(self.down_layers):
      x = self.act(l(x))  # Through the layer and the activation function
      if i < 2:  # For all but the third (final) down layer:
        h.append(x)  # Storing output for skip connection
        x = self.downscale(x)  # Downscale ready for the next layer

    for i, l in enumerate(self.up_layers):
      if i > 0:  # For all except the first up layer
        x = self.upscale(x)  # Upscale
        x += h.pop()  # Fetching stored output (skip connection)
      x = self.act(l(x))  # Through the layer and the activation function

    return x

# Dataloader (you can mess with batch size)
batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# How many runs through the data should we do?
n_epochs = 10

# Create the network
net = BasicUNet()
net.to(device)

# Our loss finction
loss_fn = nn.MSELoss()  #损失函数:均方误差损失

# The optimizer
opt = torch.optim.Adam(net.parameters(), lr=1e-3) #制定优化器:Adam优化器;更新参数的算法称为优化器

# Keeping a record of the losses for later viewing
losses = [] #损失值记录

# The training loop
for epoch in range(n_epochs):

    for x, y in train_dataloader:

        # Get some data and prepare the corrupted version
        x = x.to(device) # Data on the GPU
        noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts
        noisy_x = corrupt(x, noise_amount) # Create our noisy x

        # Get the model prediction:得到预测值
        pred = net(noisy_x)

        # Calculate the loss:计算损失并比较
        loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?

        # Backprop and update the params:更新参数:反向传播
        opt.zero_grad()
        loss.backward()
        opt.step()

        # Store the loss for later:记录损失值
        losses.append(loss.item())

    # Print our the average of the loss values for this epoch:
    avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)
    print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')

# View the loss curve
plt.plot(losses)
plt.ylim(0, 0.1);
plt.show()


#### 采样过程 ############

n_steps = 5
x = torch.rand(8, 1, 28, 28).to(device) # Start from random
step_history = [x.detach().cpu()]
pred_output_history = []

for i in range(n_steps):
    with torch.no_grad(): # No need to track gradients during inference
        pred = net(x) # Predict the denoised x0
    pred_output_history.append(pred.detach().cpu()) # Store model output for plotting
    mix_factor = 1/(n_steps - i) # How much we move towards the prediction
    x = x*(1-mix_factor) + pred*mix_factor # Move part of the way there
    step_history.append(x.detach().cpu()) # Store step for plotting

fig, axs = plt.subplots(n_steps, 2, figsize=(9, 4), sharex=True)
axs[0,0].set_title('x (model input)')
axs[0,1].set_title('model prediction')
for i in range(n_steps):
    axs[i, 0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0, 1), cmap='Greys')
    axs[i, 1].imshow(torchvision.utils.make_grid(pred_output_history[i])[0].clip(0, 1), cmap='Greys')
plt.show()

  • 9
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值