image_context,、degra_context使用过程的Prompt和交叉注意力机制代码

更新:

  • 4.26 第一次整理博文,只简要指出代码整体结构功能,对代码细节未完全理解。
  • 5.2 发现交叉注意力部分理解完全有失偏颇,重新整理关于unet和crossattention理论和代码

整体代码

代码地址在

daclip-uir-main/universal-image-restoration/config/daclip-sde/models/modules/DenoisingUNet_arch.py

注意力机制相关代码在同级文件夹下attention.py。

该部分属于使用IR-SDE的复原处理中估计图像噪声得分的步骤。

【DA-CLIP】复原过程代码解读-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/m0_60350022/article/details/137699012?spm=1001.2014.3001.5501这部分在DA-CLIP的论文中描述如下:

We use IR-SDE ( Luo et al. , 2023a ) as the base framework for image restoration. It adapts a U-Net architecture similar to DDPM ( Ho et al. , 2020 ) but removes all self-attention layers. To inject clean content embeddings into the diffusion process, we introduce a cross-attention ( Rombach et al. , 2022 ) mechanism to learn semantic guidance from pre-trained VLMs. Considering the varying input sizes in image restoration tasks and the increasing cost of applying attention to high-resolution features, we only use cross-attention in the bottom blocks of the U-Net for sample efficiency.

我们使用IR-SDE (Luo等人。作为图像恢复的基础框架。它采用了一种类似于DDPM的U-Net架构(Ho等人。但却删除了所有的自我注意层。为了在扩散过程中注入干净的内容嵌入,我们引入了交叉注意(Rombach等。从预先训练过的vlm中学习语义指导的机制。考虑到图像恢复任务中输入大小的变化,以及对高分辨率特征的注意应用成本的增加,我们只在U-Net的底部块中使用交叉注意来提高采样效率。

 在下面的代码中可以得到证实。

ConditionalUNet的forward 接收输入张量 xt、条件张量 cond、时间参数 time,以及可选的文本和图像上下文。

其中xtcondtext_context, image_context分别对应

noisy_tensor, LQ_tensor, degra_context, image_context

根据配置文件初始化

def forward(self, xt, cond, time, text_context=None, image_context=None):
    # 检查输入的时间参数是否为整数或浮点数,如果是,则将其转换为一个单元素张量,并移动到xt所在的设备
    if isinstance(time, int) or isinstance(time, float):
        time = torch.tensor([time]).to(xt.device)
    
    # X=noisy_tensor-LQ_tensor就是文章第一步添加的随机噪声,与LQ_tensor拼接,增加通道维度
    x = xt - cond
    x = torch.cat([x, cond], dim=1)
 
    # 获取输入张量的空间维度H和W
    H, W = x.shape[2:]
    # 检查并调整输入张量x的空间尺寸以匹配原始图像的尺寸
    x = self.check_image_size(x, H, W)
 
    # 应用初始卷积层
    x = self.init_conv(x)
    # 克隆x,用于后续操作
    x_ = x.clone()
 
    # 通过时间MLP处理时间参数
    t = self.time_mlp(time) 
    # 如果上下文维度大于0,并且使用degra上下文,且文本上下文不为空
    if self.context_dim > 0:
        if self.use_degra_context and text_context is not None:
            # 计算文本上下文的嵌入,将其与提示向量结合,并进行处理
            prompt_embedding = torch.softmax(self.text_mlp(text_context), dim=1) * self.prompt
            prompt_embedding = self.prompt_mlp(prompt_embedding)
            # 将处理后的文本上下文嵌入加到时间参数t上
            t = t + prompt_embedding
 
        # 如果使用图像上下文,且图像上下文不为空
        if self.use_image_context and image_context is not None:
            # 为图像上下文增加一个通道维度
            image_context = image_context.unsqueeze(1)
 
    # 存储下采样过程中的特征图
    h = []
    # 遍历下采样模块列表
    for b1, b2, attn, downsample in self.downs:
        # 应用第一个残差块和时间参数t
        x = b1(x, t)
        # 存储特征图
        h.append(x)
 
        # 应用第二个残差块和时间参数t
        x = b2(x, t)
        # 应用注意力机制,如果提供了图像上下文,则使用它
        x = attn(x, context=image_context)
        # 存储特征图
        h.append(x)
 
        # 应用下采样操作
        x = downsample(x)
 
    # 应用中间块1和时间参数t
    x = self.mid_block1(x, t)
    # 如果使用图像上下文,则应用注意力机制
    x = self.mid_attn(x, context=image_context) if self.use_image_context else x
    # 应用中间块2和时间参数t
    x = self.mid_block2(x, t)
 
    # 遍历上采样模块列表
    for b1, b2, attn, upsample in self.ups:
        # 从历史特征图中弹出并拼接特征,与当前特征图拼接
        x = torch.cat([x, h.pop()], dim=1)
        # 应用第一个残差块和时间参数t
        x = b1(x, t)
        
        # 再次从历史特征图中弹出并拼接特征,与当前特征图拼接
        x = torch.cat([x, h.pop()], dim=1)
        # 应用第二个残差块和时间参数t
        x = b2(x, t)
 
        # 应用注意力机制,如果提供了图像上下文,则使用它
        x = attn(x, context=image_context)
        # 应用上采样操作
        x = upsample(x)
 
    # 将原始输入xt与当前特征图x拼接,增加通道维度
    x = torch.cat([x, x_], dim=1)
 
    # 应用最终的残差块和时间参数t
    x = self.final_res_block(x, t)
    # 应用最终的卷积层
    x = self.final_conv(x)
 
    # 裁剪输出张量x,使其空间尺寸与原始输入图像的尺寸相匹配
    x = x[..., :H, :W].contiguous()
    
    # 返回处理后的输出张量x
    return x

1.预处理

 if isinstance(time, int) or isinstance(time, float):
        time = torch.tensor([time]).to(xt.device)
    
    # X=noisy_tensor-LQ_tensor就是app.py里添加的随机噪声,与LQ_tensor拼接,增加通道维度
    x = xt - cond
    x = torch.cat([x, cond], dim=1)
 
    # 获取输入张量的空间维度H和W
    H, W = x.shape[2:]
    # 检查并调整输入张量x的空间尺寸以匹配原始图像的尺寸
    x = self.check_image_size(x, H, W)
 
    # 应用初始卷积层
    x = self.init_conv(x)
    # 克隆x,用于后续操作
    x_ = x.clone()
 
    # 通过时间MLP处理时间参数
    t = self.time_mlp(time) 

self.init_conv = default_conv(in_nc * 2, nf, 7)
self.time_mlp = nn.Sequential(
            # 第一层是一个位置编码层,可能使用正弦和余弦函数来嵌入位置信息
            sinu_pos_emb,

            # 第二层是一个全连接层,将输入特征从 fourier_dim 维度映射到 time_dim 维度
            nn.Linear(fourier_dim, time_dim),

            # 第三层是 GELU 激活函数,它将非线性引入到模型中,有助于学习复杂的模式
            nn.GELU(),

            # 第四层是另一个全连接层,它将特征再次映射回 time_dim 维度
            # 这通常用于进一步提取特征和增强模型的非线性能力
            nn.Linear(time_dim, time_dim)
        )

2.prompt代码

 处理过程

        if self.context_dim > 0:
            if self.use_degra_context and text_context is not None:
                # 计算文本上下文的嵌入,将其与提示向量结合,并进行处理
                prompt_embedding = torch.softmax(self.text_mlp(text_context), dim=1) * self.prompt
                prompt_embedding = self.prompt_mlp(prompt_embedding)
                # 将处理后的文本上下文嵌入加到时间参数t上
                t = t + prompt_embedding

text_mlp

# 定义一个名为 text_mlp 的属性,它是一个由多个层组成的顺序模型,用于处理文本数据
            self.text_mlp = nn.Sequential(
                # 第一层是一个全连接层,将输入特征从 context_dim 维度映射到 time_dim 维度
                nn.Linear(context_dim, time_dim),

                # 第二层是一个非线性激活函数,这里用 NonLinearity() 表示,它可能是 nn.ReLU、nn.GELU 或其他激活函数
                # 这个非线性激活函数有助于模型捕捉和学习数据中的复杂关系
                NonLinearity(),  # 这里 NonLinearity() 应替换为具体的激活函数,例如 nn.ReLU 或 nn.GELU

                # 第三层是另一个全连接层,它将特征再次映射回 time_dim 维度
                # 这通常用于进一步提取特征和增强模型的非线性能力
                nn.Linear(time_dim, time_dim)
            )

使用了self.prompt 

self.prompt = nn.Parameter(
                # 调用 torch.rand 来生成一个随机初始化的张量
                # torch.rand(1, time_dim) 生成一个形状为 (1, time_dim) 的张量,其中 time_dim 是之前定义的维度
                # 张量中的每个元素都是从 [0, 1) 区间的均匀分布中随机抽取的
                torch.rand(1, time_dim)
            )

 self.prompt_mlp

self.prompt_mlp = nn.Linear(time_dim, time_dim)
            # 这个 MLP 只有一个全连接层,它将时间维度的特征映射回相同的时间维度
            # 这可能用于进一步处理或规范化提示数据的特征表示

3.Unet部分

下采样

      # 存储下采样过程中的特征图
        h = []
        # 遍历下采样模块列表
        for b1, b2, attn, downsample in self.downs:
            # 应用第一个残差块和时间参数t
            x = b1(x, t)
            # 存储特征图
            h.append(x)

            # 应用第二个残差块和时间参数t
            x = b2(x, t)
            # 应用注意力机制,如果提供了图像上下文,则使用它
            x = attn(x, context=image_context)
            # 存储特征图
            h.append(x)

            # 应用下采样操作
            x = downsample(x)

 这里需要在定义中查找self.downs的结构,这里结合self.ups一起看结构创建定义,这里self.depth深度为4.

网络结构

# 遍历网络深度
            for i in range(self.depth):
                # 输入和输出维度
                dim_in = nf * ch_mult[i]
                dim_out = nf * ch_mult[i + 1]

                # 输入和输出头部数量
                num_heads_in = dim_in // num_head_channels
                num_heads_out = dim_out // num_head_channels
                # 每个头部的输入维度
                dim_head_in = dim_in // num_heads_in

                # 如果使用图像上下文且上下文维度大于0
                if use_image_context and context_dim > 0:
                    # 这得看 i是否小于3才使用
                    att_down = LinearAttention(dim_in) if i < 3 else SpatialTransformer(dim_in, num_heads_in, dim_head,
                                                                                        depth=1,
                                                                                        context_dim=context_dim)
                    att_up = LinearAttention(dim_out) if i < 3 else SpatialTransformer(dim_out, num_heads_out, dim_head,
                                                                                       depth=1, context_dim=context_dim)
                else:
                    # 使用线性注意力机制
                    att_down = LinearAttention(dim_in)  # if i < 2 else Attention(dim_in)
                    att_up = LinearAttention(dim_out)  # if i < 2 else Attention(dim_out)

                # 下采样模块列表
                # self.downs 是一个用于存储处理块序列的列表,在类的初始化方法中定义
                self.downs.append(
                    nn.ModuleList([
                        # 创建第一个模块,一个自定义的块类(block_class),其输入和输出维度都是 dim_in,时间嵌入维度是 time_dim
                        block_class(dim_in=dim_in, dim_out=dim_in, time_emb_dim=time_dim),

                        # 创建第二个模块,与第一个模块相同
                        block_class(dim_in=dim_in, dim_out=dim_in, time_emb_dim=time_dim),

                        # 创建第三个模块,一个残差连接(Residual),它包含一个预归一化层(PreNorm)和注意力层(att_down)
                        Residual(PreNorm(dim_in, att_down)),

                        # 根据当前的索引 i 是否等于 (self.depth - 1),创建不同的模块
                        # 如果 i 不等于 (self.depth - 1),则创建一个 Downsample 模块,用于在 dim_in 和 dim_out 之间进行下采样
                        # 如果 i 等于 (self.depth - 1),则使用 default_conv 创建一个默认的二维卷积层,将 dim_in 维度的特征映射到 dim_out 维度
                        Downsample(dim_in, dim_out) if i != (self.depth - 1) else default_conv(dim_in, dim_out)
                    ])
                )

                # 上采样模块列表
                self.ups.insert(0, nn.ModuleList([
                    block_class(dim_in=dim_out + dim_in, dim_out=dim_out, time_emb_dim=time_dim),
                    block_class(dim_in=dim_out + dim_in, dim_out=dim_out, time_emb_dim=time_dim),
                    Residual(PreNorm(dim_out, att_up)),
                    Upsample(dim_out, dim_in) if i != 0 else default_conv(dim_out, dim_in)
                ]))

 这里注意att_down在unet的最bottom处(i==3时)使用了交叉注意力。

SpatialTransformer定义在attention.py文件

class SpatialTransformer(nn.Module):
    """
    Transformer block for image-like data.
    First, project the input (aka embedding)
    and reshape to b, t, d.
    Then apply standard transformer action.
    Finally, reshape to image
    """
    def __init__(self, in_channels, n_heads, d_head,
                 depth=1, dropout=0., context_dim=None):
        super().__init__()
        self.in_channels = in_channels
        inner_dim = n_heads * d_head
        self.norm = Normalize(in_channels)

        self.proj_in = nn.Conv2d(in_channels,
                                 inner_dim,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)

        self.transformer_blocks = nn.ModuleList(
            [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
                for d in range(depth)]
        )

        self.proj_out = zero_module(nn.Conv2d(inner_dim,
                                              in_channels,
                                              kernel_size=1,
                                              stride=1,
                                              padding=0))

    def forward(self, x, context=None):
        # note: if no context is given, cross-attention defaults to self-attention
        b, c, h, w = x.shape
        x_in = x
        x = self.norm(x)
        x = self.proj_in(x)
        x = rearrange(x, 'b c h w -> b (h w) c')
        for block in self.transformer_blocks:
            x = block(x, context=context)
        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
        x = self.proj_out(x)
        return x + x_in

其中使用BasicTransformerBlock类作为self.transformer_blocks

class BasicTransformerBlock(nn.Module):
    def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
        super().__init__()
        self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout)  # is a self-attention
        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
        self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
                                    heads=n_heads, dim_head=d_head, dropout=dropout)  # is self-attn if context is none
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
        self.checkpoint = checkpoint

    def forward(self, x, context=None):
        return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)

    def _forward(self, x, context=None):
        x = self.attn1(self.norm1(x)) + x
        x = self.attn2(self.norm2(x), context=context) + x
        x = self.ff(self.norm3(x)) + x
        return x

 该类的attn1 和attn12都使用了 CrossAttention类

 CrossAttention类定义如下:


class CrossAttention(nn.Module):
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)

        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context=None, mask=None):
        h = self.heads

        q = self.to_q(x)
        context = default(context, x)
        k = self.to_k(context)
        v = self.to_v(context)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        if exists(mask):
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h=h)
            sim.masked_fill_(~mask, max_neg_value)

        # attention, what we cannot get enough of
        attn = sim.softmax(dim=-1)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        return self.to_out(out)

中间块

   # 应用中间块1和时间参数t
        x = self.mid_block1(x, t)
        # 如果使用图像上下文,则应用注意力机制
        x = self.mid_attn(x, context=image_context) if self.use_image_context else x
        # 应用中间块2和时间参数t
        x = self.mid_block2(x, t)
# 中间维度
            mid_dim = nf * ch_mult[-1]
            # 中间头部数量
            num_heads_mid = mid_dim // num_head_channels
            # 中间块1
            self.mid_block1 = block_class(dim_in=mid_dim, dim_out=mid_dim, time_emb_dim=time_dim)
            # 如果使用图像上下文且上下文维度大于0
            if use_image_context and context_dim > 0:
                # 使用空间变换器
                self.mid_attn = Residual(PreNorm(mid_dim, SpatialTransformer(mid_dim, num_heads_mid, dim_head, depth=1,
                                                                             context_dim=context_dim)))
            else:
                # 使用线性注意力机制
                self.mid_attn = Residual(PreNorm(mid_dim, LinearAttention(mid_dim)))
            # 中间块2
            self.mid_block2 = block_class(dim_in=mid_dim, dim_out=mid_dim, time_emb_dim=time_dim)

上采样

  # 遍历上采样模块列表
        for b1, b2, attn, upsample in self.ups:
            # 从历史特征图中弹出并拼接特征,与当前特征图拼接
            x = torch.cat([x, h.pop()], dim=1)
            # 应用第一个残差块和时间参数t
            x = b1(x, t)

            # 再次从历史特征图中弹出并拼接特征,与当前特征图拼接
            x = torch.cat([x, h.pop()], dim=1)
            # 应用第二个残差块和时间参数t
            x = b2(x, t)

            # 应用注意力机制,如果提供了图像上下文,则使用它
            x = attn(x, context=image_context)
            # 应用上采样操作
            x = upsample(x)

4.后处理

 # 将原始输入xt与当前特征图x拼接,增加通道维度
        x = torch.cat([x, x_], dim=1)

        # 应用最终的残差块和时间参数t
        x = self.final_res_block(x, t)
        # 应用最终的卷积层
        x = self.final_conv(x)

        # 裁剪输出张量x,使其空间尺寸与原始输入图像的尺寸相匹配
        x = x[..., :H, :W].contiguous()

        # 返回处理后的输出张量x
        return x

# 最终残差块
self.final_res_block = block_class(dim_in=nf * 2, dim_out=nf, time_emb_dim=time_dim)
# 最终卷积层
self.final_conv = nn.Conv2d(nf, out_nc, 3, 1, 1)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值