diffusion 之 cifar/mnist 数据集


代码出处:https://github.com/abarankab/DDPM

wandb的问题解决方法:

step1: 按照这个https://blog.csdn.net/weixin_43164054/article/details/124156206一步步走 step2: 修改project_name=“cifar”,然后执行python train_cifar.py 若出现报错"wandb: ERROR It appears that you do not have permission to access the requested resource.",参看这个https://blog.csdn.net/weixin_43835996/article/details/126955917

cifar10数据集

配置好wandb,按照github上的源代码
将DDPM/scripts/train_mnist.py中的entity='treaptofun'去掉

 run = wandb.init(
                project=args.project_name,
                
                config=vars(args),
                name=args.run_name,
            )
            # entity='treaptofun',

然后就可以正常进行训练了

mnist数据集

对于mnist数据集需要修改如下两个文件

ddpm/script_utils.py

line 90:img_channel=1,因为cifar图片为3通道,而mnist图片为1通道
line 101: initial_pad=2, 是因为cifar数据集的图片大小为32,为2的指数倍,降采样过程中除以2的话一直能整除;而mnist的图片大小为28,所以要padding为32,即设置initial_pad=2
line 120:cifar10 的图片大小为3232, mnist的图片大小为2828,

import argparse
import torchvision
import torch.nn.functional as F

from .unet import UNet
from .diffusion import (
    GaussianDiffusion,
    generate_linear_schedule,
    generate_cosine_schedule,
)


def cycle(dl):
    """
    https://github.com/lucidrains/denoising-diffusion-pytorch/
    """
    while True:
        for data in dl:
            yield data

def get_transform():
    class RescaleChannels(object):
        def __call__(self, sample):
            return 2 * sample - 1

    return torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        RescaleChannels(),
    ])


def str2bool(v):
    """
    https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
    """
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise argparse.ArgumentTypeError("boolean value expected")


def add_dict_to_argparser(parser, default_dict):
    """
    https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/script_util.py
    """
    for k, v in default_dict.items():
        v_type = type(v)
        if v is None:
            v_type = str
        elif isinstance(v, bool):
            v_type = str2bool
        parser.add_argument(f"--{k}", default=v, type=v_type)


def diffusion_defaults():
    defaults = dict(
        num_timesteps=1000,
        schedule="linear",
        loss_type="l2",
        use_labels=False,

        base_channels=128,
        channel_mults=(1, 2, 2, 2),
        num_res_blocks=2,
        time_emb_dim=128 * 4,
        norm="gn",
        dropout=0.1,
        activation="silu",
        attention_resolutions=(1,),

        ema_decay=0.9999,
        ema_update_rate=1,
    )

    return defaults


def get_diffusion_from_args(args):
    activations = {
        "relu": F.relu,
        "mish": F.mish,
        "silu": F.silu,
    }
    # base_channels=128
    model = UNet(
        img_channels=1,

        base_channels=args.base_channels,
        channel_mults=args.channel_mults,
        time_emb_dim=args.time_emb_dim,
        norm=args.norm,
        dropout=args.dropout,
        activation=activations[args.activation],
        attention_resolutions=args.attention_resolutions,

        num_classes=None if not args.use_labels else 10,
        initial_pad=2,
    )
    # line102  在cifar中为initial_pad=0,  

    if args.schedule == "cosine":
        betas = generate_cosine_schedule(args.num_timesteps)
    else:
        betas = generate_linear_schedule(
            args.num_timesteps,
            args.schedule_low * 1000 / args.num_timesteps,
            args.schedule_high * 1000 / args.num_timesteps,
        )

    # 本py文件共修改了3处:line 90 ; line 101 ;line 120.
    # model, (32, 32), 3, 10,    
    # cifar10 的图片大小为32*32,3channel, mnist的图片大小为28*28,1channel
    
    
    diffusion = GaussianDiffusion(
        model, (28, 28), 1, 10,
        betas,
        ema_decay=args.ema_decay,
        ema_update_rate=args.ema_update_rate,
        ema_start=2000,
        loss_type=args.loss_type,
    )

    return diffusion

scripts/train_mnist.py

把entity=‘treaptofun’,给去掉

import argparse
import datetime
import torch
import wandb

from torch.utils.data import DataLoader
from torchvision import datasets
from ddpm import script_utils


def main():
    args = create_argparser().parse_args()
    device = args.device

    try:
        diffusion = script_utils.get_diffusion_from_args(args).to(device)
        optimizer = torch.optim.Adam(diffusion.parameters(), lr=args.learning_rate)


        # 接着上次中断保存的参数继续训练
        if args.model_checkpoint is not None:
            diffusion.load_state_dict(torch.load(args.model_checkpoint))
        if args.optim_checkpoint is not None:
            optimizer.load_state_dict(torch.load(args.optim_checkpoint))

        if args.log_to_wandb:
            if args.project_name is None:
                raise ValueError("args.log_to_wandb set to True but args.project_name is None")

            # wandb.init(project="ddpm_cifar")

            run = wandb.init(
                project=args.project_name,
                
                config=vars(args),
                name=args.run_name,
            )
            # entity='treaptofun',

            wandb.watch(diffusion)

        batch_size = args.batch_size

        train_dataset = datasets.MNIST(
            root='../dataset/mnist/mnist_train',
            train=True,
            download=True,
            transform=script_utils.get_transform(),
        )

        test_dataset = datasets.MNIST(
            root='../dataset/mnist/mnist_test',
            train=False,
            download=True,
            transform=script_utils.get_transform(),
        )

        train_loader = script_utils.cycle(DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=2,
        ))
        test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=True, num_workers=2)
        
        acc_train_loss = 0

        for iteration in range(1, args.iterations + 1):
            diffusion.train()

            x, y = next(train_loader)
            x = x.to(device)
            y = y.to(device)

            if args.use_labels:
                loss = diffusion(x, y)
            else:
                loss = diffusion(x)

            acc_train_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            diffusion.update_ema()
            
            if iteration % args.log_rate == 0:
                test_loss = 0
                with torch.no_grad():
                    diffusion.eval()
                    for x, y in test_loader:
                        x = x.to(device)
                        y = y.to(device)

                        if args.use_labels:
                            loss = diffusion(x, y)
                        else:
                            loss = diffusion(x)

                        test_loss += loss.item()
                
                if args.use_labels:
                    samples = diffusion.sample(10, device, y=torch.arange(10, device=device))
                else:
                    samples = diffusion.sample(10, device)
                
                samples = ((samples + 1) / 2).clip(0, 1).permute(0, 2, 3, 1).numpy()

                test_loss /= len(test_loader)
                acc_train_loss /= args.log_rate

                wandb.log({
                    "test_loss": test_loss,
                    "train_loss": acc_train_loss,
                    "samples": [wandb.Image(sample) for sample in samples],
                })

                acc_train_loss = 0
            
            if iteration % args.checkpoint_rate == 0:
                model_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-model.pth"
                optim_filename = f"{args.log_dir}/{args.project_name}-{args.run_name}-iteration-{iteration}-optim.pth"

                torch.save(diffusion.state_dict(), model_filename)
                torch.save(optimizer.state_dict(), optim_filename)
        
        if args.log_to_wandb:
            run.finish()
    except KeyboardInterrupt:
        if args.log_to_wandb:
            run.finish()
        print("Keyboard interrupt, run finished early")


def create_argparser():
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    run_name = datetime.datetime.now().strftime("ddpm-%Y-%m-%d-%H-%M")
    defaults = dict(
        learning_rate=2e-4,
        batch_size=128,
        iterations=80000,

        log_to_wandb=True,
        log_rate=1000,
        checkpoint_rate=1000,
        log_dir="./ddpm_logs_mnist",
        project_name="mnist",
        run_name=run_name,

        model_checkpoint=None,
        optim_checkpoint=None,

        schedule_low=1e-4,
        schedule_high=0.02,

        device=device,
    )
    defaults.update(script_utils.diffusion_defaults())

    parser = argparse.ArgumentParser()
    script_utils.add_dict_to_argparser(parser, defaults)
    return parser


if __name__ == "__main__":
    main()

命令行执行的训练命令:
python train.py
命令行执行的采样命令
python sample_images.py --model_path "your model path" --save_dir "your save img path" --schedule cosine

展示采样结果

import matplotlib.pyplot as plt
import numpy as np
import os

def show(num_imgs, dir_path):
    ''' 
    num_imgs: 要展示的图片的张数
    dir_path:图片的路径
    '''
    img_names=os.listdir (dir_path)
    img_names.sort(key=lambda x:int(x.split('.')[0]))

    plt.figure(figsize=(20,5)) # 画布大小
    N=2
    M=10
    #形成NxM大小的画布
    for i in range(num_imgs):#有张图片
        path = dir_path + img_names[i]
        img = plt.imread(path)
        plt.subplot(N,M,i+1)#表示第i张图片,下标只能从1开始,不能从0,
        plt.imshow(img)
        plt.title(img_names[i],color='black')
        #下面两行是消除每张图片自己单独的横纵坐标,不然每张图片会有单独的横纵坐标,影响美观
        plt.xticks([])
        plt.yticks([])
    plt.show()

print("mnist generation results:")
show(20, './scripts/save_dir_mnist/')  # 模型训练出来的保存的结果

这里的名字只是预测出来的图片的序号,并不是预测的label!
在这里插入图片描述

无label的训练和采样过程

训练过程:

def get_losses(self, x, t, y):
    noise = torch.randn_like(x)

    perturbed_x = self.perturb_x(x, t, noise)
    estimated_noise = self.model(perturbed_x, t, y)    # 输入到Model的是加噪后的图片
    # 这个model预测出来的噪声是每个像素点位置上的噪声!!!
    # 因为这个model的output的形状和x是一样的,[batch, img_channel, h, w]

    if self.loss_type == "l1":
        loss = F.l1_loss(estimated_noise, noise)
    elif self.loss_type == "l2":
        loss = F.mse_loss(estimated_noise, noise)

    return loss
  • x: (batch_size, img_channel, h, w)
  • t: (batch_size, )
    在区间[0, num_timesteps]里面随机生成b个时间,扩散过程并不是逐步进行的,t 是一个大小为batch的张量

说明一下:这个t是怎么加入到图片x中的
t最开始为(batch_size,)形状的张量,它经过linear,变成了(batch_size, img_channel),然后经过扩维,变成(batch_size, img_channel, 1, 1),经过广播机制就可以加入 x: (batch_size, img_channel, h, w)里面,即x的同一个channel上的所有像素点加入的t的值是一样的。

  • perturb_x: 根据公式 x t = α t ˉ . x 0 + 1 − α t ˉ . z x_t = \sqrt{\bar{\alpha_t}}.x_0 + \sqrt{1 - \bar{\alpha_t}}.z xt=αtˉ .x0+1αtˉ .z x 0 x_0 x0进行加噪,perturbed_x的形状为(batch_size, img_channel, h, w)
  • model(perturbed_x, t, y) : 输入加噪后的图片,和对应的时间t,model预测出来的是加入的噪声,通过对perturbed_x进行卷积,激活,降采样,上采样一通操作,最终model输出的形状仍为(batch_size, img_channel, h, w),model的output就是预测加入的噪声。那这里预测的噪声就是预测出来的是加入到每个像素点位置上的噪声!
  • 用l1或l2损失函数来计算损失。

采样过程

    @torch.no_grad()
    def sample(self, batch_size, device, y=None, use_ema=True):
        if y is not None and batch_size != len(y):
            raise ValueError("sample batch size different from length of given y")

        x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)
        
        for t in range(self.num_timesteps - 1, -1, -1): # 从T=[t-1]到T=[0]
            t_batch = torch.tensor([t], device=device).repeat(batch_size)
            x = self.remove_noise(x, t_batch, y, use_ema)

            if t > 0:
                x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)
        
        return x.cpu().detach()

x t − 1 = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) + σ t z \mathbf{x}_{t-1}=\frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t}-\frac{1-\alpha_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \boldsymbol{\epsilon}_{\theta}\left(\mathbf{x}_{t}, t\right)\right)+\sigma_{t} \mathbf{z} xt1=αt 1(xt1αˉt 1αtϵθ(xt,t))+σtz

  • x: 随机生成噪声作为初始值,batch_size就是你想生成的图片的张数,比如你想产生1k张图片
  • t_batch: 就是说对x的去噪是批处理进行的,我们的目的是 x T , x T − 1 , x T − 2 . . . x 1 , x 0 x_T, x_{T-1},x_{T-2}...x_1,x_0 xT,xT1,xT2...x1,x0, 因为x是有batch_size个,t_batch就是让这batch_size张图片同时去噪
  • remove_noise: 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) \frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t}-\frac{1-\alpha_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \boldsymbol{\epsilon}_{\theta}\left(\mathbf{x}_{t}, t\right)\right) αt 1(xt1αˉt 1αtϵθ(xt,t))
  • if t>0: 就是加上一个随机噪声 σ t z \sigma_{t} \mathbf{z} σtz,为什么在采样的过程中还要加上一个随机噪声呢?为了模拟布朗运动的随机性,当t=0时,说明已经到了 x 0 x_0 x0了,即最后一步得到原图了,对于原图就不需要再加噪声了!

有条件的训练和采样

训练

有条件的训练过程就是把标签y 也加入到图片中去进行训练

self.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else None

out += self.class_bias(y)[:, :, None, None]

y:就是标签,通过nn.Embedding可以把y表示成[batch_size, out_channels],再通过[:, :, None, None]来进行扩维,将y变为[batch_size, out_channels, 1, 1], 然后加入到经过各种操作处理后的x中(也即加入到out中)

这里的对y的操作很像对时间t的操作

采样

        if args.use_labels:
            for label in range(10):
                # 这个就是假设每一类的数量都是一样的,所以在生成标签的时候,每一类的标签y的数量是一样的
                # 比如我们想生成1k个图片,label一共有10种,所以每一类有100张
                y = torch.ones(args.num_images // 10, dtype=torch.long, device=device) * label
                samples = diffusion.sample(args.num_images // 10, device, y=y)

                for image_id in range(len(samples)):
                    image = ((samples[image_id] + 1) / 2).clip(0, 1)
                    torchvision.utils.save_image(image, f"{args.save_dir}/{label}-{image_id}.png")

采样的过程就是去噪的过程,这个去除的噪声的大小就是用我们训练好的模型预测出来的噪声,因为对于有条件的生成,我们在训练的过程中是加入了label的,所以在生成的时候我们也可以加入label,来指定噪声图片一步步去噪得到 x 0 x_0 x0,那这个 x 0 x_0 x0就更有可能属于指定的label的类别。

有条件生成和无条件生的对比

假设:原始的训练数据集中有猫,狗,猪,三类,这三类的占比分别为0.2. 0.3 0.5

  • 有条件生成:
    我们可以指定生成哪一类,比如生成1k张图片,我们指定label=猫,那生成的1k张图片大约999+张都是猫

  • 无条件生成:
    不能指定生成哪一类,比如生成1k张图片,这1k张图片大约有200张是猫,300张是狗,500张是猪

  • 4
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值