DDPM代码案例详解(附:论文模型)

DDPM代码详解

简单案例

一、初始化数据集

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_s_curve
import torch

# 生成S曲线数据,大小为10^4个点,噪声水平为0.1
s_curve,_ = make_s_curve(10**4,noise=0.1)
# 只取X和Z坐标(第二个和第四个坐标),并缩小10倍,使数据更加紧凑
s_curve = s_curve[:,[0,2]]/10.0

print("shape of s:",np.shape(s_curve))  # => (10000, 2)

# 将数据转置,以符合matplotlib的预期输入格式(每列一个维度)
data = s_curve.T

fig,ax = plt.subplots()
# 在轴上绘制散点图,颜色为蓝色,边缘颜色为白色
ax.scatter(*data,color='blue',edgecolor='white');
# 关闭坐标轴的显示
ax.axis('off')

# 将numpy数组s_curve转换为PyTorch张量,确保数据类型为float
dataset = torch.Tensor(s_curve).float()

在这里插入图片描述

二、设置超参数

超参数解释

  1. num_steps T T T
  2. betas β \beta β
  3. alphas α = 1 − β \alpha = 1 - \beta α=1β
  4. alphas_prod α ˉ = ∏ i T α i \bar{\alpha} = \prod\limits_{i}^{T}{\alpha_i} αˉ=iTαi
  5. alphas_bar_sqrt α ˉ \sqrt{\bar{\alpha}} αˉ
  6. one_minus_alphas_bar_log log ⁡ e 1 − α ˉ \log_{e}{1 - \bar{\alpha}} loge1αˉ
  7. one_minus_alphas_bar_sqrt 1 − α ˉ \sqrt{1 - \bar{\alpha}} 1αˉ
# 设置扩散过程的总步数
num_steps = 100

# 生成从-6到6的等差数列,用作beta的初始值
betas = torch.linspace(-6, 6, num_steps)
# 使用Sigmoid函数将betas压缩到(0, 1)区间,并进行缩放和平移
# 这样betas的值域在(1e-5, 0.5e-2)之间,这是一个常见的设置,用于控制扩散过程
betas = torch.sigmoid(betas) * (0.5e-2 - 1e-5) + 1e-5
# print(betas.detach().numpy())

# alphas是每一步中未被破坏的信息的比例,计算方式为1减去betas
alphas = 1 - betas
# 计算alphas的累积乘积,得到alpha_bar
alphas_prod = torch.cumprod(alphas, 0)    # alphas沿着维度0进行累积乘积
# 创建一个包含初始值1的向量,与alphas_prod的形状相同,用于计算alpha_bar_sqrt,这个后面并没有用到
alphas_prod_p = torch.cat([torch.tensor([1]).float(), alphas_prod[:-1]], 0)
# 计算alphas_prod的平方根,得到alpha_bar_sqrt,用于后续的计算
alphas_bar_sqrt = torch.sqrt(alphas_prod)
# 计算(1 - alphas_prod)的自然对数,这是为了计算one_minus_alphas_bar_sqrt
one_minus_alphas_bar_log = torch.log(1 - alphas_prod)
# 计算(1 - alphas_prod)的平方根,得到one_minus_alphas_bar_sqrt,用于后续的计算
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)

# 确保所有计算出的alpha相关变量的形状都是相同的,这对于后续的计算很重要
assert alphas.shape == alphas_prod.shape == alphas_prod_p.shape == \
       alphas_bar_sqrt.shape == one_minus_alphas_bar_log.shape == \
       one_minus_alphas_bar_sqrt.shape
# 打印betas的形状,确认其长度与num_steps相符
print("all the same shape", betas.shape)     # => all the same shape torch.Size([100])

三、扩散函数

x t = α t ˉ ⋅ x 0 + 1 − α t ˉ ⋅ z x_t = \sqrt{\bar{\alpha_t}}\cdot x_0 + \sqrt{1 - \bar{\alpha_t}}\cdot z xt=αtˉ x0+1αtˉ z

#计算任意时刻的x采样值,基于x_0和重参数化
def q_x(x_0, t):
    """可以基于x[0]得到任意时刻t的x[t]"""
    noise = torch.randn_like(x_0)   # 从正态分布中进行采样
    alphas_t = alphas_bar_sqrt[t]	# 得到t时刻的alpha_bar_sqrt
    alphas_1_m_t = one_minus_alphas_bar_sqrt[t]    # 得到t时刻的one_minus_alphas_bar_sqrt
    return (alphas_t * x_0 + alphas_1_m_t * noise)    #在x[0]的基础上添加噪声

四、绘制扩散图片

# 设置要显示的图像数量
num_shows = 20
# 创建一个2行10列的子图布局,每个子图用于显示一个图像
fig,axs = plt.subplots(2,10,figsize=(28,3))
# 设置文本颜色为黑色
plt.rc('text',color='black')

#共有10000个点,每个点包含两个坐标
#生成100步以内每隔5步加噪声后的图像
for i in range(num_shows):
    # 计算当前图像在子图布局中的位置
    # j是行索引,k是列索引
    j = i//10  # 整除10得到行索引
    k = i%10   # 取余10得到列索引
    
    # i*num_steps//num_shows用于每隔5步生成一个采样数据,得到时间t
    one_t = i*num_steps//num_shows
    
    q_i = q_x(dataset,torch.tensor([one_t]))  #生成t时刻的采样数据
    # 在对应的子图中绘制采样数据的散点图
    # 散点的颜色设置为红色,边缘颜色为白色
    axs[j,k].scatter(q_i[:,0],q_i[:,1],color='red',edgecolor='white')
    # 关闭子图的坐标轴显示
    axs[j,k].set_axis_off()
    # LaTeX格式的字符串用于更美观地显示数学表达式
    axs[j,k].set_title('$q(\mathbf{x}_{'+str(one_t)+'})$')

在这里插入图片描述

五、神经网络模型

import torch
import torch.nn as nn

class MLPDiffusion(nn.Module):
    def __init__(self, n_steps, num_units=128):
        super(MLPDiffusion, self).__init__()
        
        self.linears = nn.ModuleList(
            [
                nn.Linear(2, num_units),
                nn.ReLU(),
                nn.Linear(num_units, num_units),
                nn.ReLU(),
                nn.Linear(num_units, num_units),
                nn.ReLU(),
                nn.Linear(num_units, 2),
            ]
        )
        # 每个嵌入层将时间步t映射到num_units维的向量,对时间步进行编码
        # 三个Embedding层可以学习到更复杂的时间步表示。每个嵌入层可以捕捉时间步的不同方面或特征。
        self.step_embeddings = nn.ModuleList(
            [
                nn.Embedding(n_steps,num_units),
                nn.Embedding(n_steps,num_units),
                nn.Embedding(n_steps,num_units),
            ]
        )
    def forward(self,x,t):
        for idx, embedding_layer in enumerate(self.step_embeddings):
            # 对时间步进行编码
            t_embedding = embedding_layer(t)
            # 使图片通过线性层
            x = self.linears[2*idx](x)
            # 将时间步与图片数据进行融合
            x += t_embedding
            # 使图片通过激活函数层
            x = self.linears[2*idx+1](x)
            
        x = self.linears[-1](x)
        
        return x

六、损失函数

L s i m p l e ( θ ) = E t , x 0 , ϵ [ ∣ ∣ ϵ − ϵ θ ( α t ˉ x 0 + 1 − α t ˉ ϵ , t ) ∣ ∣ 2 ] L_{simple}(\theta) = E_{t, x_0, \epsilon}[|| \epsilon - \epsilon_\theta(\sqrt{\bar{\alpha_t}}x_0 + \sqrt{1 - \bar{\alpha_t}}\epsilon, t)||^2] Lsimple(θ)=Et,x0,ϵ[∣∣ϵϵθ(αtˉ x0+1αtˉ ϵ,t)2]

def diffusion_loss_fn(model,x_0,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,n_steps):
    """
        计算扩散模型的损失函数。

        参数:
        - model: 扩散模型,输入为噪声数据和时间步,输出为去噪数据。
        - x_0: 原始数据,维度为[batch_size, 2]。
        - alphas_bar_sqrt: alpha的累积乘积的平方根,用于调整噪声水平。
        - one_minus_alphas_bar_sqrt: (1 - alpha的累积乘积)的平方根,用于调整噪声水平。
        - n_steps: 扩散过程中的总步数。

        返回:
        - 损失值,表示模型预测与真实噪声之间的差异。
    """
    # 得到原始数据中数据的个数
    batch_size = x_0.shape[0]
    
    # 对一个batch size的样本生成随机的时刻t,生成一半的步数
    t = torch.randint(0, n_steps, size=(batch_size//2,))
    # 为了对称性,生成另一半步数的对应时刻
    t = torch.cat([t, n_steps-1-t], dim=0)
    # 增加一个维度,使其与x_0的维度匹配
    t = t.unsqueeze(-1)
    
    # 公式中x0的系数
    a = alphas_bar_sqrt[t]
    
    # 公式中epsilon的系数
    aml = one_minus_alphas_bar_sqrt[t]
    
    # 从正态分布中,生成与x_0大小一致的随机噪音epsilon
    e = torch.randn_like(x_0)
    
    # 构造模型的输入
    x = x_0 * a + e * aml
    
    # 送入模型,得到 t 时刻的随机噪声预测值
    output = model(x, t.squeeze(-1))
    
    # 与真实噪声一起计算误差,求平均值
    return (e - output).square().mean()

七、逆扩散采样函数

x t − 1 = 1 α t ( x t − 1 − α t 1 − α t ˉ ϵ θ ( x t , t ) ) + σ t z x_{t - 1} = \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha_t}}}\epsilon_\theta(x_t, t)) + \sigma_tz xt1=αt 1(xt1αtˉ 1αtϵθ(xt,t))+σtz

def p_sample_loop(model,shape,n_steps,betas,one_minus_alphas_bar_sqrt):
    """
        通过扩散模型从噪声数据恢复原始数据序列。

        参数:
        - model: 扩散模型,用于从噪声数据中预测原始数据。
        - shape: 恢复数据的形状,例如(batch_size, num_features)。
        - n_steps: 扩散过程中的总步数。
        - betas: 每一步的beta值,用于控制噪声水平。
        - one_minus_alphas_bar_sqrt: (1 - alphas_prod)的平方根,用于调整噪声水平。

        返回:
        - x_seq: 包含每一步恢复数据的列表。
    """
    # 从标准正态分布生成初始噪声数据
    cur_x = torch.randn(shape)
    x_seq = [cur_x]
    for i in reversed(range(n_steps)):	# 反转可迭代序列,从X_T ~ X_0
        cur_x = p_sample(model, cur_x, i, betas, one_minus_alphas_bar_sqrt)
        x_seq.append(cur_x)
    return x_seq

def p_sample(model,x,t,betas,one_minus_alphas_bar_sqrt):
    """
        从给定的噪声数据采样t时刻的重构值。

        参数:
        - model: 扩散模型,用于预测t时刻的噪声。
        - x: 当前时刻的噪声数据。
        - t: 当前时间步。
        - betas: 每一步的beta值,用于控制噪声水平。
        - one_minus_alphas_bar_sqrt: (1 - alphas_prod)的平方根,用于调整噪声水平。

        返回:
        - sample: t时刻的重构数据样本。
    """
    # 转换为tensor类型的数据
    t = torch.tensor([t])
    # 计算公式中epsilon_theta的系数
    coeff = betas[t] / one_minus_alphas_bar_sqrt[t]
    # 从模型中进行采样
    eps_theta = model(x,t)
    # 得到前一个时间步分布的均值
    mean = (1/(1-betas[t]).sqrt())*(x-(coeff*eps_theta))
    
    # 从标准正态分布中进行随机采样
    z = torch.randn_like(x)
    # 得到前一个时间步分布的方差
    sigma_t = betas[t].sqrt()
    # 得到前一个时间步的数据
    sample = mean + sigma_t * z
    
    return (sample)

在这里插入图片描述

八、训练模型

# 设置随机种子以保证结果的可重复性
seed = 1234

class EMA():
    """构建一个参数平滑器"""
    def __init__(self, mu=0.01):
        # 平滑系数
        self.mu = mu
        # 存储平滑后的参数
        self.shadow = {}
        
    def register(self,name,val):
        # 注册参数的初始值
        self.shadow[name] = val.clone()
        
    def __call__(self,name,x):
        # 根据平滑系数更新参数
        assert name in self.shadow
        new_average = self.mu * x + (1.0-self.mu)*self.shadow[name]
        self.shadow[name] = new_average.clone()
        return new_average
    
print('Training model...')
# 设置批量大小
batch_size = 128
# 创建数据加载器
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 设置训练轮数
num_epoch = 4000
# 设置matplotlib文本颜色
plt.rc('text',color='blue')

# 实例化MLPDiffusion模型,num_steps已经定义为100
model = MLPDiffusion(num_steps)
# 实例化优化器
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)

for t in range(num_epoch):
    # 输出训练进度信息
    process = t / num_epoch * 100
    if t % 50 == 0:
        print(f'{int(process)}%')
        
    # 每个epoch中的批次循环
    for idx,batch_x in enumerate(dataloader):
        # 计算扩散损失函数
        loss = diffusion_loss_fn(model,batch_x,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,num_steps)
        # 清除之前的梯度
        optimizer.zero_grad()
        # 反向传播,计算当前损失的梯度
        loss.backward()
        # 梯度裁剪,防止梯度爆炸
        torch.nn.utils.clip_grad_norm_(model.parameters(),1.)
        # 使用优化器更新模型参数
        optimizer.step()
    
    # 每100轮进行绘图验证一下效果
    if(t%100==0):
        # 打印当前轮的损失值
        print(loss)
        # 使用模型生成样本序列
        x_seq = p_sample_loop(model,dataset.shape,num_steps,betas,one_minus_alphas_bar_sqrt)
        # 绘制样本序列的图形
        fig,axs = plt.subplots(1,10,figsize=(28,3))
        for i in range(1,11):
            cur_x = x_seq[i*10].detach()
            axs[i-1].scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white');
            axs[i-1].set_axis_off();
            axs[i-1].set_title('$q(\mathbf{x}_{'+str(i*10)+'})$')

九、绘制扩散过程

# 导入所需的库
import io
from PIL import Image

# 初始化一个空列表,用于存储生成的图像
imgs = []

# 循环100次,生成100个图像
for i in range(100):
    # 清除当前的绘图区域,为下一次迭代准备
    plt.clf()
    
    # 假设q_x是一个函数,它根据数据集dataset和索引i生成数据点
    # torch.tensor([i])将整数i转换为一个PyTorch张量
    q_i = q_x(dataset, torch.tensor([i]))
    
    # 使用matplotlib在当前清除的绘图区域中绘制散点图
    # q_i[:, 0] 和 q_i[:, 1] 分别表示x和y坐标
    # color设置点的颜色,edgecolor设置点的边缘颜色,s设置点的大小
    plt.scatter(q_i[:, 0], q_i[:, 1], color='red', edgecolor='white', s=5)
    
    # 关闭坐标轴的显示
    plt.axis('off')
    
    # 创建一个字节流对象,用于存储图像数据
    img_buf = io.BytesIO()
    
    # 将当前绘制的图像保存到字节流中,格式为PNG
    plt.savefig(img_buf, format='png')
    
    # 从字节流中打开图像,使用PIL库
    img = Image.open(img_buf)
    
    # 将打开的图像添加到imgs列表中
    imgs.append(img)

十、绘制逆扩散过程

# 初始化一个空列表,用于存储生成的图像,这个列表将被逆转
reverse = []

# 循环100次,生成100个图像
for i in range(100):
    # 清除当前的绘图区域,为绘制新图像做准备
    plt.clf()
    
    # x_seq是最后一次采样的时候的所有图片数据
    # 这里我们从列表中取出第i个元素,并使用detach()方法从计算图中分离
    cur_x = x_seq[i].detach()
    
    # 使用matplotlib绘制散点图
    # cur_x[:, 0] 和 cur_x[:, 1] 分别是数据点的x和y坐标
    # color设置点的颜色,edgecolor设置点的边缘颜色,s设置点的大小
    plt.scatter(cur_x[:, 0], cur_x[:, 1], color='red', edgecolor='white', s=5)
    
    # 关闭坐标轴的显示,使图像更干净
    plt.axis('off')
    
    # 创建一个字节流对象,用于存储图像数据
    img_buf = io.BytesIO()
    
    # 将当前绘制的图像保存到字节流中,格式为PNG
    plt.savefig(img_buf, format='png')
    
    # 从字节流中使用PIL打开图像
    img = Image.open(img_buf)
    
    # 将打开的图像添加到reverse列表中
    reverse.append(img)

十一、扩散过程与逆扩散过程相加

imgs = imgs + reverse

十二、扩散过程与逆扩散过程保存为GIF动画

# imgs[0] 表示之前生成的图像列表中的第一个图像
# save 方法用于将图像保存为文件,这里保存为 GIF 格式
imgs[0].save(
    "diffusion.gif",  # 指定保存的文件名为 "diffusion.gif"
    format='GIF',      # 指定保存的文件格式为 GIF
    append_images=imgs,  # 将 imgs 列表中的所有图像追加到第一个图像上
    save_all=True,      # 指示 PIL 保存所有帧为 GIF 动画
    duration=100,       # 设置 GIF 中每帧的持续时间(以毫秒为单位),这里是 100 毫秒
    loop=0              # 设置 GIF 的循环次数,0 表示无限循环
)

附:论文中的神经网络模型

模型的整体结构

在这里插入图片描述

SiLU激活函数和归一化的选择

class SiLU(nn.Module):
    # SiLU激活函数
    @staticmethod
    def forward(x):
        return x * torch.sigmoid(x)


def get_norm(norm, num_channels, num_groups):
    if norm == "in":
        return nn.InstanceNorm2d(num_channels, affine=True)
    elif norm == "bn":
        return nn.BatchNorm2d(num_channels)
    elif norm == "gn":
        return nn.GroupNorm(num_groups, num_channels)
    elif norm is None:
        return nn.Identity()
    else:
        raise ValueError("unknown normalization type")

位置编码层

P E ( p o s , 2 i ) = sin ⁡ ( p o s / 1000 0 2 i / d m o d e l ) P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s / 1000 0 2 i / d m o d e l ) PE_{(pos, 2i)} = \sin{(pos / 10000^{2i/d_{model}})} \\ PE_{(pos, 2i + 1)} = \cos{(pos / 10000^{2i/d_{model}})} PE(pos,2i)=sin(pos/100002i/dmodel)PE(pos,2i+1)=cos(pos/100002i/dmodel)

class PositionalEmbedding(nn.Module):
    def __init__(self, dim, scale=1.0):
        super().__init__()
        assert dim % 2 == 0
        self.dim = dim
        self.scale = scale

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        # x * self.scale和emb外积
        emb = torch.outer(x * self.scale, emb)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

注意力模块

class AttentionBlock(nn.Module):
    def __init__(self, in_channels, norm="gn", num_groups=32):
        super().__init__()

        self.in_channels = in_channels
        self.norm = get_norm(norm, in_channels, num_groups)
        # 定义一个卷积层,将输入通道数扩展为 3 倍输入通道数,用于生成查询(Q)、键(K)和值(V)。
        self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1)
        self.to_out = nn.Conv2d(in_channels, in_channels, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        # 将图片数据的通道提升三倍,然后在通道的维度上分割为Q、K、V三个矩阵
        q, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)

        q = q.permute(0, 2, 3, 1).view(b, h * w, c)
        k = k.view(b, c, h * w)
        v = v.permute(0, 2, 3, 1).view(b, h * w, c)
		
        # torch.bmm接受两个输入张量,并将它们视为一系列矩阵,然后对这些矩阵进行乘法运算。
        dot_products = torch.bmm(q, k) * (c ** (-0.5))
        assert dot_products.shape == (b, h * w, h * w)

        attention = torch.softmax(dot_products, dim=-1)
        out = torch.bmm(attention, v)
        assert out.shape == (b, h * w, c)
        out = out.view(b, h, w, c).permute(0, 3, 1, 2)

        return self.to_out(out) + x

残差模块

在这里插入图片描述

class ResidualBlock(nn.Module):
    def __init__(
        self, 
        in_channels, 
        out_channels, 
        dropout, 
        time_emb_dim=None, 
        num_classes=None, 
        activation=SiLU(),	# 默认使用SiLU激活函数
        norm="gn", 			# 默认使用组归一化
        num_groups=32, 
        use_attention=False,
    ):
        super().__init__()
        
		# 激活函数
        self.activation = activation
		# 归一化层
        self.norm_1 = get_norm(norm, in_channels, num_groups)
        self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)

        self.norm_2 = get_norm(norm, out_channels, num_groups)
        self.conv_2 = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
        )

        self.time_bias = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None else None
        self.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else None

        self.residual_connection = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
        # nn.Identity() 是一个特殊的模块,它实现了一个恒等函数,即输入直接等于输出。
        self.attention = nn.Identity() if not use_attention else AttentionBlock(out_channels, norm, num_groups)

    def forward(self, x, time_emb=None, y=None):
        # 归一化层 + 激活函数 + 卷积层
        out = self.activation(self.norm_1(x))
        out = self.conv_1(out)

        # 对时间time_emb做一个全连接,施加在通道上
        if self.time_bias is not None:
            if time_emb is None:
                raise ValueError("time conditioning was specified but time_emb is not passed")
            # 激活函数 + 线性层
            out += self.time_bias(self.activation(time_emb))[:, :, None, None]

        # 对种类y_emb做一个全连接,施加在通道上
        if self.class_bias is not None:
            if y is None:
                raise ValueError("class conditioning was specified but y is not passed")

            out += self.class_bias(y)[:, :, None, None]
	    # 归一化层 + 激活函数
        out = self.activation(self.norm_2(out))
        # 第二个卷积 + 残差边
        out = self.conv_2(out) + self.residual_connection(x)
        # 最后做个Attention
        out = self.attention(out)
        return out

下采样层

class Downsample(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        # 使用一个大小为(3, 3),步长为 2 的卷积来进行下采样
        self.downsample = nn.Conv2d(in_channels, in_channels, 3, stride=2, padding=1)

    def forward(self, x, time_emb, y):
        # 判断高度是否为偶数个像素
        if x.shape[2] % 2 == 1:
            raise ValueError("downsampling tensor height should be even")
        # 判读宽度是否为偶数个像素
        if x.shape[3] % 2 == 1:
            raise ValueError("downsampling tensor width should be even")

        return self.downsample(x)

上采样层

class Upsample(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.upsample = nn.Sequential(
            # 上采样层使用最近邻插值模式
            nn.Upsample(scale_factor=2, mode="nearest"),
            # 卷积层
            nn.Conv2d(in_channels, in_channels, 3, padding=1),
        )

    def forward(self, x, time_emb, y):
        return self.upsample(x)

U-Net网络

class UNet(nn.Module):
    def __init__(
        self, 
        img_channels, 
        base_channels=128, 
        channel_mults=(1, 2, 4, 8),
        num_res_blocks=3, 
        time_emb_dim=128 * 4, 
        time_emb_scale=1.0, 
        num_classes=None, 
        activation=SiLU(),
        dropout=0.1, 
        attention_resolutions=(1,), 
        norm="gn", 
        num_groups=32, 
        initial_pad=0,
    ):
        super().__init__()
        # 使用到的激活函数,一般为SiLU
        self.activation = activation
        # 是否对输入进行padding
        self.initial_pad = initial_pad
        # 需要去区分的类别数
        self.num_classes = num_classes

        # 时间步 t 编码连接层
        self.time_mlp = nn.Sequential(
            PositionalEmbedding(base_channels, time_emb_scale),
            nn.Linear(base_channels, time_emb_dim),
            SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        ) if time_emb_dim is not None else None

        # 对输入图片的第一个卷积
        self.init_conv = nn.Conv2d(img_channels, base_channels, 3, padding=1)

        # self.downs 用于存储下采样用到的所有模块
        # self.ups 用于存储上采样用到的所有模块
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()

        # channels 指的是每一个模块处理后的通道数
        # now_channels 是一个中间变量,代表中间的通道数
        channels = [base_channels]
        now_channels = base_channels
        for i, mult in enumerate(channel_mults):
            # 得到下采样后的每一个通道个数
            out_channels = base_channels * mult
            # 进行添加每一个resblockattn
            for _ in range(num_res_blocks):
                self.downs.append(
                    ResidualBlock(
                        now_channels, out_channels, dropout,
                        time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
                        norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,
                    )
                )
                now_channels = out_channels
                channels.append(now_channels)
	
            if i != len(channel_mults) - 1:
                self.downs.append(Downsample(now_channels))
                channels.append(now_channels)

        # 中间层的特征提取模块,直接添加两个residual block模块
        self.mid = nn.ModuleList(
            [
                ResidualBlock(
                    now_channels, now_channels, dropout,
                    time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
                    norm=norm, num_groups=num_groups, use_attention=True,
                ),
                ResidualBlock(
                    now_channels, now_channels, dropout,
                    time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
                    norm=norm, num_groups=num_groups, use_attention=False,
                ),
            ]
        )

        # 进行上采样,进行特征融合
        for i, mult in reversed(list(enumerate(channel_mults))):
            out_channels = base_channels * mult

            for _ in range(num_res_blocks + 1):
                self.ups.append(ResidualBlock(
                    channels.pop() + now_channels, out_channels, dropout,
                    time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
                    norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,
                ))
                now_channels = out_channels

            if i != 0:
                self.ups.append(Upsample(now_channels))

        assert len(channels) == 0

        self.out_norm = get_norm(norm, base_channels, num_groups)
        self.out_conv = nn.Conv2d(base_channels, img_channels, 3, padding=1)

    def forward(self, x, time=None, y=None):
        # 是否对输入进行padding
        ip = self.initial_pad
        if ip != 0:
            x = F.pad(x, (ip,) * 4)

        # 对时间轴输入的全连接层
        if self.time_mlp is not None:
            if time is None:
                raise ValueError("time conditioning was specified but tim is not passed")
            time_emb = self.time_mlp(time)
        else:
            time_emb = None

        if self.num_classes is not None and y is None:
            raise ValueError("class conditioning was specified but y is not passed")

        # 对输入图片的第一个卷积
        x = self.init_conv(x)

        # skips用于存放下采样的中间层
        skips = [x]
        for layer in self.downs:
            x = layer(x, time_emb, y)
            skips.append(x)

        # 特征整合与提取
        for layer in self.mid:
            x = layer(x, time_emb, y)

        # 上采样并进行特征融合
        for layer in self.ups:
            if isinstance(layer, ResidualBlock):
                x = torch.cat([x, skips.pop()], dim=1)
            x = layer(x, time_emb, y)

        # 上采样并进行特征融合
        x = self.activation(self.out_norm(x))
        x = self.out_conv(x)

        if self.initial_pad != 0:
            return x[:, :, ip:-ip, ip:-ip]
        else:
            return x

完整的模型结构

import math

import torch
import torch.nn as nn
import torch.nn.functional as F


class SiLU(nn.Module):
    # SiLU激活函数
    @staticmethod
    def forward(x):
        return x * torch.sigmoid(x)


def get_norm(norm, num_channels, num_groups):
    if norm == "in":
        return nn.InstanceNorm2d(num_channels, affine=True)
    elif norm == "bn":
        return nn.BatchNorm2d(num_channels)
    elif norm == "gn":
        return nn.GroupNorm(num_groups, num_channels)
    elif norm is None:
        return nn.Identity()
    else:
        raise ValueError("unknown normalization type")


class PositionalEmbedding(nn.Module):
    def __init__(self, dim, scale=1.0):
        super().__init__()
        assert dim % 2 == 0
        self.dim = dim
        self.scale = scale

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        # x * self.scale和emb外积
        emb = torch.outer(x * self.scale, emb)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


class Downsample(nn.Module):
    def __init__(self, in_channels):
        super().__init__()

        self.downsample = nn.Conv2d(in_channels, in_channels, 3, stride=2, padding=1)

    def forward(self, x, time_emb, y):
        if x.shape[2] % 2 == 1:
            raise ValueError("downsampling tensor height should be even")
        if x.shape[3] % 2 == 1:
            raise ValueError("downsampling tensor width should be even")

        return self.downsample(x)


class Upsample(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(in_channels, in_channels, 3, padding=1),
        )

    def forward(self, x, time_emb, y):
        return self.upsample(x)


class AttentionBlock(nn.Module):
    def __init__(self, in_channels, norm="gn", num_groups=32):
        super().__init__()

        self.in_channels = in_channels
        self.norm = get_norm(norm, in_channels, num_groups)
        self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1)
        self.to_out = nn.Conv2d(in_channels, in_channels, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        q, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)

        q = q.permute(0, 2, 3, 1).view(b, h * w, c)
        k = k.view(b, c, h * w)
        v = v.permute(0, 2, 3, 1).view(b, h * w, c)

        dot_products = torch.bmm(q, k) * (c ** (-0.5))
        assert dot_products.shape == (b, h * w, h * w)

        attention = torch.softmax(dot_products, dim=-1)
        out = torch.bmm(attention, v)
        assert out.shape == (b, h * w, c)
        out = out.view(b, h, w, c).permute(0, 3, 1, 2)

        return self.to_out(out) + x


class ResidualBlock(nn.Module):
    def __init__(
            self, in_channels, out_channels, dropout, time_emb_dim=None, num_classes=None, activation=SiLU(),
            norm="gn", num_groups=32, use_attention=False,
    ):
        super().__init__()

        self.activation = activation

        self.norm_1 = get_norm(norm, in_channels, num_groups)
        self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)

        self.norm_2 = get_norm(norm, out_channels, num_groups)
        self.conv_2 = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
        )

        self.time_bias = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None else None
        self.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else None

        self.residual_connection = nn.Conv2d(in_channels, out_channels,
                                             1) if in_channels != out_channels else nn.Identity()
        self.attention = nn.Identity() if not use_attention else AttentionBlock(out_channels, norm, num_groups)

    def forward(self, x, time_emb=None, y=None):
        out = self.activation(self.norm_1(x))
        # 第一个卷积
        out = self.conv_1(out)

        # 对时间time_emb做一个全连接,施加在通道上
        if self.time_bias is not None:
            if time_emb is None:
                raise ValueError("time conditioning was specified but time_emb is not passed")
            out += self.time_bias(self.activation(time_emb))[:, :, None, None]

        # 对种类y_emb做一个全连接,施加在通道上
        if self.class_bias is not None:
            if y is None:
                raise ValueError("class conditioning was specified but y is not passed")

            out += self.class_bias(y)[:, :, None, None]

        out = self.activation(self.norm_2(out))
        # 第二个卷积+残差边
        out = self.conv_2(out) + self.residual_connection(x)
        # 最后做个Attention
        out = self.attention(out)
        return out


class UNet(nn.Module):
    def __init__(
            self, img_channels, base_channels=128, channel_mults=(1, 2, 4, 8),
            num_res_blocks=3, time_emb_dim=128 * 4, time_emb_scale=1.0, num_classes=None, activation=SiLU(),
            dropout=0.1, attention_resolutions=(1,), norm="gn", num_groups=32, initial_pad=0,
    ):
        super().__init__()
        # 使用到的激活函数,一般为SILU
        self.activation = activation
        # 是否对输入进行padding
        self.initial_pad = initial_pad
        # 需要去区分的类别数
        self.num_classes = num_classes

        # 对时间轴输入的全连接层
        self.time_mlp = nn.Sequential(
            PositionalEmbedding(base_channels, time_emb_scale),
            nn.Linear(base_channels, time_emb_dim),
            SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        ) if time_emb_dim is not None else None

        # 对输入图片的第一个卷积
        self.init_conv = nn.Conv2d(img_channels, base_channels, 3, padding=1)

        # self.downs用于存储下采样用到的层,首先利用ResidualBlock提取特征
        # 然后利用Downsample降低特征图的高宽
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()

        # channels指的是每一个模块处理后的通道数
        # now_channels是一个中间变量,代表中间的通道数
        channels = [base_channels]
        now_channels = base_channels
        for i, mult in enumerate(channel_mults):
            out_channels = base_channels * mult
            for _ in range(num_res_blocks):
                self.downs.append(
                    ResidualBlock(
                        now_channels, out_channels, dropout,
                        time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
                        norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,
                    )
                )
                now_channels = out_channels
                channels.append(now_channels)

            if i != len(channel_mults) - 1:
                self.downs.append(Downsample(now_channels))
                channels.append(now_channels)

        # 可以看作是特征整合,中间的一个特征提取模块
        self.mid = nn.ModuleList(
            [
                ResidualBlock(
                    now_channels, now_channels, dropout,
                    time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
                    norm=norm, num_groups=num_groups, use_attention=True,
                ),
                ResidualBlock(
                    now_channels, now_channels, dropout,
                    time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
                    norm=norm, num_groups=num_groups, use_attention=False,
                ),
            ]
        )

        # 进行上采样,进行特征融合
        for i, mult in reversed(list(enumerate(channel_mults))):
            out_channels = base_channels * mult

            for _ in range(num_res_blocks + 1):
                self.ups.append(ResidualBlock(
                    channels.pop() + now_channels, out_channels, dropout,
                    time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
                    norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,
                ))
                now_channels = out_channels

            if i != 0:
                self.ups.append(Upsample(now_channels))

        assert len(channels) == 0

        self.out_norm = get_norm(norm, base_channels, num_groups)
        self.out_conv = nn.Conv2d(base_channels, img_channels, 3, padding=1)

    def forward(self, x, time=None, y=None):
        # 是否对输入进行padding
        ip = self.initial_pad
        if ip != 0:
            x = F.pad(x, (ip,) * 4)

        # 对时间轴输入的全连接层
        if self.time_mlp is not None:
            if time is None:
                raise ValueError("time conditioning was specified but tim is not passed")
            time_emb = self.time_mlp(time)
        else:
            time_emb = None

        if self.num_classes is not None and y is None:
            raise ValueError("class conditioning was specified but y is not passed")

        # 对输入图片的第一个卷积
        x = self.init_conv(x)

        # skips用于存放下采样的中间层
        skips = [x]
        for layer in self.downs:
            x = layer(x, time_emb, y)
            skips.append(x)

        # 特征整合与提取
        for layer in self.mid:
            x = layer(x, time_emb, y)

        # 上采样并进行特征融合
        for layer in self.ups:
            if isinstance(layer, ResidualBlock):
                x = torch.cat([x, skips.pop()], dim=1)
            x = layer(x, time_emb, y)

        # 上采样并进行特征融合
        x = self.activation(self.out_norm(x))
        x = self.out_conv(x)

        if self.initial_pad != 0:
            return x[:, :, ip:-ip, ip:-ip]
        else:
            return x

  • 12
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值