DDPM 核心代码解析(1)

所有代码 已上传至GitHub - duhanyue349/diffusion_model_learned_ddpm_main: 扩散模型基础框架源代码

目录结构如下

在train_cifar.py 中展示了扩散模型训练的所有代码

如果没有安装wandb  可以在create_argparser()设置 log_to_wandb=False

 一、加载模型参数 args

 这里用了一个create_argparser()函数创建命令行解析器

args = create_argparser().parse_args()
def create_argparser():
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    run_name = datetime.datetime.now().strftime("ddpm-%Y-%m-%d-%H-%M")
    defaults = dict(
        learning_rate=2e-4,
        batch_size=128,
        iterations=800000,

        log_to_wandb=True,
        log_rate=10,
        checkpoint_rate=10,
        log_dir="~/ddpm_logs",
        project_name='ddpm',
        run_name=run_name,

        model_checkpoint=None,
        optim_checkpoint=None,

        schedule_low=1e-4,
        schedule_high=0.02,

        device=device,
    )
    defaults.update(script_utils.diffusion_defaults())

    parser = argparse.ArgumentParser()
    script_utils.add_dict_to_argparser(parser, defaults)
    return parser

 

defaults是基础的一些参数,用defaults.update(script_utils.diffusion_defaults())可以将模型参数加载进来  使用了这个函数diffusion_defaults()其返回的是一个字典
def diffusion_defaults():
    defaults = dict(
        num_timesteps=1000,
        schedule="linear",
        loss_type="l2",
        use_labels=False,

        base_channels=128,
        channel_mults=(1, 2, 2, 2),
        num_res_blocks=2,
        time_emb_dim=128 * 4,
        norm="gn",
        dropout=0.1,
        activation="silu",
        attention_resolutions=(1,),
        schedule_low=1e-4,
        schedule_high=0.02,

        ema_decay=0.9999,
        ema_update_rate=1,
    )

    return defaults
随后  实例化命令行解析器
parser = argparse.ArgumentParser()
script_utils.add_dict_to_argparser(parser, defaults)为解析器添加参数
 #这个函数是运用了字典存储命令函参数的形式,通过命令行参数的键值对来获取参数
def add_dict_to_argparser(parser, default_dict):
    """
    https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/script_util.py
    """
    for k, v in default_dict.items():
        v_type = type(v)
        if v is None:
            v_type = str
        elif isinstance(v, bool):
            v_type = str2bool
        parser.add_argument(f"--{k}", default=v, type=v_type)
基础的创建解析器的步骤为
#parser = argparse.ArgumentParser() 创建命令行解析器
#parser.add_argument() 添加命令行参数
#args = parser.parse_args() 对命令行参数进行解析

二、获得 diffusion 模型架构

diffusion = script_utils.get_diffusion_from_args(args).to(device)

这里调用了 get_diffusion_from_args   函数 加载模型,输入时刚刚创建的参数解析器

def get_diffusion_from_args(args):
    activations = {
        "relu": F.relu,
        "mish": F.mish,
        "silu": F.silu,
    }

    model = UNet(
        img_channels=3,

        base_channels=args.base_channels,
        channel_mults=args.channel_mults,
        time_emb_dim=args.time_emb_dim,
        norm=args.norm,
        dropout=args.dropout,
        activation=activations[args.activation],
        attention_resolutions=args.attention_resolutions,

        num_classes=None if not args.use_labels else 10,
        initial_pad=0,
    )

    if args.schedule == "cosine":
        betas = generate_cosine_schedule(args.num_timesteps)
    else:
        betas = generate_linear_schedule(
            args.num_timesteps,
            args.schedule_low * 1000 / args.num_timesteps,
            args.schedule_high * 1000 / args.num_timesteps,
        )

    diffusion = GaussianDiffusion(
        model, (32, 32), 3, 10,
        betas,
        ema_decay=args.ema_decay,
        ema_update_rate=args.ema_update_rate,
        ema_start=2000,
        loss_type=args.loss_type,
    )

    return diffusion

返回的是一个  GaussianDiffusion类   把model(UNet)、betas、loss_type等传给了这个类

 这里beta 有两种定义方法  一个时cosine 一个是linear 

betas = generate_cosine_schedule(args.num_timesteps)
betas = generate_linear_schedule(
            args.num_timesteps,
            args.schedule_low * 1000 / args.num_timesteps,
            args.schedule_high * 1000 / args.num_timesteps,
        )
GaussianDiffusionz这个类在diffusion .py 文件中
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from functools import partial
from copy import deepcopy

from .ema import EMA
from .utils import extract

class GaussianDiffusion(nn.Module):
    __doc__ = r"""Gaussian Diffusion model. Forwarding through the module returns diffusion reversal scalar loss tensor.

    Input:
        x: tensor of shape (N, img_channels, *img_size)
        y: tensor of shape (N)
    Output:
        scalar loss tensor
    Args:
        model (nn.Module): model which estimates diffusion noise
        img_size (tuple): image size tuple (H, W)
        img_channels (int): number of image channels
        betas (np.ndarray): numpy array of diffusion betas
        loss_type (string): loss type, "l1" or "l2"
        ema_decay (float): model weights exponential moving average decay
        ema_start (int): number of steps before EMA
        ema_update_rate (int): number of steps before each EMA update
    """
    def __init__(
        self,
        model,
        img_size,
        img_channels,
        num_classes,
        betas,
        loss_type="l2",
        ema_decay=0.9999,
        ema_start=5000,
        ema_update_rate=1,
    ):
        super().__init__()

        self.model = model
        self.ema_model = deepcopy(model)

        self.ema = EMA(ema_decay)
        self.ema_decay = ema_decay
        self.ema_start = ema_start
        self.ema_update_rate = ema_update_rate
        self.step = 0

        self.img_size = img_size
        self.img_channels = img_channels
        self.num_classes = num_classes

        if loss_type not in ["l1", "l2"]:
            raise ValueError("__init__() got unknown loss type")

        self.loss_type = loss_type
        self.num_timesteps = len(betas)

        alphas = 1.0 - betas
        alphas_cumprod = np.cumprod(alphas)

        to_torch = partial(torch.tensor, dtype=torch.float32)

        self.register_buffer("betas", to_torch(betas))
        self.register_buffer("alphas", to_torch(alphas))
        self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))

        self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
        self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1 - alphas_cumprod)))
        self.register_buffer("reciprocal_sqrt_alphas", to_torch(np.sqrt(1 / alphas)))

        self.register_buffer("remove_noise_coeff", to_torch(betas / np.sqrt(1 - alphas_cumprod)))
        self.register_buffer("sigma", to_torch(np.sqrt(betas)))

    def update_ema(self):
        self.step += 1
        if self.step % self.ema_update_rate == 0:
            if self.step < self.ema_start:
                self.ema_model.load_state_dict(self.model.state_dict())
            else:
                self.ema.update_model_average(self.ema_model, self.model)

    @torch.no_grad()
    def remove_noise(self, x, t, y, use_ema=True):
        if use_ema:
            return (
                (x - extract(self.remove_noise_coeff, t, x.shape) * self.ema_model(x, t, y)) *
                extract(self.reciprocal_sqrt_alphas, t, x.shape)
            )
        else:
            return (
                (x - extract(self.remove_noise_coeff, t, x.shape) * self.model(x, t, y)) *
                extract(self.reciprocal_sqrt_alphas, t, x.shape)
            )

    @torch.no_grad()
    def sample(self, batch_size, device, y=None, use_ema=True):
        if y is not None and batch_size != len(y):
            raise ValueError("sample batch size different from length of given y")

        x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)
        
        for t in range(self.num_timesteps - 1, -1, -1):
            t_batch = torch.tensor([t], device=device).repeat(batch_size)
            x = self.remove_noise(x, t_batch, y, use_ema)

            if t > 0:
                x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)
        
        return x.cpu().detach()

    @torch.no_grad()
    def sample_diffusion_sequence(self, batch_size, device, y=None, use_ema=True):
        if y is not None and batch_size != len(y):
            raise ValueError("sample batch size different from length of given y")

        x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)
        diffusion_sequence = [x.cpu().detach()]
        
        for t in range(self.num_timesteps - 1, -1, -1):
            t_batch = torch.tensor([t], device=device).repeat(batch_size)
            x = self.remove_noise(x, t_batch, y, use_ema)

            if t > 0:
                x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)
            
            diffusion_sequence.append(x.cpu().detach())
        
        return diffusion_sequence

    def perturb_x(self, x, t, noise):
        return (
            extract(self.sqrt_alphas_cumprod, t, x.shape) * x +
            extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * noise
        )   

    def get_losses(self, x, t, y):
        noise = torch.randn_like(x)

        perturbed_x = self.perturb_x(x, t, noise)
        estimated_noise = self.model(perturbed_x, t, y)

        if self.loss_type == "l1":
            loss = F.l1_loss(estimated_noise, noise)
        elif self.loss_type == "l2":
            loss = F.mse_loss(estimated_noise, noise)

        return loss

    def forward(self, x, y=None):
        b, c, h, w = x.shape
        device = x.device

        if h != self.img_size[0]:
            raise ValueError("image height does not match diffusion parameters")
        if w != self.img_size[0]:
            raise ValueError("image width does not match diffusion parameters")
        
        t = torch.randint(0, self.num_timesteps, (b,), device=device)
        return self.get_losses(x, t, y)


def generate_cosine_schedule(T, s=0.008):
    def f(t, T):
        return (np.cos((t / T + s) / (1 + s) * np.pi / 2)) ** 2
    
    alphas = []
    f0 = f(0, T)

    for t in range(T + 1):
        alphas.append(f(t, T) / f0)
    
    betas = []

    for t in range(1, T + 1):
        betas.append(min(1 - alphas[t] / alphas[t - 1], 0.999))
    
    return np.array(betas)


def generate_linear_schedule(T, low, high):
    return np.linspace(low, high, T)

 三、获得优化器、数据集

optimizer = torch.optim.Adam(diffusion.parameters(), lr=args.learning_rate)
#省略了原代码中if args.model_checkpoint is not None:、if args.log_to_wandb:.....这些

batch_size = args.batch_size

        train_dataset = datasets.CIFAR10(
            root='./cifar_train',
            train=True,
            download=True,
            transform=script_utils.get_transform(),
        )

        test_dataset = datasets.CIFAR10(
            root='./cifar_test',
            train=False,
            download=True,
            transform=script_utils.get_transform(),
        )

        train_loader = script_utils.cycle(DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=2,
        ))
        test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=True, num_workers=2)

这里采用了一个cycle 的函数  循环加载数据  后面会和next 一起使用,x, y = next(train_loader)

from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as T

class RescaleChannels:
    def __call__(self, sample):
        return 2 * sample - 1

'''return 2 * sample - 1 - 这一行代码是对输入 sample 进行线性变换的公式。
它将 sample 的每个像素值乘以 2,然后减去 1。这样的变换通常用于将像素值从 [0, 1] 范围映射到 [-1, 1] 范围。
为什么要做这样的转换?在神经网络训练中,将数据归一化到特定的范围可以带来以下好处:
数值稳定性:某些激活函数(如 tanh)在输入接近 [-1, 1] 范围时性能更好。
加速收敛:归一化的数据可以减少梯度消失或爆炸的问题,从而加快模型的训练速度。
标准化:确保不同来源或不同尺度的数据在模型中具有相似的影响。
所以,当你有一个 sample,比如说一个图像,其像素值范围是 [0, 1],
通过 RescaleChannels 类的实例调用,它会将像素值转换到 [-1, 1] 范围,这在很多情况下对于模型训练是有利的'''
def get_transform():
    return T.Compose([
        T.ToTensor(),
        RescaleChannels(),
    ])
train_dataset = datasets.CIFAR10(
    root='./cifar_train',
    train=True,
    download=True,
    transform=get_transform(),
)
def cycle(dl):
    """
    https://github.com/lucidrains/denoising-diffusion-pytorch/
    """
    while True:
        for data in dl:
            yield data

#这个 cycle 函数是一个无限循环的生成器,它的作用是让数据加载器(dl)的数据可以被无限次地迭代。这种设计通常在深度学习中用于数据增强或者当训练数据集较小而希望增加训练轮次时使用。
#当这个函数被调用时,它会不断地从 dl 中取出数据,一旦 dl 的数据被完全遍历,它会重新开始遍历,从而形成一个无限循环的数据流。
#这种设计允许你在训练模型时,即使数据集很小,也可以像拥有无限数据一样进行训练。
train_loader = cycle(DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=2,
        ))

 四、开始训练

acc_train_loss = 0

        for iteration in range(1, args.iterations + 1):
            diffusion.train()

            x, y = next(train_loader)
            x = x.to(device)
            y = y.to(device)

            if args.use_labels:
                loss = diffusion(x, y)
            else:
                loss = diffusion(x)

            acc_train_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            diffusion.update_ema()

 五、测试部分

            if iteration % args.log_rate == 0:
                test_loss = 0
                with torch.no_grad():
                    diffusion.eval()
                    for x, y in test_loader:
                        x = x.to(device)
                        y = y.to(device)

                        if args.use_labels:
                            loss = diffusion(x, y)
                        else:
                            loss = diffusion(x)

                        test_loss += loss.item()

                if args.use_labels:
                    samples = diffusion.sample(10, device, y=torch.arange(10, device=device))
                else:
                    samples = diffusion.sample(10, device)
            #将张量(tensor)转换回可以显示或保存的图像格式
                samples = ((samples + 1) / 2).clip(0, 1).permute(0, 2, 3, 1).numpy()#将数据范围从[-1, 1](常见的归一化范围)转换到[0, 2] 将范围从[0, 2]缩放到[0, 1],这是常见的图像像素值范围。函数确保所有像素值都在0和1之间。这可以防止因为浮点数运算误差导致的像素值超出正常范围。,permute将其转换为(batch_size, height, width, channels),这通常是将数据从PyTorch的通道优先格式转换为更通用的格式,便于显示或保存图像。

                test_loss /= len(test_loader)
                acc_train_loss /= args.log_rate

                wandb.log({
                    "test_loss": test_loss,
                    "train_loss": acc_train_loss,
                    "samples": [wandb.Image(sample) for sample in samples],
                })

                acc_train_loss = 0

            if iteration % args.checkpoint_rate == 0:
                model_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-model.pth"
                optim_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-optim.pth"

                # 获取目录路径
                log_dir1 = os.path.dirname(model_filename)
                log_dir2 = os.path.dirname(optim_filename)

                # 创建目录,如果它不存在
                os.makedirs(log_dir1, exist_ok=True)
                os.makedirs(log_dir2, exist_ok=True)

                # 使用完整的文件路径保存模型和优化器状态
                torch.save(diffusion.state_dict(), model_filename)
                torch.save(optimizer.state_dict(), optim_filename)

六、整个train_cifar 所有代码

 

import argparse
import datetime
import torch
import wandb
from torch.utils.data import DataLoader
from torchvision import datasets
from ddpm import script_utils
import os

os.environ["WANDB_API_KEY"] = "b9171ddb0a1638d8cca0425e41c8a9d789281515"
os.environ["WANDB_MODE"] = "online"

wandb.login(key="b9171ddb0a1638d8cca0425e41c8a9d789281515")
def main():
    args = create_argparser().parse_args()
    device = args.device

    try:
        diffusion = script_utils.get_diffusion_from_args(args).to(device)
        optimizer = torch.optim.Adam(diffusion.parameters(), lr=args.learning_rate)

        if args.model_checkpoint is not None:
            diffusion.load_state_dict(torch.load(args.model_checkpoint))
        if args.optim_checkpoint is not None:
            optimizer.load_state_dict(torch.load(args.optim_checkpoint))

        if args.log_to_wandb:
            if args.project_name is None:
                raise ValueError("args.log_to_wandb set to True but args.project_name is None")

            run = wandb.init(
                project=args.project_name,
                config=vars(args),
                name=args.run_name,
            )
            wandb.watch(diffusion)

        batch_size = args.batch_size

        train_dataset = datasets.CIFAR10(
            root='./cifar_train',
            train=True,
            download=True,
            transform=script_utils.get_transform(),
        )

        test_dataset = datasets.CIFAR10(
            root='./cifar_test',
            train=False,
            download=True,
            transform=script_utils.get_transform(),
        )

        train_loader = script_utils.cycle(DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=2,
        ))
        test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=True, num_workers=2)

        acc_train_loss = 0

        for iteration in range(1, args.iterations + 1):
            diffusion.train()

            x, y = next(train_loader)
            x = x.to(device)
            y = y.to(device)

            if args.use_labels:
                loss = diffusion(x, y)
            else:
                loss = diffusion(x)

            acc_train_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            diffusion.update_ema()

            if iteration % args.log_rate == 0:
                test_loss = 0
                with torch.no_grad():
                    diffusion.eval()
                    for x, y in test_loader:
                        x = x.to(device)
                        y = y.to(device)

                        if args.use_labels:
                            loss = diffusion(x, y)
                        else:
                            loss = diffusion(x)

                        test_loss += loss.item()

                if args.use_labels:
                    samples = diffusion.sample(10, device, y=torch.arange(10, device=device))
                else:
                    samples = diffusion.sample(10, device)
            #将张量(tensor)转换回可以显示或保存的图像格式
                samples = ((samples + 1) / 2).clip(0, 1).permute(0, 2, 3, 1).numpy()#将数据范围从[-1, 1](常见的归一化范围)转换到[0, 2] 将范围从[0, 2]缩放到[0, 1],这是常见的图像像素值范围。函数确保所有像素值都在0和1之间。这可以防止因为浮点数运算误差导致的像素值超出正常范围。,permute将其转换为(batch_size, height, width, channels),这通常是将数据从PyTorch的通道优先格式转换为更通用的格式,便于显示或保存图像。

                test_loss /= len(test_loader)
                acc_train_loss /= args.log_rate

                wandb.log({
                    "test_loss": test_loss,
                    "train_loss": acc_train_loss,
                    "samples": [wandb.Image(sample) for sample in samples],
                })

                acc_train_loss = 0

            if iteration % args.checkpoint_rate == 0:
                model_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-model.pth"
                optim_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-optim.pth"

                # 获取目录路径
                log_dir1 = os.path.dirname(model_filename)
                log_dir2 = os.path.dirname(optim_filename)

                # 创建目录,如果它不存在
                os.makedirs(log_dir1, exist_ok=True)
                os.makedirs(log_dir2, exist_ok=True)

                # 使用完整的文件路径保存模型和优化器状态
                torch.save(diffusion.state_dict(), model_filename)
                torch.save(optimizer.state_dict(), optim_filename)

        if args.log_to_wandb:
            run.finish()
    except KeyboardInterrupt:
        if args.log_to_wandb:
            run.finish()
        print("Keyboard interrupt, run finished early")


def create_argparser():
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    run_name = datetime.datetime.now().strftime("ddpm-%Y-%m-%d-%H-%M")
    defaults = dict(
        learning_rate=2e-4,
        batch_size=128,
        iterations=800000,

        log_to_wandb=True,
        log_rate=10,
        checkpoint_rate=10,
        log_dir="~/ddpm_logs",
        project_name='ddpm',
        run_name=run_name,

        model_checkpoint=None,
        optim_checkpoint=None,

        schedule_low=1e-4,
        schedule_high=0.02,

        device=device,
    )
    defaults.update(script_utils.diffusion_defaults())

    parser = argparse.ArgumentParser()
    script_utils.add_dict_to_argparser(parser, defaults)
    return parser


if __name__ == "__main__":
    main()

  • 6
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
DDPM(Diffusion Probabilistic Models)是一种用于生成模型的深度学习方法。以下是 DDPM 的 TensorFlow 2.0 实现代码示例: ```python import tensorflow as tf from tensorflow.keras import layers class GaussianDiffusion(tf.keras.Model): def __init__(self, num_filters, num_diffusion_timesteps): super(GaussianDiffusion, self).__init__() self.num_filters = num_filters self.num_diffusion_timesteps = num_diffusion_timesteps self.diffusion_step = 1 / (num_diffusion_timesteps - 1) self.net = tf.keras.Sequential([ layers.Conv2D(num_filters, 3, padding='same', activation='relu'), layers.Conv2D(num_filters, 3, padding='same', activation='relu'), layers.Conv2D(num_filters, 3, padding='same', activation='relu'), layers.Conv2D(num_filters, 3, padding='same', activation='relu'), layers.Conv2D(num_filters, 3, padding='same', activation='relu'), layers.Conv2D(num_filters, 3, padding='same', activation=None), ]) def call(self, x, t, noise=None): x = tf.cast(x, tf.float32) t = tf.cast(t, tf.float32) x_shape = tf.shape(x) batch_size = x_shape[0] height = x_shape[1] width = x_shape[2] if noise is None: noise = tf.random.normal([batch_size, height, width, 3]) for i in range(self.num_diffusion_timesteps): scale = tf.math.sqrt(1 - self.diffusion_step * i) x_noisy = x + scale * noise net_in = tf.concat([x_noisy, t[:, tf.newaxis, tf.newaxis, tf.newaxis] * tf.ones_like(x_noisy)], axis=-1) noise = noise + self.net(net_in) * tf.math.sqrt(self.diffusion_step) return x_noisy ``` 这段代码实现了一个名为 GaussianDiffusion 的 TensorFlow 2.0 模型,并且提供了一个 `call` 方法,可以用于生成模型。其中,`num_filters` 表示卷积层中的滤波器数量,`num_diffusion_timesteps` 表示扩散时间步数。模型输入 `x` 表示图像,`t` 表示时间步,`noise` 表示噪声。最终,该模型会返回一个经过扩散的图像。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值