diffusion models 扩散模型公式推导,原理分析与代码(二)

接上一节diffusion models 扩散模型公式推导,原理分析与代码(一)

我们还不知道 p θ ( x t − 1 ∣ x t ) p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right) pθ(xt1xt)是什么形式,扩散模型的第一篇文章给出其同样也服从某个高斯分布,这个好像是从热动力学那里得到证明的,不做深入解释,我们现在要求解的就是其服从的分布的均值和方差是什么,才能够满足将损失函数最小化的要求,原文中给出的 p θ ( x t − 1 ∣ x t ) p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right) pθ(xt1xt)的形式为:
p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)=\mathcal{N}\left(\mathbf{x}_{t-1} ; \boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right), \mathbf{\Sigma}_\theta\left(\mathbf{x}_t, t\right)\right) pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t))

来看损失函数的第二项 ∑ t = 2 T D K L ( q ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) \sum_{t=2}^T D_{K L}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\right) t=2TDKL(q(xt1xt,x0)pθ(xt1xt)),为了方便,用 L t L_t Lt表示,两个高斯分布计算的KL散度为两个分布均值的L2损失(前面有个系数),这个已经被证明过了,并且很容易推导出来,在这里就不推了,我们将第二项的散度展开之后应该是:
L t = E x 0 , ϵ [ 1 2 Σ θ ( x t , t ) 2 ∥ μ ~ t ( x t , x 0 ) − μ θ ( x t , t ) ∥ 2 ] \begin{aligned} L_t&=\mathbb{E}_{\mathbf{x}_0,{\epsilon}}\left[\frac{1}{2 \mathbf{\Sigma}_\theta\left(\mathbf{x}_t, t\right)^2}\left\|\tilde{\mu}_t\left(\mathbf{x}_t, \mathbf{x}_0\right)-\mu_\theta\left(\mathbf{x}_t, t\right)\right\|^2\right] \\ \end{aligned} Lt=Ex0,ϵ[2Σθ(xt,t)21μ~t(xt,x0)μθ(xt,t)2]
对于 μ ~ ( x t , x 0 ) = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ t x 0 \tilde{\mu}\left(\mathbf{x}_t, \mathbf{x}_0\right)=\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t} \mathbf{x}_t+\frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1-\bar{\alpha}_t} \mathbf{x}_0 μ~(xt,x0)=1αˉtαt (1αˉt1)xt+1αˉtαˉt1 βtx0,我们将它表示成只有 x t \mathbf{x}_t xt的形式,根据前向过程推导的 x t = α ˉ t x 0 + 1 − α ˉ t ϵ \mathbf{x}_t=\sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \boldsymbol\epsilon xt=αˉt x0+1αˉt ϵ,带入可以得到 μ ~ ( x t , x 0 ) = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ t ) \tilde{\mu}\left(\mathbf{x}_t, \mathbf{x}_0\right)=\frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \boldsymbol\epsilon_t\right) μ~(xt,x0)=αt 1(xt1αˉt 1αtϵt),相应地, μ θ ( x t , t ) \mu_\theta\left(\mathbf{x}_t, t\right) μθ(xt,t)可以表示为 μ θ ( x t , t ) = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) \mu_\theta\left(\mathbf{x}_t, t\right)=\frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \boldsymbol\epsilon_\theta\left(\mathbf{x}_t, t\right)\right) μθ(xt,t)=αt 1(xt1αˉt 1αtϵθ(xt,t)),其中 ϵ t \boldsymbol\epsilon_t ϵt表示前向过程的 t t t时刻添加的 ϵ ∼ N ( 0 , 1 ) \epsilon \sim \mathcal{N}(0, 1) ϵN(0,1)的具体噪声,也就是实际的采样值。上式变为:
L t = E x 0 , ϵ [ 1 2 Σ θ ( x t , t ) 2 ∥ μ ~ t ( x t , x 0 ) − μ θ ( x t , t ) ∥ 2 ] = E x 0 , ϵ [ 1 2 ∥ Σ θ ∥ 2 2 ∥ 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ t ) − 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) ∥ 2 ] = E x 0 , ϵ [ ( 1 − α t ) 2 2 α t ( 1 − α ˉ t ) ∥ Σ θ ∥ 2 2 ∥ ϵ t − ϵ θ ( x t , t ) ∥ 2 ] \begin{aligned} L_t&=\mathbb{E}_{\mathbf{x}_0,{\boldsymbol\epsilon}}\left[\frac{1}{2 \mathbf{\Sigma}_\theta\left(\mathbf{x}_t, t\right)^2}\left\|\tilde{\mu}_t\left(\mathbf{x}_t, \mathbf{x}_0\right)-\mu_\theta\left(\mathbf{x}_t, t\right)\right\|^2\right] \\ & =\mathbb{E}_{\mathbf{x}_0,{\boldsymbol\epsilon}}\left[\frac{1}{2\left\|\boldsymbol{\Sigma}_\theta\right\|_2^2}\left\|\frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \boldsymbol{\epsilon}_t\right)-\frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\right)\right\|^2\right] \\ & =\mathbb{E}_{\mathbf{x}_0, \boldsymbol\epsilon}\left[\frac{\left(1-\alpha_t\right)^2}{2 \alpha_t\left(1-\bar{\alpha}_t\right)\left\|\boldsymbol{\Sigma}_\theta\right\|_2^2}\left\|\boldsymbol{\epsilon}_t-\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\right\|^2\right] \\ \end{aligned} Lt=Ex0,ϵ[2Σθ(xt,t)21μ~t(xt,x0)μθ(xt,t)2]=Ex0,ϵ[2Σθ221 αt 1(xt1αˉt 1αtϵt)αt 1(xt1αˉt 1αtϵθ(xt,t)) 2]=Ex0,ϵ[2αt(1αˉt)Σθ22(1αt)2ϵtϵθ(xt,t)2]

ϵ θ ( x t , t ) \boldsymbol\epsilon_\theta\left(\mathbf{x}_t, t\right) ϵθ(xt,t)表示要用神经网络预测的值,具体来说, θ \theta θ本身作为网络的参数, ϵ θ ( x t , t ) \boldsymbol\epsilon_\theta\left(\mathbf{x}_t, t\right) ϵθ(xt,t)作为网络的预测值(输出),所以在实际训练时,我们只需要预测在不同的 t t t时刻所添加的噪声,并与真实的噪声 ϵ t \boldsymbol\epsilon_t ϵt计算L2损失,就可以不断地减小 L t L_t Lt,从而达到一开始最大化 log ⁡ p ( x 0 ) \log p(\mathbf{x}_0) logp(x0)的目标。

这里有一个地方, Σ θ \boldsymbol{\Sigma}_\theta Σθ被设置为固定值,所以它可以提出到前面的常数项中,openAI在《Improved Denoising Diffusion Probabilistic Models》的文章中对这一设定进行了修改,将其变成与参数有关的值,因此 L t L_t Lt公式有一些改动,但是本质思想不变,感兴趣的可以自己试验一下。

在2020年的《Denoising Diffusion Probabilistic Models》这篇文章中,作者在实验中发现,对于 L t L_t Lt的优化如果直接省去前面的权重项会更有利于训练:
L t = E x 0 , ϵ [ ∥ ϵ t − ϵ θ ( x t , t ) ∥ 2 ] \begin{aligned} L_t =\mathbb{E}_{\mathbf{x}_0, \epsilon}\left[\left\|\boldsymbol{\epsilon}_t-\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\right\|^2\right] \\ \end{aligned} Lt=Ex0,ϵ[ϵtϵθ(xt,t)2]
接下来,我们终于可以返回到原来的优化目标:
log ⁡ ( q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) ) = log ⁡ ( q ( x T ∣ x 0 ) p ( x T ) ) + ∑ t = 2 T log ⁡ ( q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) ) − log ⁡ ( p θ ( x 0 ∣ x 1 ) ) ≡ D K L ( q ( x T ∣ x 0 ) ∥ p ( x T ) ) + ∑ t = 2 T D K L ( q ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) − log ⁡ ( p θ ( x 0 ∣ x 1 ) ) = constant + E x 0 , ϵ [ ∥ ϵ t − ϵ θ ( x t , t ) ∥ 2 ] − log ⁡ ( p θ ( x 0 ∣ x 1 ) ) \begin{aligned} \log \left(\frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right) &= \log \left(\frac{q\left(\mathbf{x}_T \mid \mathbf{x}_0\right)}{p\left(\mathbf{x}_T\right)}\right)+\sum_{t=2}^T \log \left(\frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}\right)-\log \left(p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)\right) \\ & \equiv D_{K L}\left(q\left(\mathbf{x}_T \mid \mathbf{x}_0\right) \| p\left(\mathbf{x}_T\right)\right)+\sum_{t=2}^T D_{K L}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\right)-\log (p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right) )\\ &= \text{constant} + \mathbb{E}_{\mathbf{x}_0, \boldsymbol\epsilon}\left[\left\|\boldsymbol{\epsilon}_t-\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\right\|^2\right] -\log \left(p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)\right) \end{aligned} log(pθ(x0:T)q(x1:Tx0))=log(p(xT)q(xTx0))+t=2Tlog(pθ(xt1xt)q(xt1xt,x0))log(pθ(x0x1))DKL(q(xTx0)p(xT))+t=2TDKL(q(xt1xt,x0)pθ(xt1xt))log(pθ(x0x1))=constant+Ex0,ϵ[ϵtϵθ(xt,t)2]log(pθ(x0x1))
还剩下后面这一项 − log ⁡ ( p θ ( x 0 ∣ x 1 ) ) -\log \left(p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)\right) log(pθ(x0x1)),论文中用了另外的一个神经网络(原文称encoder)来预测 t 0 t_0 t0时刻图像而非噪声(预测 t 0 t_0 t0需要 t 1 t_1 t1的知识):
p θ ( x 0 ∣ x 1 ) = ∏ i = 1 D ∫ δ − ( x 0 i ) δ + ( x 0 i ) N ( x ; μ θ i ( x 1 , 1 ) , σ 1 2 ) d x δ + ( x ) = { ∞  if  x = 1 x + 1 255  if  x < 1 δ − ( x ) = { − ∞  if  x = − 1 x − 1 255  if  x > − 1 \begin{aligned} p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right) & =\prod_{i=1}^D \int_{\delta_{-}\left(x_0^i\right)}^{\delta_{+}\left(x_0^i\right)} \mathcal{N}\left(x ; \mu_\theta^i\left(\mathbf{x}_1, 1\right), \sigma_1^2\right) d x \\ \delta_{+}(x) & =\left\{\begin{array}{ll} \infty & \text { if } x=1 \\ x+\frac{1}{255} & \text { if } x<1 \end{array} \quad \delta_{-}(x)= \begin{cases}-\infty & \text { if } x=-1 \\ x-\frac{1}{255} & \text { if } x>-1\end{cases} \right. \end{aligned} pθ(x0x1)δ+(x)=i=1Dδ(x0i)δ+(x0i)N(x;μθi(x1,1),σ12)dx={x+2551 if x=1 if x<1δ(x)={x2551 if x=1 if x>1
但是在简化的版本中,作者将上式包括在了前面的 L t L_t Lt中,对应 t = 1 t=1 t=1,所以最终的优化目标为:
L = constant + E x 0 , ϵ [ ∥ ϵ t − ϵ θ ( x t , t ) ∥ 2 ] \mathcal{L}=\text{constant} + \mathbb{E}_{\mathbf{x}_0, \boldsymbol\epsilon}\left[\left\|\boldsymbol{\epsilon}_t-\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\right\|^2\right] L=constant+Ex0,ϵ[ϵtϵθ(xt,t)2]
也就是说,我们只需要预测每个时刻添加的噪声就可以了。

实现代码

1. 前向扩散过程

# forward process
import torch.nn.functional as F

def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)

def get_index_from_list(vals, t, x_shape):
    """
    Returns a specific index t of a passed list of values vals
    while considering the batch dimension.
    """
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def forward_diffusion_sample(x_0, t, device="cpu"):
    """
    Takes an image and a timestep as input and
    returns the noisy version of it
    """
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x_0.shape
    )
    # mean + variance
    return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
    + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)

2. 提前计算 α \alpha α β \beta β等参数

# Define beta schedule
T = 300
betas = linear_beta_schedule(timesteps=T)

# Pre-calculate different terms for closed form
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

3. 加载一个数据集并测试一下前向过程

# test on the car dataset
from torchvision import transforms
from torch.utils.data import DataLoader
import numpy as np

IMG_SIZE = 64
BATCH_SIZE = 16

def load_transformed_dataset():
    data_transforms = [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), # Scales data into [0,1]
        transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1]
    ]
    data_transform = transforms.Compose(data_transforms)

    train = torchvision.datasets.StanfordCars(root=".", download=True,
                                         transform=data_transform)

    test = torchvision.datasets.StanfordCars(root=".", download=True,
                                         transform=data_transform, split='test')
    return torch.utils.data.ConcatDataset([train, test])

def show_tensor_image(image):
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
        transforms.Lambda(lambda t: t * 255.),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])

    # Take first image of batch
    if len(image.shape) == 4:
        image = image[0, :, :, :]
    plt.imshow(reverse_transforms(image))


data = load_transformed_dataset()
dataloader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

# Simulate forward diffusion 可忽略
image = next(iter(dataloader))[0]

plt.figure(figsize=(15,15))
plt.axis('off')
num_images = 10
stepsize = int(T/num_images)

for idx in range(0, T, stepsize):
    t = torch.Tensor([idx]).type(torch.int64)
    plt.subplot(1, num_images+1, int(idx/stepsize) + 1)
    image, noise = forward_diffusion_sample(image, t)
    show_tensor_image(image)
plt.show()

4. 定义损失函数

# get loss
def get_loss(model, x_0, t):
    x_noisy, noise = forward_diffusion_sample(x_0, t, device)
    noise_pred = model(x_noisy, t)
    return F.l1_loss(noise, noise_pred)

5. 采样(预测)阶段,它的预测需要不断地迭代反向过程,所以很消耗计算量

@torch.no_grad()
def sample_timestep(x, t):
    """
    Calls the model to predict the noise in the image and returns
    the denoised image.
    Applies noise to this image, if we are not in the last step yet.
    """
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)

    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
            x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)

    if t == 0:
        return model_mean
    else:
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise


@torch.no_grad()
def sample_plot_image():
    # Sample noise
    img_size = IMG_SIZE
    img = torch.randn((1, 3, img_size, img_size), device=device)
    plt.figure(figsize=(15, 15))
    plt.axis('off')
    num_images = 10
    stepsize = int(T / num_images)
    # 从 T = 200的时刻开始迭代,直到迭代到 t = 0时刻
    # 但是绘图的时候只绘制几张
    for i in range(0, T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep(img, t)
        if i % stepsize == 0:
            plt.subplot(1, int(num_images), int(i / stepsize) + 1)
            show_tensor_image(img.detach().cpu())
    plt.show()

from torchvision.utils import save_image
@torch.no_grad()
def save_sampled_image(epoch):
    # Sample noise
    img_size = IMG_SIZE
    img = torch.randn((1, 3, img_size, img_size), device=device)

    num_images = 10
    stepsize = int(T / num_images)
    # 从 T = 200的时刻开始迭代,直到迭代到 t = 0时刻
    # 但是绘图的时候只绘制几张
    for i in range(0, T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep(img, t)
        if i % stepsize == 0:
            trans = transforms.Lambda(lambda t: (t + 1) / 2)
            if len(img.shape) == 4:
                image = img[0, :, :, :]
            image = trans(image)
            save_image(image, './results/' + str(epoch) + '_' + str(i) + '.jpg')

6. 训练过程,这里把测试写在上面是因为预测 x 0 x_0 x0本身就需要对反向过程迭代采样

from torch.optim import Adam
    from implementation2.model import SimpleUnet

    model = SimpleUnet()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    optimizer = Adam(model.parameters(), lr=0.001)
    epochs = 100  # Try more!

    for epoch in range(epochs):
        for step, batch in enumerate(dataloader):
            optimizer.zero_grad()

            t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
            loss = get_loss(model, batch[0], t)
            loss.backward()
            optimizer.step()

            # 每隔5个epoch测试一下当前的模型
            if epoch % 5 == 0 and step == 0:
                print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")
                # sample_plot_image()
                save_sampled_image(epoch)

7. 模型(transformer用来编码时间信息,u-net用来编码图像)

from torch import nn
import math
import torch

class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2 * in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()

    def forward(self, x, t, ):
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimensions
        time_emb = time_emb[(...,) + (None,) * 2]
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        # TODO: Double check the ordering here
        return embeddings


class SimpleUnet(nn.Module):
    """
    A simplified variant of the Unet architecture.
    """

    def __init__(self):
        super().__init__()
        image_channels = 3
        down_channels = (64, 128, 256, 512, 1024)
        up_channels = (1024, 512, 256, 128, 64)
        out_dim = 1
        time_emb_dim = 32

        # Time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU()
        )

        # Initial projection
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        # Downsample
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i + 1], \
                                          time_emb_dim) \
                                    for i in range(len(down_channels) - 1)])
        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i + 1], \
                                        time_emb_dim, up=True) \
                                  for i in range(len(up_channels) - 1)])

        self.output = nn.Conv2d(up_channels[-1], 3, out_dim)

    def forward(self, x, timestep):
        # Embedd time
        t = self.time_mlp(timestep)
        # Initial conv
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t)
        return self.output(x)

这两篇只是一些最基础的理论,个人感觉了解事情的来龙去脉还是很重要的,了解事物的出发点是什么,后面不管是利用模型也好还是利用模型的思想也好,都有助于在思考问题时更加深刻与灵活。

当年GAN刚兴起的时候,YouTube上有个热门评论称GAN是过去20年来深度学习中最酷的想法,但是后面的研究逐渐发现了GAN存在的很多问题,这项思想逐渐变得不像它刚刚兴起时那样完美,不知道扩散模型后面会不会也走相同的路线😗,但是有一个问题已经开始显现了,那就是对大模型和计算资源的渴求,预测过程的独特性要求对很多时间步骤迭代采样,计算量很大,不知道后面会怎么发展。

完结撒花🍀🍀🍀🍀

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值