AIGC笔记--DDIM的简单实现

1--DDIM介绍

原论文:DENOISING DIFFUSION IMPLICIT MODELS

2--核心代码

# ddim的实现
def compute_alpha(beta, t):
    beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0) # beta -> [1, beta]
    # 先通过cumprod计算累乘结果,即: alpha_(t)_hat = alpha_(t) * alpha_(t-1) * ... * alpha_1 * alpha_0
    # 再选取alpha_(t)_hat, 这里用索引t+1来选取
    a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
    return a

# ddim的实现, 参考: https://github.com/ermongroup/ddim/blob/main/functions/denoising.py
def generalized_steps(x, seq, model, b, **kwargs):
    with torch.no_grad():
        n = x.size(0) # batchsize
        seq_next = [-1] + list(seq[:-1]) # t-skip: [-1, 0, 10, 20, ..., 980], len: 100
        x0_preds = []
        xs = [x]
        for i, j in zip(reversed(seq), reversed(seq_next)): # i = t, j = t-skip
            t = (torch.ones(n) * i).to(x.device) # t
            next_t = (torch.ones(n) * j).to(x.device) # t-1
            at = compute_alpha(b, t.long()) # alpha_(t)_hat
            at_next = compute_alpha(b, next_t.long()) # alpha_(t-1)_hat
            xt = xs[-1].to('cuda') # 获取当前时间步的样本,即x_t
            et = model(xt, t) # 预测噪声
            x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() # 论文公式(12)中的 predicted x0
            x0_preds.append(x0_t.to('cpu')) # 记录当前时间步的 predicted x0
            c1 = (kwargs.get("eta", 0) * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()) # 计算公式(12)中的标准差(\sigma)_(t)
            c2 = ((1 - at_next) - c1 ** 2).sqrt() # 论文公式(12)中 direction pointing to xt 的系数
            xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et # 根据公式(12)计算x_(t-1)
            xs.append(xt_next.to('cpu')) # 记录每一个时间步的x_(t-1)

    return xs, x0_preds # 保存了每一个时间步的结果

3--完整代码

DDIM_Demo

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值