图像修复领域-CVPR2024-Improving Image Restoration through Removing Degradations in Textual Representations

图像修复领域-CVPR2024-Improving Image Restoration through Removing Degradations in Textual Representations

论文链接:Improving Image Restoration through Removing Degradations in Textual Representations(来自于CVPR2024)

基本思想

Improving Image Restoration through Removing Degradations in Textual Representations 提出的主要思想是通过在文本层面来消除退化信息,生成文本层面修复后的图像,然后用生成的图像来辅助图片层面的修复。这个我理解的论文的大致思想。
本论文还有涉及到两个重要的模型,分别是CLIP(图像和文本在特征空间上映射的模型),Stable Diffusion(稳定扩散模型,用来将文本信息作为指导生成不含噪声的图像)这两个模型可以查看我的前两篇文章(里面提供相应论文查看)。

CLIP:CLIP原理

Stable Diffusion:Stable Diffusion原理(stable-diffusion-v2)

在这里插入图片描述

模型训练架构图

模型的总体架构中,需要训练四个模块,分别是Image-to-Text、Textual Restoration、Restoration Network、Dynamic Aggregation。其他模块都是固定参数的预训练模型,用来完成上诉四个模型的训练

CLIP Image Encoder :一个图像到文本的编码器,使用CLIP的预训练模型,能够很好的将图像特征映射到文本特征空间。

Image-to-Text :这个模型的作用就是将特征真正的映射到文字的特征,作者也对比了映射到显示文本(使用BLIPv2[53]将退化图像转换为图像标题)和隐式文本的对特征图还原的效果。

Textual Restoration :模型主要完成退化信息的删除,生成不含退化信息的文本信息,最后输入Diffusion Unet 模型来指导图像去噪

Diffusion Unet : 基于Unet的扩散模型进行图像去噪,原来就是Stable Diffusion中正向扩散,逆向去噪那一套,这里使用的指导信息是Textual Restoration去除退化信息后文本信息

Guidance Generation :这个就是(a)中训练好的两个模型权重拿来直接使用,用来生成一个文本层面删除退化信息图片来辅助图像层面的恢复。

Dynamic Aggregation :主要是在多层面的进行特征匹配和特征融合,将匹配到特征输入到Restoration Network中来辅助图像恢复

Restoration Network : 基于卷积和Transformer的图像恢复模块,通过transformer将Dynamic Aggregation匹配出来的特征融入到图像恢复中。

在这里插入图片描述

上面只是对的各个模型的作用做出粗略的解释,下面将从源码层面剖析具体的训练过程和模型内部是如何实现的。

Image-to-Text

每个词都会经过两个独立的映射网络:

  • mapping_{i}:用于处理词嵌入中的第一个部分(一般可能是某种特征的前缀或标记)。
  • mapping_patch_{i}:用于处理剩余的词嵌入部分(可能是更详细的特征信息)。

这两个映射网络的结构是一样的,包含多层线性变换和激活函数。

每个词的嵌入都经过了 4 层线性变换,配合 LayerNorm 和 LeakyReLU 非线性激活函数。这一串处理使得每个词的输入特征被逐层变换,最终映射到一个特定的输出维度。这种设计旨在通过多层的变换和非线性处理,学习到更丰富和复杂的特征表示。

本质就是进行了每个词自身经过线性变化,同时同一批的其他词也进行线性变化,最终将两者的特征拼接起来

值得注意是本层是单独训练的,同时没有加入的Textual Restoration,也就是没有删除退化信息,直接用来指导Stable Diffusion来进行图像修复。

class Mapper(nn.Module):
    def __init__(self,
                 input_dim: int,  # 输入维度
                 output_dim: int,  # 输出维度
                 num_words: int,   # 要处理的词的数量
    ):
        super(Mapper, self).__init__()
        self.num_words = num_words  # 保存词的数量
        # 对每一个词,创建两个不同的映射网络
        for i in range(self.num_words):
            # 第一个映射网络 'mapping_{i}'
            setattr(self, f'mapping_{i}', nn.Sequential(nn.Linear(input_dim, 1280),  # 输入映射到1280维
                                                        nn.LayerNorm(1280),           # 层归一化
                                                        nn.LeakyReLU(),              # 激活函数
                                                        nn.Linear(1280, 1280),       # 线性层
                                                        nn.LayerNorm(1280),          # 层归一化
                                                        nn.LeakyReLU(),              # 激活函数
                                                        nn.Linear(1280, 1280),       # 线性层
                                                        nn.LayerNorm(1280),          # 层归一化
                                                        nn.LeakyReLU(),              # 激活函数
                                                        nn.Linear(1280, output_dim)))  # 输出映射到目标维度
            # 第二个映射网络 'mapping_patch_{i}'
            setattr(self, f'mapping_patch_{i}', nn.Sequential(nn.Linear(input_dim, 1280),  # 输入映射到1280维
                                                              nn.LayerNorm(1280),           # 层归一化
                                                              nn.LeakyReLU(),              # 激活函数
                                                              nn.Linear(1280, 1280),       # 线性层
                                                              nn.LayerNorm(1280),          # 层归一化
                                                              nn.LeakyReLU(),              # 激活函数
                                                              nn.Linear(1280, 1280),       # 线性层
                                                              nn.LayerNorm(1280),          # 层归一化
                                                              nn.LeakyReLU(),              # 激活函数
                                                              nn.Linear(1280, output_dim)))  # 输出映射到目标维度
    def forward(self, embs):
        hidden_states = ()  # 用于存储每个词的隐藏状态
        embs = embs[0]  # 获取输入的词嵌入

        # 对每个词进行处理
        for i in range(self.num_words):
            # 将词嵌入通过两个映射网络进行处理,计算最终的隐藏状态
            # mapping_{i} 处理 embs 的前1维
            # mapping_patch_{i} 处理 embs 的其他维度,并取平均
            hidden_state = getattr(self, f'mapping_{i}')(embs[:, :1]) + getattr(self, f'mapping_patch_{i}')(embs[:, 1:]).mean(dim=1, keepdim=True)
            hidden_states += (hidden_state,)  # 将隐藏状态添加到列表中
        hidden_states = torch.cat(hidden_states, dim=1)  # 将所有词的隐藏状态拼接成一个张量
        return hidden_states  # 返回最终的隐藏状态

Textual Restoration

hidden_states:初始化一个空元组,用于存储每个映射的输出。

循环处理每个词

  • 对于每个 i 从 0 到 num_words-1:
    • 使用 getattr 动态访问 mapping_{i},并将对应的嵌入传入(embs[:, i])。这里使用 unsqueeze(1) 是为了将张量的形状从 (batch_size, input_dim) 转换为 (batch_size, 1, input_dim),以适应线性层的输入要求。
    • 将每个映射的输出 hidden_state 加入 hidden_states 元组。

拼接输出

  • 使用 torch.cat(hidden_states, dim=1) 将所有的 hidden_states 在维度 1 上拼接,得到形状为 (batch_size, num_words * output_dim) 的输出。

总之一个深度前馈神经网络,具有多个线性变换、层归一化和激活函数,帮助学习输入特征到目标输出的复杂映射关系。通过这种结构,模型能够逐层提取特征并进行非线性变换,以适应输入数据的复杂性。

class CleanMapper(nn.Module):
    def __init__(self,
                 input_dim: int,  # 输入特征的维度
                 output_dim: int,  # 输出特征的维度
                 num_words: int,  # 要映射的词汇数量
    ):
        super(CleanMapper, self).__init__()

        self.num_words = num_words  # 保存词汇数量

        # 为每个词汇创建独立的映射网络
        for i in range(self.num_words):
            setattr(self, f'mapping_{i}', nn.Sequential(
                nn.Linear(input_dim, 1280),  # 输入到1280维的线性层
                nn.LayerNorm(1280),  # 对1280维输出进行层归一化
                nn.LeakyReLU(),  # Leaky ReLU激活函数
                nn.Linear(1280, 1280),  # 第二个线性层,保持1280维
                nn.LayerNorm(1280),  # 层归一化
                nn.LeakyReLU(),  # 激活函数
                nn.Linear(1280, 1280),  # 第三个线性层,保持1280维
                nn.LayerNorm(1280),  # 层归一化
                nn.LeakyReLU(),  # 激活函数
                nn.Linear(1280, output_dim)  # 最后一个线性层,输出到指定维度
            ))

    def forward(self, embs):
        hidden_states = ()  # 初始化一个空元组以存储每个词的隐藏状态

        # 对每个词汇进行前向传播
        for i in range(self.num_words):
            hidden_state = getattr(self, f"mapping_{i}")(embs[:, i].unsqueeze(1))  # 获取对应映射并处理嵌入
            hidden_states += (hidden_state,)  # 将隐藏状态添加到元组中

        # 将所有隐藏状态在维度1上拼接
        hidden_states = torch.cat(hidden_states, dim=1)

        return hidden_states  # 返回拼接后的隐藏状态

Guidance Generation

这一模块的使用依赖于训练完成的Mapper和CleanMapper权重参数,加载预训练参数来生成清晰图片即可。

Mapper
  • 功能:通常用于对输入特征进行映射或转换。它可能涉及多个层次的非线性变换,以学习输入数据的复杂关系。
  • 结构:Mapper 可能包含多个线性层、激活函数、归一化层等,以实现特征提取和转换。具体结构根据具体任务和输入数据类型可能有所不同。
  • 用途:常见于需要将数据从一个空间映射到另一个空间的应用,例如将图像特征映射到文本描述。
  • 训练方式:在训练时,Mapper是先进行训练的,在训练时,冻结其他模型参数,同时不会加入CleanMapper映射,直接用模糊图片Guidance Generation生成高清图片计算loss值更新Mapper中的参数。
CleanMapper
  • 功能:与 Mapper 相似,但通常更侧重于特征的“清理”或规范化。CleanMapper 的名字可能暗示它在映射过程中应用了一些特殊的处理步骤,以确保输入特征更加规范或干净。
  • 结构:包含多个线性层和激活函数,并应用层归一化,这有助于稳定训练并提高模型的收敛性。层归一化可以在不同的批次之间消除特征的分布差异。
  • 用途:可能用于需要更高精度或更清晰特征表达的任务,例如在自然语言处理或图像处理中的复杂映射。
  • 训练方式: 首先会加载Mapper训练出来的预训练参数,然后冻结Mapper层的参数和其他模型的参数,同样用模糊图片Guidance Generation生成高清图片计算loss值更新CleanMapper中的参数。
Dynamic Aggregation和Restoration Network

Dynamic Aggregation和Restoration Network是同时进行的主要步骤如下:

  1. 模型第一步是对inp_img(模糊图片)和ref_img(参考图片,也就是上面生成的 图片)进行多尺度的特征提取,也就是接个四个卷积和残差连接,得到各四个特征图
  2. 对于inp_img我们只使用最后一层inp_img[4]的提取的特征来跟ref_img四个层的特征进行融合,每个层对于特征块的个数不一样,由于ref_img在进行特征提取的时候,特征图越来越小,通道数越来越多,对于每个层次的分块大小每一层都要2的幂次方缩小,源码中是(1,2,4,8)的缩小块的大小
  3. 由于我们使用的inp_img[4]切块与ref_img不同层次的特征的每个块进行相似度计算,找出相似度最高的,然后将相似度高的两块cat拼接再一起,然后走一个transformer或者带有残差连接的transformer,进行特征融合。
  4. 四层的特征分别进行匹配之后,并不是简单的拼接,而是做了一个残差连接,最先是第一层的较大的块ref_img[1],然后ref_img[2],都经过上面特征匹配,拼接,并且经过transformer,ref_img[1]做完得到enc_img[1],然后经过下采样得到,加入到ref_img[2]中,ref_img[3]、ref_img[4]都是一样的过程,特征图融合会越来越小,相当于下采样。但是我们需要输出图片是原始图像,所以我们要图像进行上采样,同时上采样时,加入enc_img[1]、enc_img[2]、enc_img[3]逐级拼接上,丰富特征。(有点类似与U-net结构)

注:代码过程太长,我直接上图解释

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值