DiffusionFastForward(一) | Diffusion Sandbox

本文介绍了如何基于GitHub上的DiffusionFastForward项目,使用PyTorch进行代码实现,涉及噪声添加、前向和反向传播步骤,展示了如何在图像处理中应用这些技术进行图像生成和恢复。
摘要由CSDN通过智能技术生成
! git clone https://github.com/mikonvergence/DiffusionFastForward
!pip install pytorch-lightning==1.9.3 diffusers einops kornia

本文章基于Github上的DiffusionFastForward,对其中的代码进行中文介绍

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import imageio#图像处理

mpl.rcParams['figure.figsize'] = (12, 8)#修改默认图形大小

img = torch.FloatTensor(imageio.imread('./DiffusionFastForward/imgs/hills_2.png')/255)#读取一张图片,并将像素值从0-255缩小至0-1
plt.imshow(img)

定义一个自动展示图片的函数,将[0,1]转换到[-1,1]再进行输出

def input_T(input):
    # [0,1] -> [-1,+1]
    return 2*input-1
    
def output_T(input):
    # [-1,+1] -> [0,1]
    return (input+1)/2

def show(input):
    plt.imshow(output_T(input).clip(0,1))
    
img_=input_T(img)
show(img_)

如何添加和控制噪声是通过预先设定的方差时间表来实现的。

定义方差时间表:

num_timesteps=1000
betas=torch.linspace(1e-4,2e-2,num_timesteps)

alphas=1-betas
alphas_sqrt=alphas.sqrt()
alphas_cumprod=torch.cumprod(alphas,0)
alphas_cumprod_sqrt=alphas_cumprod.sqrt()

利用torch.linspace()生成了一个范围为1e-4到2e-2,步数为1000的等差数列

torch.cumpord(alphas,0)计算了 alphas 按元素累积乘积,即第一个元素保持不变,第二个元素为第一个元素与第二个元素的乘积,第三个元素为前两个元素的乘积,依此类推。

前向过程:

def forward_step(t, condition_img, return_noise=False):
    """
        forward step: t-1 -> t
    """    
    assert t >= 0

    mean=alphas_sqrt[t]*condition_img    
    std=betas[t].sqrt()
      
    # sampling from N
    if not return_noise:
        return mean+std*torch.randn_like(img)
    else:
        noise=torch.randn_like(img)
        return mean+std*noise, noise
    
def forward_jump(t, condition_img, condition_idx=0, return_noise=False):
    """
        forward jump: 0 -> t
    """   
    assert t >= 0
    
    mean=alphas_cumprod_sqrt[t]*condition_img
    std=(1-alphas_cumprod[t]).sqrt()
      
    # sampling from N
    if not return_noise:
        return mean+std*torch.randn_like(img)
    else:
        noise=torch.randn_like(img)
        return mean+std*noise, noise

定义一个名为 `forward_step` 的函数,用于执行 t-1->t 的步骤。

1. `assert t >= 0`:断言语句,用于确保 t 的值大于等于 0,如果不满足则会引发 AssertionError。

2. `mean = alphas_sqrt[t] * condition_img`:计算当前时间步 t 的均值,其中 `alphas_sqrt` 是之前计算得到的系数序列,`condition_img` 是输入的条件图像。

3. `std = betas[t].sqrt()`:计算当前时间步 t 的标准差,其中 `betas` 是之前计算得到的系数序列。

4. `if not return_noise:`:判断是否需要返回噪声。如果 `return_noise` 为 False,则执行下面的代码块;否则,执行 else 语句块。

5. `return mean + std * torch.randn_like(img)`:如果不需要返回噪声,则按照标准正态分布生成一个与 `img` 相同大小的随机张量,并与均值和标准差相乘后的结果相加,最后返回这个结果。

这个函数看起来是用于在某种模型中执行前向传播的步骤,根据时间步 t、条件图像和噪声参数生成新的输出。

forward_jump 函数同理但执行 0->t 的步骤。

前向过程部分展示:

plt.figure(figsize=(12,8))
for idx in range(N):
    t_step=int(idx*(num_timesteps/N))#把一千步均匀分成了5个起点
    
    plt.subplot(N,1+M,1+(M+1)*idx)
    show(alphas_cumprod_sqrt[t_step]*img_)
    plt.title(r'$\mu_t=\sqrt{\bar{\alpha}_t}x_0$') if idx==0 else None
    plt.ylabel("t: {:.2f}".format(t_step/num_timesteps))
    plt.xticks([])
    plt.yticks([])
    
    for sample in range(M):
        x_t=forward_jump(t_step+sample,img_)
        
        plt.subplot(N,1+M,2+(1+M)*idx+sample)
        show(x_t)        
        plt.axis('off')
        
plt.tight_layout()

plit.subplot(N,1+M,1+(M+1)*idx)表示子图行数为N,列数为1+M,当前子图索引为1+(M+1)*idx。

plt.axis('off')用于隐藏坐标轴。具体地说,它会将当前图形的 x 轴和 y 轴的坐标轴线以及刻度标签都隐藏起来,使得图形只显示数据内容,而不显示坐标轴。

查看单步前向过程

plt.figure(figsize=(12,8))
for idx in range(N):
    t_step=int(idx*(num_timesteps/N))
    prev_img=forward_jump(max([0,t_step-1]),img_) # directly go to prev state
    
    plt.subplot(N,1+M,1+(M+1)*idx)
    show(alphas_sqrt[t_step]*prev_img)
    plt.title(r'$\mu_t=\sqrt{1-\beta_t}x_{t-1}$') if idx==0 else None
    plt.ylabel("t: {:.2f}".format(t_step/num_timesteps))
    plt.xticks([])
    plt.yticks([])
    
    for sample in range(M):
        plt.subplot(N,1+M,2+(1+M)*idx+sample)
        x_t=forward_step(t_step+sample,prev_img)
        show(x_t)        
        plt.axis('off')
plt.tight_layout()

反向过程:

查看加入的噪声

t_step=50

x_t,noise=forward_jump(t_step,img_,return_noise=True)

plt.subplot(1,3,1)
show(img_)
plt.title(r'$x_0$')
plt.axis('off')
plt.subplot(1,3,2)
show(x_t)
plt.title(r'$x_t$')
plt.axis('off')
plt.subplot(1,3,3)
show(noise)
plt.title(r'$\epsilon$')
plt.axis('off')

通过之前查看的噪声ε,和xt来获得x0(当然能获得平均绝对误差为0的x0)式(5)

x_0_pred=(x_t-(1-alphas_cumprod[t_step]).sqrt()*noise)/(alphas_cumprod_sqrt[t_step])

plt.subplot(1,3,1)
show(x_t)
plt.title('$x_t$ ($\ell_1$: {:.3f})'.format(F.l1_loss(x_t,img_)))
plt.axis('off')
plt.subplot(1,3,2)
show(x_0_pred)
plt.title('$x_0$ prediction ($\ell_1$: {:.3f})'.format(F.l1_loss(x_0_pred,img_)))
plt.axis('off') 
plt.subplot(1,3,3)
show(img_)
plt.title('$x_0$')
plt.axis('off')

根据x_0x_t来得到x_{t-1}的均值式(4),并将其与真实的xt-1步的均值(前向过程中去除加噪的步骤)比较

# estimate mean
mean_pred=x_0_pred*(alphas_cumprod_sqrt[t_step-1]*betas[t_step])/(1-alphas_cumprod[t_step]) + x_t*(alphas_sqrt[t_step]*(1-alphas_cumprod[t_step-1]))/(1-alphas_cumprod[t_step])

# let's compare it to ground truth mean of the previous step (requires knowledge of x_0)
mean_gt=alphas_cumprod_sqrt[t_step-1]*img_
plt.subplot(1,3,1)
show(x_t)
plt.title('$x_t$   ($\ell_1$: {:.3f})'.format(F.l1_loss(x_t,img_)))
plt.subplot(1,3,2)
show(mean_pred)
plt.title(r'$\tilde{\mu}_{t-1}$' + '  ($\ell_1$: {:.3f})'.format(F.l1_loss(mean_pred,img_)))
plt.subplot(1,3,3)
show(mean_gt)
plt.title(r'$\mu_{t-1}$' + '  ($\ell_1$: {:.3f})'.format(F.l1_loss(mean_gt,img_)))

这个代码的思路是从xt->x0,再由xt、x0->μt-1,由μt-1加上方差即得xt-1

将上面的步骤合并就是函数reverse_step()

def reverse_step(epsilon, x_t, t_step, return_noise=False):
    
    # estimate x_0 based on epsilon
    x_0_pred=(x_t-(1-alphas_cumprod[t_step]).sqrt()*epsilon)/(alphas_cumprod_sqrt[t_step])
    if t_step==0:
        sample=x_0_pred
        noise=torch.zeros_like(x_0_pred)
    else:
        # estimate mean
        mean_pred=x_0_pred*(alphas_cumprod_sqrt[t_step-1]*betas[t_step])/(1-alphas_cumprod[t_step]) + x_t*(alphas_sqrt[t_step]*(1-alphas_cumprod[t_step-1]))/(1-alphas_cumprod[t_step])

        # compute variance
        beta_pred=betas[t_step].sqrt() if t_step != 0 else 0

        sample=mean_pred+beta_pred*torch.randn_like(x_t)
        # this noise is only computed for simulation purposes (since x_0_pred is not known normally)
        #使用式(4)得到噪音
        noise=(sample-x_0_pred*alphas_cumprod_sqrt[t_step-1])/(1-alphas_cumprod[t_step-1]).sqrt()
    if return_noise:
        return sample, noise
    else:
        return sample
x_t,noise=forward_jump(1000-1,img_,return_noise=True)

state_imgs=[x_t.numpy()]
for t_step in reversed(range(1000)):
    x_t,noise=reverse_step(noise,x_t,t_step,return_noise=True)
    
    if t_step % 200 == 0:
        state_imgs.append(x_t.numpy())

可以通过以上代码来验证reverse函数,但需要注意的是,此时仍然未完成diffusion,甚至可以说仅仅只是冰山一角。其原因是,上述代码已知了图片img_然后利用img_得到x_t和noise,事实上由这两个可以直接得到img_,而不需要经过一个reverse过程,因此这里的代码只是形象的展示了一下diffusion的过程,未涉及到模型的训练过程。

  • 23
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值