DDIM模型代码解析(二)

在这一部分中,我们主要讲解Diffusion类的代码。代码在runners/diffusion.py中。

class Diffusion(object):
    def __init__(self, args, config, device=None):
        ...

    def train(self):
        ...

    def sample(self):
        ...

    def sample_fid(self, model):
        ...

    def sample_sequence(self, model):
        ...

    def sample_interpolation(self, model):
        ...

    def sample_image(self, x, model, last=True):
        ...

    def test(self):
        pass

Diffusion类初始化

在初始化函数中除了获取基本的一些信息作为自身的属性保存下来外,还计算了论文中的 \beta_t、\bar{\alpha}_{t-1}、\bar{\alpha}_t\beta_t

在 DDPM 的论文中,选择的方差为 \sigma_t^2=\beta_t、\sigma_t^2 = \tilde{\beta}_t = \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}\beta_t 两种,在代码中也能找到对应的运算,对应配置文件中的fixedlarge和fixedsmall。

由于神经网络的输出是在正负无穷之间的,所以我们经过exp运算转换到正数区间,也就是方差所在区间上,因此我们将预测方差转换为预测方差的对数。

经过初始化之后我们可以得到如下的属性:

  • args
  • config
  • device
  • model_var_type
  • betas
  • num_timesteps
  • logvar
class Diffusion(object):
    def __init__(self, args, config, device=None):
        self.args = args  # 基本上与设定的命令行传入的参数一致, 会多一些中间得到的有用参数
        self.config = config  # 这里就是对应数据集config中的配置文件, 不是yaml格式了, 是argparse.Namespace格式
        if device is None:  # 如果没有指定device, 则自动选择device
            device = (
                torch.device("cuda")
                if torch.cuda.is_available()
                else torch.device("cpu")
            )
        self.device = device

        self.model_var_type = config.model.var_type  # 模型的方差类型选择
        betas = get_beta_schedule(  # 得到t=0~t=T时的\beta
            beta_schedule=config.diffusion.beta_schedule,  # 选择\beta是按照什么规律变化的
            beta_start=config.diffusion.beta_start,  # 在t=0时的\beta
            beta_end=config.diffusion.beta_end,  # 在t=T时的\beta
            num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps,  # 扩散步数
        )
        betas = self.betas = torch.from_numpy(betas).float().to(self.device)
        self.num_timesteps = betas.shape[0]  # 扩散步数

        alphas = 1.0 - betas  # 得到t=0~t=T时的\alpha
        alphas_cumprod = alphas.cumprod(dim=0)  # 得到t=0~t=T时的\bar{\alpha}_s
        alphas_cumprod_prev = torch.cat(  # 得到t=0~t=T时的\bar{\alpha}_{s-1}
            [torch.ones(1).to(device), alphas_cumprod[:-1]], dim=0
        )
        posterior_variance = (  # 后验方差: DDPM中的\tilde{\beta}_t
            betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        )
        if self.model_var_type == "fixedlarge":  # 按上界算: DDPM中方差为\beta_t
            self.logvar = betas.log()  # 转为预测方差的对数, (-inf, inf), 经过exp运算回到正数的var
            # torch.cat(
            # [posterior_variance[1:2], betas[1:]], dim=0).log()
        elif self.model_var_type == "fixedsmall":  # 按下界算: DDPM中方差\tilde{\beta}_t
            self.logvar = posterior_variance.clamp(min=1e-20).log()  # 进行截断防止0处为+inf

    ...

其中获取 \large \beta_t 的函数为get_beta_schedule:

代码提供了5中获取beta序列的方法。

  def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
    def sigmoid(x):  # 定义一个sigmoid函数
        return 1 / (np.exp(-x) + 1)

    if beta_schedule == "quad":  # 二次方增长
        betas = (
            np.linspace(
                beta_start ** 0.5,
                beta_end ** 0.5,
                num_diffusion_timesteps,
                dtype=np.float64,
            )
            ** 2
        )
    elif beta_schedule == "linear":  # 线性增长
        betas = np.linspace(
            beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
        )
    elif beta_schedule == "const":  # 常数(t=T时的\beta)
        betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
    elif beta_schedule == "jsd":  # 1/T, 1/(T-1), 1/(T-2), ..., 1
        betas = 1.0 / np.linspace(
            num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
        )
    elif beta_schedule == "sigmoid":  # sigmoid增长
        betas = np.linspace(-6, 6, num_diffusion_timesteps)
        betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
    else:
        raise NotImplementedError(beta_schedule)  # 报错
    assert betas.shape == (num_diffusion_timesteps,)
    return betas

Diffusion类的train函数

train函数分为四部分:

  • 准备工作:基本的数据集操作、数据加载器操作、模型构建、优化器选择
  • EMA设定
  • 继续训练设定
  • 开始训练

准备工作

这里都是很常规的一些操作。有关数据集加载、模型结构会在后面的几部分单独拿出来解析。

class Diffusion(object):
    ...    

    def train(self):
        args, config = self.args, self.config
        tb_logger = self.config.tb_logger  # 获取tensorboard的SummaryWriter
        dataset, test_dataset = get_dataset(args, config)  # 获取测试与训练数据集
        train_loader = data.DataLoader(
            dataset,
            batch_size=config.training.batch_size,
            shuffle=True,
            num_workers=config.data.num_workers,
        )
        model = Model(config)  # 根据config实例化U-Net模型

        model = model.to(self.device)  # 将模型送到device上
        model = torch.nn.DataParallel(model)

        optimizer = get_optimizer(self.config, model.parameters())  # 设置模型参数使用的优化器

    ...

最后获取优化器是通过下面的函数实现的,函数在functions/__init__.py中:

def get_optimizer(config, parameters):  # 根据config文件内容, 为模型参数选择优化器
    if config.optim.optimizer == 'Adam':
        return optim.Adam(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay,
                          betas=(config.optim.beta1, 0.999), amsgrad=config.optim.amsgrad,
                          eps=config.optim.eps)
    elif config.optim.optimizer == 'RMSProp':
        return optim.RMSprop(parameters, lr=config.optim.lr, weight_decay=config.optim.weight_decay)
    elif config.optim.optimizer == 'SGD':
        return optim.SGD(parameters, lr=config.optim.lr, momentum=0.9)
    else:
        raise NotImplementedError(
            'Optimizer {} not understood.'.format(config.optim.optimizer))

EMA设定

有关EMA的内容也放在后面的部分中进行解析。

class Diffusion(object):
    ...

    def train(self):
        ...

        if self.config.model.ema:  # 如果模型需要EMA
            ema_helper = EMAHelper(mu=self.config.model.ema_rate)
            ema_helper.register(model)  # 传输模型, 初始化EMA
        else:
            ema_helper = None

        ...

    ...

继续训练设定

主要操作就是把模型参数、优化器参数、上一次的epoch数和step数加载到现在的训练中。

class Diffusion(object):
    ...

    def train(self):
        ...

        start_epoch, step = 0, 0
        if self.args.resume_training:
            states = torch.load(os.path.join(self.args.log_path, "ckpt.pth"))  # 加载训练相关参数
            model.load_state_dict(states[0])  # 将模型参数传入网络中

            states[1]["param_groups"][0]["eps"] = self.config.optim.eps
            optimizer.load_state_dict(states[1])  # 将优化器参数传入优化器中
            start_epoch = states[2]  # 开始epoch数
            step = states[3]  # 开始步数
            if self.config.model.ema:  # 如果设定EMA,还要加在EMA参数
                ema_helper.load_state_dict(states[4])
    
        ...
    ...

开始训练

class Diffusion(object):
    ...

    def train(self):
        ...

        for epoch in range(start_epoch, self.config.training.n_epochs):  # 开始按epoch训练
            data_start = time.time()  # 数据开始时间
            data_time = 0  # 数据时间
            for i, (x, y) in enumerate(train_loader):  # 读取数据
                n = x.size(0)  # 图像个数
                data_time += time.time() - data_start  # 读取数据所用时间
                model.train()  # 设置模型为训练模式
                step += 1  # 步数加一

                x = x.to(self.device)  # 将图像送到device上
                x = data_transform(self.config, x)  # 对图像数据进行变换
                e = torch.randn_like(x)  # 得到与图像形状一致的噪声
                b = self.betas  # 得到\beta参数

                # antithetic sampling 对偶采样
                t = torch.randint(
                    low=0, high=self.num_timesteps, size=(n // 2 + 1,)
                ).to(self.device)
                t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:n]
                loss = loss_registry[config.model.type](model, x, t, e, b)  # 计算出损失大小

                tb_logger.add_scalar("loss", loss, global_step=step)  # 在tensorboard中显示损失大小

                logging.info(  # 显示训练损失log日志
                    f"step: {step}, loss: {loss.item()}, data time: {data_time / (i+1)}"
                )

                optimizer.zero_grad()
                loss.backward()  # 反向传播

                try:  # 梯度裁剪
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(), config.optim.grad_clip
                    )
                except Exception:
                    pass
                optimizer.step()  # 更新参数

                if self.config.model.ema:  # 如果采用EMA
                    ema_helper.update(model)  # 对模型参数进行更新

                if step % self.config.training.snapshot_freq == 0 or step == 1:  # 定期存储模型等相关参数
                    states = [
                        model.state_dict(),
                        optimizer.state_dict(),
                        epoch,
                        step,
                    ]
                    if self.config.model.ema:
                        states.append(ema_helper.state_dict())

                    torch.save(
                        states,
                        os.path.join(self.args.log_path, "ckpt_{}.pth".format(step)),
                    )
                    torch.save(states, os.path.join(self.args.log_path, "ckpt.pth"))

                data_start = time.time()  # 数据开始时间更新

    ...

其中对于损失大小的计算在functions/losses.py中:

def noise_estimation_loss(model,
                          x0: torch.Tensor,  # 原始图像
                          t: torch.LongTensor,  # 时刻t
                          e: torch.Tensor,  # 高斯噪声
                          b: torch.Tensor, keepdim=False):  # \beta参数
    a = (1-b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1)
    x = x0 * a.sqrt() + e * (1.0 - a).sqrt()  # 得到加噪后的x图像
    output = model(x, t.float())  # 将加噪后的图像以及时间t送入网络得到输出——预测的噪声
    if keepdim:
        return (e - output).square().sum(dim=(1, 2, 3))  # 计算实际噪声与预测噪声的偏差
    else:
        return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0)


loss_registry = {
    'simple': noise_estimation_loss,
}

后续将循环按epoch进行训练。

  • 4
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值