diffusion models 扩散模型公式推导,原理分析与代码(二)

接上一节diffusion models 扩散模型公式推导,原理分析与代码(一)

我们还不知道 p θ ( x t − 1 ∣ x t ) p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right) pθ(xt1xt)是什么形式,扩散模型的第一篇文章给出其同样也服从某个高斯分布,这个好像是从热动力学那里得到证明的,不做深入解释,我们现在要求解的就是其服从的分布的均值和方差是什么,才能够满足将损失函数最小化的要求,原文中给出的 p θ ( x t − 1 ∣ x t ) p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right) pθ(xt1xt)的形式为:
p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)=\mathcal{N}\left(\mathbf{x}_{t-1} ; \boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right), \mathbf{\Sigma}_\theta\left(\mathbf{x}_t, t\right)\right) pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t))

来看损失函数的第二项 ∑ t = 2 T D K L ( q ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) \sum_{t=2}^T D_{K L}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\right) t=2TDKL(q(xt1xt,x0)pθ(xt1xt)),为了方便,用 L t L_t Lt表示,两个高斯分布计算的KL散度为两个分布均值的L2损失(前面有个系数),这个已经被证明过了,并且很容易推导出来,在这里就不推了,我们将第二项的散度展开之后应该是:
L t = E x 0 , ϵ [ 1 2 Σ θ ( x t , t ) 2 ∥ μ ~ t ( x t , x 0 ) − μ θ ( x t , t ) ∥ 2 ] \begin{aligned} L_t&=\mathbb{E}_{\mathbf{x}_0,{\epsilon}}\left[\frac{1}{2 \mathbf{\Sigma}_\theta\left(\mathbf{x}_t, t\right)^2}\left\|\tilde{\mu}_t\left(\mathbf{x}_t, \mathbf{x}_0\right)-\mu_\theta\left(\mathbf{x}_t, t\right)\right\|^2\right] \\ \end{aligned} Lt=Ex0,ϵ[2Σθ(xt,t)21μ~t(xt,x0)μθ(xt,t)2]
对于 μ ~ ( x t , x 0 ) = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ t x 0 \tilde{\mu}\left(\mathbf{x}_t, \mathbf{x}_0\right)=\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t} \mathbf{x}_t+\frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1-\bar{\alpha}_t} \mathbf{x}_0 μ~(xt,x0)=1αˉtαt (1αˉt1)xt+1αˉtαˉt1 βtx0,我们将它表示成只有 x t \mathbf{x}_t xt的形式,根据前向过程推导的 x t = α ˉ t x 0 + 1 − α ˉ t ϵ \mathbf{x}_t=\sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \boldsymbol\epsilon xt=αˉt x0+1αˉt ϵ,带入可以得到 μ ~ ( x t , x 0 ) = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ t ) \tilde{\mu}\left(\mathbf{x}_t, \mathbf{x}_0\right)=\frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \boldsymbol\epsilon_t\right) μ~(xt,x0)=αt 1(xt1αˉt 1αtϵt),相应地, μ θ ( x t , t ) \mu_\theta\left(\mathbf{x}_t, t\right) μθ(xt,t)可以表示为 μ θ ( x t , t ) = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) \mu_\theta\left(\mathbf{x}_t, t\right)=\frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \boldsymbol\epsilon_\theta\left(\mathbf{x}_t, t\right)\right) μθ(xt,t)=αt 1(xt1αˉt 1αtϵθ(xt,t)),其中 ϵ t \boldsymbol\epsilon_t ϵt表示前向过程的 t t t时刻添加的 ϵ ∼ N ( 0 , 1 ) \epsilon \sim \mathcal{N}(0, 1) ϵN(0,1)的具体噪声,也就是实际的采样值。上式变为:
L t = E x 0 , ϵ [ 1 2 Σ θ ( x t , t ) 2 ∥ μ ~ t ( x t , x 0 ) − μ θ ( x t , t ) ∥ 2 ] = E x 0 , ϵ [ 1 2 ∥ Σ θ ∥ 2 2 ∥ 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ t ) − 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) ∥ 2 ] = E x 0 , ϵ [ ( 1 − α t ) 2 2 α t ( 1 − α ˉ t ) ∥ Σ θ ∥ 2 2 ∥ ϵ t − ϵ θ ( x t , t ) ∥ 2 ] \begin{aligned} L_t&=\mathbb{E}_{\mathbf{x}_0,{\boldsymbol\epsilon}}\left[\frac{1}{2 \mathbf{\Sigma}_\theta\left(\mathbf{x}_t, t\right)^2}\left\|\tilde{\mu}_t\left(\mathbf{x}_t, \mathbf{x}_0\right)-\mu_\theta\left(\mathbf{x}_t, t\right)\right\|^2\right] \\ & =\mathbb{E}_{\mathbf{x}_0,{\boldsymbol\epsilon}}\left[\frac{1}{2\left\|\boldsymbol{\Sigma}_\theta\right\|_2^2}\left\|\frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \boldsymbol{\epsilon}_t\right)-\frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\right)\right\|^2\right] \\ & =\mathbb{E}_{\mathbf{x}_0, \boldsymbol\epsilon}\left[\frac{\left(1-\alpha_t\right)^2}{2 \alpha_t\left(1-\bar{\alpha}_t\right)\left\|\boldsymbol{\Sigma}_\theta\right\|_2^2}\left\|\boldsymbol{\epsilon}_t-\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\right\|^2\right] \\ \end{aligned} Lt=Ex0,ϵ[2Σθ(xt,t)21μ~t(xt,x0)μθ(xt,t)2]=Ex0,ϵ[2Σθ221 αt 1(xt1αˉt 1αtϵt)αt 1(xt1αˉt 1αtϵθ(xt,t)) 2]=Ex0,ϵ[2αt(1αˉt)Σθ22(1αt)2ϵtϵθ(xt,t)2]

ϵ θ ( x t , t ) \boldsymbol\epsilon_\theta\left(\mathbf{x}_t, t\right) ϵθ(xt,t)表示要用神经网络预测的值,具体来说, θ \theta θ本身作为网络的参数, ϵ θ ( x t , t ) \boldsymbol\epsilon_\theta\left(\mathbf{x}_t, t\right) ϵθ(xt,t)作为网络的预测值(输出),所以在实际训练时,我们只需要预测在不同的 t t t时刻所添加的噪声,并与真实的噪声 ϵ t \boldsymbol\epsilon_t ϵt计算L2损失,就可以不断地减小 L t L_t Lt,从而达到一开始最大化 log ⁡ p ( x 0 ) \log p(\mathbf{x}_0) logp(x0)的目标。

这里有一个地方, Σ θ \boldsymbol{\Sigma}_\theta Σθ被设置为固定值,所以它可以提出到前面的常数项中,openAI在《Improved Denoising Diffusion Probabilistic Models》的文章中对这一设定进行了修改,将其变成与参数有关的值,因此 L t L_t Lt公式有一些改动,但是本质思想不变,感兴趣的可以自己试验一下。

在2020年的《Denoising Diffusion Probabilistic Models》这篇文章中,作者在实验中发现,对于 L t L_t Lt的优化如果直接省去前面的权重项会更有利于训练:
L t = E x 0 , ϵ [ ∥ ϵ t − ϵ θ ( x t , t ) ∥ 2 ] \begin{aligned} L_t =\mathbb{E}_{\mathbf{x}_0, \epsilon}\left[\left\|\boldsymbol{\epsilon}_t-\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\right\|^2\right] \\ \end{aligned} Lt=Ex0,ϵ[ϵtϵθ(xt,t)2]
log ⁡ ( q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) ) = log ⁡ ( q ( x T ∣ x 0 ) p ( x T ) ) + ∑ t = 2 T log ⁡ ( q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) ) − log ⁡ ( p θ ( x 0 ∣ x 1 ) ) ≡ D K L ( q ( x T ∣ x 0 ) ∥ p ( x T ) ) + ∑ t = 2 T D K L ( q ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) − log ⁡ ( p θ ( x 0 ∣ x 1 ) ) = constant + E x 0 , ϵ [ ∥ ϵ t − ϵ θ ( x t , t ) ∥ 2 ] − log ⁡ ( p θ ( x 0 ∣ x 1 ) ) \begin{aligned} \log \left(\frac{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{0: T}\right)}\right) &= \log \left(\frac{q\left(\mathbf{x}_T \mid \mathbf{x}_0\right)}{p\left(\mathbf{x}_T\right)}\right)+\sum_{t=2}^T \log \left(\frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)}{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}\right)-\log \left(p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)\right) \\ & \equiv D_{K L}\left(q\left(\mathbf{x}_T \mid \mathbf{x}_0\right) \| p\left(\mathbf{x}_T\right)\right)+\sum_{t=2}^T D_{K L}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\right)-\log (p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right) )\\ &= \text{constant} + \mathbb{E}_{\mathbf{x}_0, \boldsymbol\epsilon}\left[\left\|\boldsymbol{\epsilon}_t-\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\right\|^2\right] -\log \left(p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)\right) \end{aligned} log(pθ(x0:T)q(x1:Tx0))=log(p(xT)q(xTx0))+t=2Tlog(pθ(xt1xt)q(xt1xt,x0))log(pθ(x0x1))DKL(q(xTx0)p(xT))+t=2TDKL(q(xt1xt,x0)pθ(xt1xt))log(pθ(x0x1))=constant+Ex0,ϵ[ϵtϵθ(xt,t)2]log(pθ(x0x1))
还剩下后面这一项 − log ⁡ ( p θ ( x 0 ∣ x 1 ) ) -\log \left(p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)\right) log(pθ(x0x1)),论文中用了另外的一个神经网络(原文称encoder)来预测 t 0 t_0 t0时刻图像而非噪声(预测 t 0 t_0 t0需要 t 1 t_1 t1的知识):
p θ ( x 0 ∣ x 1 ) = ∏ i = 1 D ∫ δ − ( x 0 i ) δ + ( x 0 i ) N ( x ; μ θ i ( x 1 , 1 ) , σ 1 2 ) d x δ + ( x ) = { ∞  if  x = 1 x + 1 255  if  x < 1 δ − ( x ) = { − ∞  if  x = − 1 x − 1 255  if  x > − 1 \begin{aligned} p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right) & =\prod_{i=1}^D \int_{\delta_{-}\left(x_0^i\right)}^{\delta_{+}\left(x_0^i\right)} \mathcal{N}\left(x ; \mu_\theta^i\left(\mathbf{x}_1, 1\right), \sigma_1^2\right) d x \\ \delta_{+}(x) & =\left\{\begin{array}{ll} \infty & \text { if } x=1 \\ x+\frac{1}{255} & \text { if } x<1 \end{array} \quad \delta_{-}(x)= \begin{cases}-\infty & \text { if } x=-1 \\ x-\frac{1}{255} & \text { if } x>-1\end{cases} \right. \end{aligned} pθ(x0x1)δ+(x)=i=1Dδ(x0i)δ+(x0i)N(x;μθi(x1,1),σ12)dx={x+2551 if x=1 if x<1δ(x)={x2551 if x=1 if x>1
但是在简化的版本中,作者将上式包括在了前面的 L t L_t Lt中,对应 t = 1 t=1 t=1,所以最终的优化目标为:
L = constant + E x 0 , ϵ [ ∥ ϵ t − ϵ θ ( x t , t ) ∥ 2 ] \mathcal{L}=\text{constant} + \mathbb{E}_{\mathbf{x}_0, \boldsymbol\epsilon}\left[\left\|\boldsymbol{\epsilon}_t-\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\right\|^2\right] L=constant+Ex0,ϵ[ϵtϵθ(xt,t)2]


1. 前向扩散过程

# forward process
import torch.nn.functional as F

def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)

def get_index_from_list(vals, t, x_shape):
    Returns a specific index t of a passed list of values vals
    while considering the batch dimension.
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def forward_diffusion_sample(x_0, t, device="cpu"):
    Takes an image and a timestep as input and
    returns the noisy version of it
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x_0.shape
    # mean + variance
    return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
    + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)

2. 提前计算 α \alpha α β \beta β等参数

# Define beta schedule
T = 300
betas = linear_beta_schedule(timesteps=T)

# Pre-calculate different terms for closed form
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

3. 加载一个数据集并测试一下前向过程

# test on the car dataset
from torchvision import transforms
from torch.utils.data import DataLoader
import numpy as np


def load_transformed_dataset():
    data_transforms = [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(), # Scales data into [0,1]
        transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1]
    data_transform = transforms.Compose(data_transforms)

    train = torchvision.datasets.StanfordCars(root=".", download=True,

    test = torchvision.datasets.StanfordCars(root=".", download=True,
                                         transform=data_transform, split='test')
    return torch.utils.data.ConcatDataset([train, test])

def show_tensor_image(image):
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
        transforms.Lambda(lambda t: t * 255.),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),

    # Take first image of batch
    if len(image.shape) == 4:
        image = image[0, :, :, :]

data = load_transformed_dataset()
dataloader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

# Simulate forward diffusion 可忽略
image = next(iter(dataloader))[0]

num_images = 10
stepsize = int(T/num_images)

for idx in range(0, T, stepsize):
    t = torch.Tensor([idx]).type(torch.int64)
    plt.subplot(1, num_images+1, int(idx/stepsize) + 1)
    image, noise = forward_diffusion_sample(image, t)

4. 定义损失函数

# get loss
def get_loss(model, x_0, t):
    x_noisy, noise = forward_diffusion_sample(x_0, t, device)
    noise_pred = model(x_noisy, t)
    return F.l1_loss(noise, noise_pred)

5. 采样(预测)阶段,它的预测需要不断地迭代反向过程,所以很消耗计算量

def sample_timestep(x, t):
    Calls the model to predict the noise in the image and returns
    the denoised image.
    Applies noise to this image, if we are not in the last step yet.
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)

    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
            x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)

    if t == 0:
        return model_mean
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise

def sample_plot_image():
    # Sample noise
    img_size = IMG_SIZE
    img = torch.randn((1, 3, img_size, img_size), device=device)
    plt.figure(figsize=(15, 15))
    num_images = 10
    stepsize = int(T / num_images)
    # 从 T = 200的时刻开始迭代,直到迭代到 t = 0时刻
    # 但是绘图的时候只绘制几张
    for i in range(0, T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep(img, t)
        if i % stepsize == 0:
            plt.subplot(1, int(num_images), int(i / stepsize) + 1)

from torchvision.utils import save_image
def save_sampled_image(epoch):
    # Sample noise
    img_size = IMG_SIZE
    img = torch.randn((1, 3, img_size, img_size), device=device)

    num_images = 10
    stepsize = int(T / num_images)
    # 从 T = 200的时刻开始迭代,直到迭代到 t = 0时刻
    # 但是绘图的时候只绘制几张
    for i in range(0, T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep(img, t)
        if i % stepsize == 0:
            trans = transforms.Lambda(lambda t: (t + 1) / 2)
            if len(img.shape) == 4:
                image = img[0, :, :, :]
            image = trans(image)
            save_image(image, './results/' + str(epoch) + '_' + str(i) + '.jpg')

6. 训练过程,这里把测试写在上面是因为预测 x 0 x_0 x0本身就需要对反向过程迭代采样

from torch.optim import Adam
    from implementation2.model import SimpleUnet

    model = SimpleUnet()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    optimizer = Adam(model.parameters(), lr=0.001)
    epochs = 100  # Try more!

    for epoch in range(epochs):
        for step, batch in enumerate(dataloader):

            t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
            loss = get_loss(model, batch[0], t)

            # 每隔5个epoch测试一下当前的模型
            if epoch % 5 == 0 and step == 0:
                print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")
                # sample_plot_image()

7. 模型(transformer用来编码时间信息,u-net用来编码图像)

from torch import nn
import math
import torch

class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2 * in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()

    def forward(self, x, t, ):
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimensions
        time_emb = time_emb[(...,) + (None,) * 2]
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        # TODO: Double check the ordering here
        return embeddings

class SimpleUnet(nn.Module):
    A simplified variant of the Unet architecture.

    def __init__(self):
        image_channels = 3
        down_channels = (64, 128, 256, 512, 1024)
        up_channels = (1024, 512, 256, 128, 64)
        out_dim = 1
        time_emb_dim = 32

        # Time embedding
        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb_dim, time_emb_dim),

        # Initial projection
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        # Downsample
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i + 1], \
                                          time_emb_dim) \
                                    for i in range(len(down_channels) - 1)])
        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i + 1], \
                                        time_emb_dim, up=True) \
                                  for i in range(len(up_channels) - 1)])

        self.output = nn.Conv2d(up_channels[-1], 3, out_dim)

    def forward(self, x, timestep):
        # Embedd time
        t = self.time_mlp(timestep)
        # Initial conv
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t)
        return self.output(x)




  • 2
  • 4
    觉得还不错? 一键收藏
  • 1


  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
评论 1




当前余额3.43前往充值 >
领取后你会自动成为博主和红包主的粉丝 规则
钱包余额 0


