图像修复领域-CVPR2024-Improving Image Restoration through Removing Degradations in Textual Representations
论文链接:Improving Image Restoration through Removing Degradations in Textual Representations(来自于CVPR2024)
基本思想
Improving Image Restoration through Removing Degradations in Textual Representations 提出的主要思想是通过在文本层面来消除退化信息,生成文本层面修复后的图像,然后用生成的图像来辅助图片层面的修复。这个我理解的论文的大致思想。
本论文还有涉及到两个重要的模型,分别是CLIP(图像和文本在特征空间上映射的模型),Stable Diffusion(稳定扩散模型,用来将文本信息作为指导生成不含噪声的图像)这两个模型可以查看我的前两篇文章(里面提供相应论文查看)。CLIP:CLIP原理
Stable Diffusion:Stable Diffusion原理(stable-diffusion-v2)
模型训练架构图
模型的总体架构中,需要训练四个模块,分别是Image-to-Text、Textual Restoration、Restoration Network、Dynamic Aggregation。其他模块都是固定参数的预训练模型,用来完成上诉四个模型的训练
CLIP Image Encoder :一个图像到文本的编码器,使用CLIP的预训练模型,能够很好的将图像特征映射到文本特征空间。
Image-to-Text :这个模型的作用就是将特征真正的映射到文字的特征,作者也对比了映射到显示文本(使用BLIPv2[53]将退化图像转换为图像标题)和隐式文本的对特征图还原的效果。
Textual Restoration :模型主要完成退化信息的删除,生成不含退化信息的文本信息,最后输入Diffusion Unet 模型来指导图像去噪
Diffusion Unet : 基于Unet的扩散模型进行图像去噪,原来就是Stable Diffusion中正向扩散,逆向去噪那一套,这里使用的指导信息是Textual Restoration去除退化信息后文本信息
Guidance Generation :这个就是(a)中训练好的两个模型权重拿来直接使用,用来生成一个文本层面删除退化信息图片来辅助图像层面的恢复。
Dynamic Aggregation :主要是在多层面的进行特征匹配和特征融合,将匹配到特征输入到Restoration Network中来辅助图像恢复
Restoration Network : 基于卷积和Transformer的图像恢复模块,通过transformer将Dynamic Aggregation匹配出来的特征融入到图像恢复中。
上面只是对的各个模型的作用做出粗略的解释,下面将从源码层面剖析具体的训练过程和模型内部是如何实现的。
Image-to-Text
每个词都会经过两个独立的映射网络:
mapping_{i}
:用于处理词嵌入中的第一个部分(一般可能是某种特征的前缀或标记)。mapping_patch_{i}
:用于处理剩余的词嵌入部分(可能是更详细的特征信息)。这两个映射网络的结构是一样的,包含多层线性变换和激活函数。
每个词的嵌入都经过了 4 层线性变换,配合 LayerNorm 和 LeakyReLU 非线性激活函数。这一串处理使得每个词的输入特征被逐层变换,最终映射到一个特定的输出维度。这种设计旨在通过多层的变换和非线性处理,学习到更丰富和复杂的特征表示。
本质就是进行了每个词自身经过线性变化,同时同一批的其他词也进行线性变化,最终将两者的特征拼接起来
值得注意是本层是单独训练的,同时没有加入的Textual Restoration,也就是没有删除退化信息,直接用来指导Stable Diffusion来进行图像修复。
class Mapper(nn.Module):
def __init__(self,
input_dim: int, # 输入维度
output_dim: int, # 输出维度
num_words: int, # 要处理的词的数量
):
super(Mapper, self).__init__()
self.num_words = num_words # 保存词的数量
# 对每一个词,创建两个不同的映射网络
for i in range(self.num_words):
# 第一个映射网络 'mapping_{i}'
setattr(self, f'mapping_{i}', nn.Sequential(nn.Linear(input_dim, 1280), # 输入映射到1280维
nn.LayerNorm(1280), # 层归一化
nn.LeakyReLU(), # 激活函数
nn.Linear(1280, 1280), # 线性层
nn.LayerNorm(1280), # 层归一化
nn.LeakyReLU(), # 激活函数
nn.Linear(1280, 1280), # 线性层
nn.LayerNorm(1280), # 层归一化
nn.LeakyReLU(), # 激活函数
nn.Linear(1280, output_dim))) # 输出映射到目标维度
# 第二个映射网络 'mapping_patch_{i}'
setattr(self, f'mapping_patch_{i}', nn.Sequential(nn.Linear(input_dim, 1280), # 输入映射到1280维
nn.LayerNorm(1280), # 层归一化
nn.LeakyReLU(), # 激活函数
nn.Linear(1280, 1280), # 线性层
nn.LayerNorm(1280), # 层归一化
nn.LeakyReLU(), # 激活函数
nn.Linear(1280, 1280), # 线性层
nn.LayerNorm(1280), # 层归一化
nn.LeakyReLU(), # 激活函数
nn.Linear(1280, output_dim))) # 输出映射到目标维度
def forward(self, embs):
hidden_states = () # 用于存储每个词的隐藏状态
embs = embs[0] # 获取输入的词嵌入
# 对每个词进行处理
for i in range(self.num_words):
# 将词嵌入通过两个映射网络进行处理,计算最终的隐藏状态
# mapping_{i} 处理 embs 的前1维
# mapping_patch_{i} 处理 embs 的其他维度,并取平均
hidden_state = getattr(self, f'mapping_{i}')(embs[:, :1]) + getattr(self, f'mapping_patch_{i}')(embs[:, 1:]).mean(dim=1, keepdim=True)
hidden_states += (hidden_state,) # 将隐藏状态添加到列表中
hidden_states = torch.cat(hidden_states, dim=1) # 将所有词的隐藏状态拼接成一个张量
return hidden_states # 返回最终的隐藏状态
Textual Restoration
hidden_states
:初始化一个空元组,用于存储每个映射的输出。循环处理每个词:
- 对于每个 i 从 0 到 num_words-1:
- 使用
getattr
动态访问mapping_{i}
,并将对应的嵌入传入(embs[:, i]
)。这里使用unsqueeze(1)
是为了将张量的形状从(batch_size, input_dim)
转换为(batch_size, 1, input_dim)
,以适应线性层的输入要求。- 将每个映射的输出
hidden_state
加入hidden_states
元组。拼接输出:
- 使用
torch.cat(hidden_states, dim=1)
将所有的hidden_states
在维度 1 上拼接,得到形状为(batch_size, num_words * output_dim)
的输出。总之一个深度前馈神经网络,具有多个线性变换、层归一化和激活函数,帮助学习输入特征到目标输出的复杂映射关系。通过这种结构,模型能够逐层提取特征并进行非线性变换,以适应输入数据的复杂性。
class CleanMapper(nn.Module):
def __init__(self,
input_dim: int, # 输入特征的维度
output_dim: int, # 输出特征的维度
num_words: int, # 要映射的词汇数量
):
super(CleanMapper, self).__init__()
self.num_words = num_words # 保存词汇数量
# 为每个词汇创建独立的映射网络
for i in range(self.num_words):
setattr(self, f'mapping_{i}', nn.Sequential(
nn.Linear(input_dim, 1280), # 输入到1280维的线性层
nn.LayerNorm(1280), # 对1280维输出进行层归一化
nn.LeakyReLU(), # Leaky ReLU激活函数
nn.Linear(1280, 1280), # 第二个线性层,保持1280维
nn.LayerNorm(1280), # 层归一化
nn.LeakyReLU(), # 激活函数
nn.Linear(1280, 1280), # 第三个线性层,保持1280维
nn.LayerNorm(1280), # 层归一化
nn.LeakyReLU(), # 激活函数
nn.Linear(1280, output_dim) # 最后一个线性层,输出到指定维度
))
def forward(self, embs):
hidden_states = () # 初始化一个空元组以存储每个词的隐藏状态
# 对每个词汇进行前向传播
for i in range(self.num_words):
hidden_state = getattr(self, f"mapping_{i}")(embs[:, i].unsqueeze(1)) # 获取对应映射并处理嵌入
hidden_states += (hidden_state,) # 将隐藏状态添加到元组中
# 将所有隐藏状态在维度1上拼接
hidden_states = torch.cat(hidden_states, dim=1)
return hidden_states # 返回拼接后的隐藏状态
Guidance Generation
这一模块的使用依赖于训练完成的Mapper和CleanMapper权重参数,加载预训练参数来生成清晰图片即可。
Mapper
- 功能:通常用于对输入特征进行映射或转换。它可能涉及多个层次的非线性变换,以学习输入数据的复杂关系。
- 结构:Mapper 可能包含多个线性层、激活函数、归一化层等,以实现特征提取和转换。具体结构根据具体任务和输入数据类型可能有所不同。
- 用途:常见于需要将数据从一个空间映射到另一个空间的应用,例如将图像特征映射到文本描述。
- 训练方式:在训练时,Mapper是先进行训练的,在训练时,冻结其他模型参数,同时不会加入CleanMapper映射,直接用模糊图片Guidance Generation生成高清图片计算loss值更新Mapper中的参数。
CleanMapper
- 功能:与 Mapper 相似,但通常更侧重于特征的“清理”或规范化。
CleanMapper
的名字可能暗示它在映射过程中应用了一些特殊的处理步骤,以确保输入特征更加规范或干净。- 结构:包含多个线性层和激活函数,并应用层归一化,这有助于稳定训练并提高模型的收敛性。层归一化可以在不同的批次之间消除特征的分布差异。
- 用途:可能用于需要更高精度或更清晰特征表达的任务,例如在自然语言处理或图像处理中的复杂映射。
- 训练方式: 首先会加载Mapper训练出来的预训练参数,然后冻结Mapper层的参数和其他模型的参数,同样用模糊图片Guidance Generation生成高清图片计算loss值更新CleanMapper中的参数。
Dynamic Aggregation和Restoration Network
Dynamic Aggregation和Restoration Network是同时进行的主要步骤如下:
- 模型第一步是对inp_img(模糊图片)和ref_img(参考图片,也就是上面生成的 图片)进行多尺度的特征提取,也就是接个四个卷积和残差连接,得到各四个特征图
- 对于inp_img我们只使用最后一层inp_img[4]的提取的特征来跟ref_img四个层的特征进行融合,每个层对于特征块的个数不一样,由于ref_img在进行特征提取的时候,特征图越来越小,通道数越来越多,对于每个层次的分块大小每一层都要2的幂次方缩小,源码中是(1,2,4,8)的缩小块的大小
- 由于我们使用的inp_img[4]切块与ref_img不同层次的特征的每个块进行相似度计算,找出相似度最高的,然后将相似度高的两块cat拼接再一起,然后走一个transformer或者带有残差连接的transformer,进行特征融合。
- 四层的特征分别进行匹配之后,并不是简单的拼接,而是做了一个残差连接,最先是第一层的较大的块ref_img[1],然后ref_img[2],都经过上面特征匹配,拼接,并且经过transformer,ref_img[1]做完得到enc_img[1],然后经过下采样得到,加入到ref_img[2]中,ref_img[3]、ref_img[4]都是一样的过程,特征图融合会越来越小,相当于下采样。但是我们需要输出图片是原始图像,所以我们要图像进行上采样,同时上采样时,加入enc_img[1]、enc_img[2]、enc_img[3]逐级拼接上,丰富特征。(有点类似与U-net结构)
注:代码过程太长,我直接上图解释