Diffusion Model代码从零搭建(1)

教程来自Diffusion models代码解读:入门与实战

# 1. Network helpers
# 定义辅助函数

#函数exists接受参数x并检查是不是None, 使用is操作符来检查x是否和None是同一个对象, is操作符比较的是对象的身份(即它们是否指向内存中的同一个位置)
def exists(x):
    return x is not None

#在 val 存在(即不是 None)时返回 val,否则返回 d 的值。但是,这里 d 的处理有些特殊:如果 d 是一个可调用的对象(比如函数或类实例的 __call__ 方法),则调用它并返回结果;如果 d 不是可调用的,则直接返回 d 的值。
def default(val, d):
    if exists(val):
        return val
    return d() if isinstance(d) else d

# 定义残差,添加到特定函数的残差连接
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x
# 定义上采样
def Upsample(dim):
    return nn.ConvTranspose2d(dim, dim, 4, stride=2, padding=1)

#定义下采样
def Downsample(dim):
    return nn.ConvTranspose2d(dim, dim, 4, stride=2, padding=1)

# 2. Position embeddings
# 神经网络的参数跨时间共享采用sinusoidal position embeddings编码时间time。在批量处理图像时,使得神经网络知道在特定时间的步长操作。

class SinusoidalPositionEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device #获取输入张量time所在的设备(CPU或GPU),以便在后续操作中确保所有张量都在相同的设备上。
        half_dim = self.dim // 2 #嵌入的维度一半正弦一半余弦
        embeddings = math.log(10000) / (half_dim - 1) #计算缩放因子,保证嵌入的值在合理范围内
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) #使用torch.arange生成一个从0到half_dim-1的整数序列,并将该序列乘以缩放因子的负值,然后exp取指数,得到一个指数递减的序列。
        embeddings = time[:, None] * embeddings[:, None] #上一步的序列与time相乘得到基于位置的嵌入
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) #生成正弦和余弦的嵌入
        # 为什么基于时间做位置嵌入之后还要生成正弦和余弦的嵌入?正弦和余弦可以表征不同位置间的依赖关系,捕捉绝对位置和相对位置,提高长序处理能力
        return embeddings

运用U-net预测噪声更好的学习到数据的增广分布(数据增广:对原始数据集进行一系列变换丰富数据集,提高模型训练效果),缺点是训练复杂模型复杂度增加

# 3. ResNet/ConvNeXT block
# 构造U-Net的核心模块
class Block(nn.Module)
    def __init__(self, dim, dim_out, groups = 8):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift = None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x

class ResnetBlock(nn.Module):
    """https://arxiv.org/abs/1512.03385 """

    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
        super().__init__()
        self.mlp(
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
            if exists(time_emb_dim)
            else None
        )

        self.block1 = Block(dim, dim_out, groups=groups)
        self.block2 = Block(dim_out, dim_out, groups=grpups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        h = self.block1(x)

        if exists(self, mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            h = rearrange(time_emb, "b c -> b c 1 1") + h

        h = self.block2(h)
        return h + self.res_conv(x)
#4. Attention Model 添加到卷积模块当中
class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32, dropout=0.):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, kernel_size=1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, kernel_size=1, bias=False)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )#qkv是包含了q、k和v信息的张量,map将lambda t应用于qkv中的q、k和v,lambda定义了对t的重新排列操作,b=batchsize, h=head number, c=channel number, x、y空间维度
        q = q * self.scale
        sim = einsum("b h d i, b h d j -> b h i j", q, k) #einsum用于计算多维数组之间的元素及乘法和求和。q 和 k 的形状分别是 (b, h, d, i) 和 (b, h, d, j),对于每一个 (b, h, d) 的组合,einsum 会对 q 的 i 维度和 k 的 j 维度执行点积(即元素级乘法后的求和)。结果是一个形状为 (b, h, i, j) 的张量。
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)

        out = einsum("b h i j, b h d j -> b h i d", attn, v)
        out = arrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
        return self.to_out(out)

#考虑到复杂度问题,这里建议用linear attention
class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32, dropout=0.):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, kernel_size=1, bias=False)
        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )
        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
        return self.to_out(out)
#5. Group normalization 放在attention之前
class Prenorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(1, dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

Conditional U-Net
(1)卷积计算噪声position embeddings
(2)下采样:ResNet block ∗ 2 *2 2 + groupnorm + attention + residual connection + downsample
(3)中间应用ResNet block和attention
(4)上采样:ResNet block ∗ 2 *2 2 + groupnorm + attention + residual connection + upsample
(5)ResNet block

#6. Conditional U-Net

class Unet(nn.Module):
    def __init__(
            self,
            dim,
            init_dim=None,
            out_dim=None,
            dim_mults=(1, 2, 4, 8),
            channels=3,
            with_time_emb=True,
            resnet_block_groups=8,
            use_convnext=True,
            convnext_mult=2,
    ):
        super().__init__()

        # determine dimensions
        self.channels = channels

        init_dim = default(init_dim, dim // 3 * 2)
        self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        if use_convnext:
            block_klass = partial(ConvNextBlock, mult=convnext_mult)
        else:
            block_klass = partial(ResnetBlock, groups=resnet_block_groups)

        # time embeddings
        if with_time_emb:
            time_dim = dim * 4
            self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(dim),
                nn.Linear(dim, time_dim),
                nn.GELU(),
                nn.Linear(time_dim, time_dim),
            )
        else:
            time_dim = None
            self.time_mlp = None

        # layers
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(
                nn.ModuleList(
                    [
                        block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                        block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Downsample(dim_out) if not is_last else nn.Identity(),
                    ]
                )
            )

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)

            self.ups.append(
                nn.ModuleList(
                    [
                        block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        Upsample(dim_in) if not is_last else nn.Identity(),
                    ]
                )
            )

        out_dim = default(out_dim, channels)
        self.final_conv = nn.Sequential(
            block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
        )

    def forward(self, x, time):
        x = self.init_conv(x)

        t = self.time_mlp(time) if exists(self.time_mlp) else None

        h = []

        # downsample
        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            h.append(x)
            x = downsample(x)

        # bottleneck
        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        # upsample
        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            x = upsample(x)

        return self.final_conv(x)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值