nanodiffusion代码逐行理解之diffusion

一、diffusion创建

diffusion = GaussianDiffusion(
        model,
        args.img_size,
        args.img_channels,
        args.num_classes,
        betas,
        ema_decay=args.ema_decay,
        ema_update_rate=args.ema_update_rate,
        ema_start=2000,
        loss_type=args.loss_type,
    )

二、GaussianDiffusion定义

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from functools import partial
from copy import deepcopy

from ddpm.ema import EMA
from ddpm.utils import extract

class GaussianDiffusion(nn.Module):
    
    def __init__(
        self,
        model,
        img_size,
        img_channels,
        num_classes,
        betas,
        loss_type="l2",
        ema_decay=0.9999,
        ema_start=5000,
        ema_update_rate=1,
    ):
        super().__init__()

        self.model = model
        self.ema_model = deepcopy(model)

        self.ema = EMA(ema_decay)
        self.ema_decay = ema_decay
        self.ema_start = ema_start
        self.ema_update_rate = ema_update_rate
        self.step = 0

        self.img_size = img_size
        self.img_channels = img_channels
        self.num_classes = num_classes

        if loss_type not in ["l1", "l2"]:
            raise ValueError("__init__() got unknown loss type")

        self.loss_type = loss_type
        self.num_timesteps = len(betas)

        alphas = 1.0 - betas
        alphas_cumprod = np.cumprod(alphas)

        to_torch = partial(torch.tensor, dtype=torch.float32)

        self.register_buffer("betas", to_torch(betas))
        self.register_buffer("alphas", to_torch(alphas))
        self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))

        self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
        self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1 - alphas_cumprod)))
        self.register_buffer("reciprocal_sqrt_alphas", to_torch(np.sqrt(1 / alphas)))

        self.register_buffer("remove_noise_coeff", to_torch(betas / np.sqrt(1 - alphas_cumprod)))
        self.register_buffer("sigma", to_torch(np.sqrt(betas)))

    def update_ema(self):
        self.step += 1
        if self.step % self.ema_update_rate == 0:
            if self.step < self.ema_start:
                self.ema_model.load_state_dict(self.model.state_dict())
            else:
                self.ema.update_model_average(self.ema_model, self.model)

    @torch.no_grad()
    def remove_noise(self, x, t, y, use_ema=True):
        if use_ema:
            return (
                (x - extract(self.remove_noise_coeff, t, x.shape) * self.ema_model(x, t, y)) *
                extract(self.reciprocal_sqrt_alphas, t, x.shape)
            )
        else:
            return (
                (x - extract(self.remove_noise_coeff, t, x.shape) * self.model(x, t, y)) *
                extract(self.reciprocal_sqrt_alphas, t, x.shape)
            )

    @torch.no_grad()
    def sample(self, batch_size, device, y=None, use_ema=True):
        if y is not None and batch_size != len(y):
            raise ValueError("sample batch size different from length of given y")

        x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)
        
        for t in range(self.num_timesteps - 1, -1, -1):
            t_batch = torch.tensor([t], device=device).repeat(batch_size)
            x = self.remove_noise(x, t_batch, y, use_ema)

            if t > 0:
                x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)
        
        return x.cpu().detach()

    @torch.no_grad()
    def sample_diffusion_sequence(self, batch_size, device, y=None, use_ema=True):
        if y is not None and batch_size != len(y):
            raise ValueError("sample batch size different from length of given y")

        x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)
        for t in range(self.num_timesteps - 1, -1, -1):
            t_batch = torch.tensor([t], device=device).repeat(batch_size)
            x = self.remove_noise(x, t_batch, y, use_ema)

            if t > 0:
                x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)
            
            yield x.cpu().detach()

    def perturb_x(self, x, t, noise):
        return (
            extract(self.sqrt_alphas_cumprod, t, x.shape) * x +
            extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * noise
        )   

    def get_losses(self, x, t, y):
        noise = torch.randn_like(x)

        perturbed_x = self.perturb_x(x, t, noise)
        estimated_noise = self.model(perturbed_x, t, y)

        if self.loss_type == "l1":
            loss = F.l1_loss(estimated_noise, noise)
        elif self.loss_type == "l2":
            loss = F.mse_loss(estimated_noise, noise)

        return loss

    def forward(self, x, y=None):
        b, c, h, w = x.shape
        device = x.device

        if h != self.img_size[0]:
            raise ValueError("image height does not match diffusion parameters")
        if w != self.img_size[0]:
            raise ValueError("image width does not match diffusion parameters")
        
        t = torch.randint(0, self.num_timesteps, (b,), device=device)
        return self.get_losses(x, t, y)


def generate_cosine_schedule(T, s=0.008):
    def f(t, T):
        return (np.cos((t / T + s) / (1 + s) * np.pi / 2)) ** 2
    
    alphas = []
    f0 = f(0, T)

    for t in range(T + 1):
        alphas.append(f(t, T) / f0)
    
    betas = []

    for t in range(1, T + 1):
        betas.append(min(1 - alphas[t] / alphas[t - 1], 0.999))
    
    return np.array(betas)


def generate_linear_schedule(T, low, high):
    return np.linspace(low, high, T)

三、代码理解

Input:
x: (N, img_channels, *img_size)
y: (N)
Output:
scalar loss tensor
Args:
model (nn.Module):估计高斯噪声的模型
img_size (tuple): (H, W)
img_channels (int): 图像通道数
betas (np.ndarray): diffusion betas 数组
loss_type (string): loss type, “l1” or “l2” 类型
ema_decay (float): model weights exponential moving average decay
ema_start (int): number of steps before EMA
ema_update_rate (int): number of steps before each EMA update
“”"

def init(self,model,img_size,img_channels,num_classes,betas, loss_type=“l2”, ema_decay=0.9999, ema_start=5000, ema_update_rate=1,):

在这里插入图片描述
np.cumprod返回数组沿指定轴的累计积。
a=[a1,a2,a3,a4,a5]
np.cumprod(a)=array([a1,a1a2,a1a2a3,a1a2a3a4,a1a2a3a4a5])。

def remove_noise(self, x, t, y, use_ema=True):

(x - extract(self.remove_noise_coeff, t, x.shape) * self.model(x, t, y)) *extract(self.reciprocal_sqrt_alphas, t, x.shape)

这个函数就是去除第t-1到第t步的噪声
在这里插入图片描述
在这个函数里面调用了extract函数。实现的功能:提取时间步t时对应的参数

def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

a: Tensor:(1000,)
t: Tensor:(128,)
x_shape: torch.Size([128, 1, 28, 28])
最终返回的是Tensor:(128,1,1,1)
模型定义在初始化函数中,模型调用定义在forward函数中。

def sample(self, batch_size, device, y=None, use_ema=True):

    def sample(self, batch_size, device, y=None, use_ema=True):
        if y is not None and batch_size != len(y):
            raise ValueError("sample batch size different from length of given y")

        x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)
        
        for t in range(self.num_timesteps - 1, -1, -1):
            t_batch = torch.tensor([t], device=device).repeat(batch_size)
            x = self.remove_noise(x, t_batch, y, use_ema)

            if t > 0:
                x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)
        
        return x.cpu().detach()

def perturb_x(self, x, t, noise):

在图像中添加噪声

        return (
            extract(self.sqrt_alphas_cumprod, t, x.shape) * x +
            extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * noise
        )   

在这里插入图片描述

def get_losses(self, x, t, y):

计算添加噪声和估计噪声的损失

        noise = torch.randn_like(x)

        perturbed_x = self.perturb_x(x, t, noise)
        estimated_noise = self.model(perturbed_x, t, y)

        if self.loss_type == "l1":
            loss = F.l1_loss(estimated_noise, noise)
        elif self.loss_type == "l2":
            loss = F.mse_loss(estimated_noise, noise)

        return loss

def forward(self, x, y=None):

前向函数很简单,随机b个t,然后计算对应的噪声损失。

    def forward(self, x, y=None):
        b, c, h, w = x.shape
        device = x.device

        if h != self.img_size[0]:
            raise ValueError("image height does not match diffusion parameters")
        if w != self.img_size[0]:
            raise ValueError("image width does not match diffusion parameters")
        
        t = torch.randint(0, self.num_timesteps, (b,), device=device)
        return self.get_losses(x, t, y)

def generate_cosine_schedule(T, s=0.008)和def generate_linear_schedule(T, low, high):

这个函数就是两种不同的生成betas的方法。betas数组是从小到大排列的。

  • 20
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
引用\[1\]中提到了一个名为respace.py的文件中的一个类,该类继承自GaussianDiffusion,并覆盖了父类的一些函数。这个类实现了一个可以跳过基本扩散过程中的步骤的扩散过程。\[1\] 引用\[2\]中提到了DDPM和IDDPM的学习,以及本次学习的另一种重要的扩散模型。这个模型的采样速度比DDPM快很多,扩散过程不依赖于马尔科夫链。这个模型被称为Denoising diffusion implicit models,是在ICLR 2021上提出的。\[2\] 引用\[3\]中提到了DDPM和DDIM的比较。DDPM的加噪和去噪过程都基于马尔科夫链,导致步数较多。而DDIM的训练过程和DDPM相同,可以重用DDPM的权重和代码。只需要重新编写一个采样的代码,就可以享受到采样步数减少的好处。DDIM的采样过程是确定的。此外,引用\[3\]还提到了其他一些概率模型,如扩散模型和分数模型。采样过程可以是基于郎之万或对逆扩散过程进行建模。\[3\] 根据以上引用内容,diffusion逐行实现的具体细节没有被提及。但可以根据引用\[1\]中的信息推测,diffusion逐行实现可能涉及到对GaussianDiffusion类的函数进行覆盖和修改,以实现跳过基本扩散过程中的步骤。 #### 引用[.reference_title] - *1* *2* *3* [DDIM原理及代码(Denoising diffusion implicit models)](https://blog.csdn.net/weixin_43850253/article/details/128413786)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down28v1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值