图像修复-CVPR2023-DRSformer-Learning A Sparse Transformer Network for Effective Image Deraining
DRSformer 通过自适应 Top-k 选择、自注意力的多尺度前馈网络和混合专家特征补偿器,实现了有效的特征聚合和协同优化,以提升图像去雨效果。
文章目录
论文链接:DRSformer-Learning A Sparse Transformer Network for Effective Image Deraining
主要创新点
-
自适应稀疏选择:提出了一种自适应 Top-k 选择操作符,可以保留
query-key
中最有用的自注意力值,避免不相关的特征干扰,从而提升特征聚合的效果。 -
多尺度信息建模:为了解决传统
Transformer
中缺乏多尺度信息的问题,提出了混合尺度的前馈网络(mixed-scale feed-forward network
),以生成更适合图像去雨的多尺度特征。 -
混合专家补偿器:结合 CNN 操作来增强局部上下文信息,通过混合专家特征补偿器(
mixture of experts feature compensator
)来进一步提升图像去雨的效果,实现协同优化。
模型架构图
模型由一下几部分组成:
- 自适应稀疏选择 (TKSA):为了实现稀疏选择自注意力机制,在q和k进行卷积之后的注意力结果,再与v做点积之前,进行特征选择,选择特征最好的k个作为和v计算注意力,实现对重要的特征的提取。
- 多尺度信息建模(MSFN):通过在特征传输过程中插入两个多尺度深度卷积路径来更好地去除多尺度的雨条纹,实验证明是有效的。
- 混合专家补偿器(MEFC):主要是进行了,利用稀疏卷积减少计算开销、自适应选择专家的表示、增强多样化的特征表示(注:只是字面解释为专家,我们可以理解为任务)。
- 开始和结尾3x3卷积,前面是将图像映射到特征图, 后面3x3是特征图还原成图像,可以理解为编码器(
encode
)和解码器(decoder
),但其实整个下采样都称为编码器(encode
),整个上采样过程称为解码器(decoder
)。
下面将从源码层面剖析模型代码的具体实现,直观的了解模型设计和代码实现之间的区别。
自适应稀疏选择 (TKSA)
稀疏选择与通道维度注意力:
TKSA
不再对所有查询-键对计算注意力,而是只保留对特征聚合最关键的信息,避免了大量不相关信息的干扰。其基本思路是稀疏化注意力,即只对最重要的注意力权重进行计算。
- 首先,
TKSA
使用1×1
卷积和3×3
深度卷积来对特征进行通道维度的编码。- 接着,
TKSA
在通道维度上应用注意力,而不是传统的空间维度。这种方式降低了时间和内存复杂度,因为减少了在空间维度上计算大规模注意力矩阵的开销。- Top-k 稀疏选择机制:
TKSA
计算重组后的查询和键矩阵中每对像素之间的相似度,并将那些注意力权重较低的元素屏蔽掉。换言之,只有最重要的前k
个注意力值会参与计算。这一过程产生一个稀疏的注意力矩阵,只包含最显著的注意力信息,从而减少了冗余计算量,也确保了更高效的特征聚合。
源码实现
在具体实现使,从源码中可以看出,对于稀疏选择时,采用4种不同的K来进行稀疏选择,并且为每个稀疏选择生成的注意力加上了可学习的参数,这样实现对不同注意力的自适应选择。
## Top-K稀疏注意力(TKSA)模块
class Attention(nn.Module):
def __init__(self, dim, num_heads, bias):
super(Attention, self).__init__()
self.num_heads = num_heads
# 温度参数,用于缩放注意力得分
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
# 定义q, k, v的投影层,1x1卷积和深度卷积3x3
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)
# 输出投影层
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
# 注意力的dropout
self.attn_drop = nn.Dropout(0.)
# 每个注意力掩码的可学习权重
self.attn1 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
self.attn2 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
self.attn3 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
self.attn4 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
def forward(self, x):
b, c, h, w = x.shape
# 计算q, k, v投影,并应用深度卷积
qkv = self.qkv_dwconv(self.qkv(x))
q, k, v = qkv.chunk(3, dim=1)
# 将q, k, v重排列为多头注意力的形式
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
# 对q和k进行归一化,确保稳定的注意力计算
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
_, _, C, _ = q.shape # C表示每个头的通道数
# 初始化不同top-k稀疏程度的掩码
mask1 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False)
mask2 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False)
mask3 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False)
mask4 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False)
# 计算缩放后的点积注意力
attn = (q @ k.transpose(-2, -1)) * self.temperature
# 对注意力进行top-k稀疏处理,创建不同的注意力掩码
index = torch.topk(attn, k=int(C/2), dim=-1, largest=True)[1]
mask1.scatter_(-1, index, 1.)
attn1 = torch.where(mask1 > 0, attn, torch.full_like(attn, float('-inf')))
index = torch.topk(attn, k=int(C*2/3), dim=-1, largest=True)[1]
mask2.scatter_(-1, index, 1.)
attn2 = torch.where(mask2 > 0, attn, torch.full_like(attn, float('-inf')))
index = torch.topk(attn, k=int(C*3/4), dim=-1, largest=True)[1]
mask3.scatter_(-1, index, 1.)
attn3 = torch.where(mask3 > 0, attn, torch.full_like(attn, float('-inf')))
index = torch.topk(attn, k=int(C*4/5), dim=-1, largest=True)[1]
mask4.scatter_(-1, index, 1.)
attn4 = torch.where(mask4 > 0, attn, torch.full_like(attn, float('-inf')))
# 对每个掩码中的注意力权重进行softmax归一化
attn1 = attn1.softmax(dim=-1)
attn2 = attn2.softmax(dim=-1)
attn3 = attn3.softmax(dim=-1)
attn4 = attn4.softmax(dim=-1)
# 对每个掩码计算注意力输出
out1 = (attn1 @ v)
out2 = (attn2 @ v)
out3 = (attn3 @ v)
out4 = (attn4 @ v)
# 使用学习到的注意力权重结合多个输出
out = out1 * self.attn1 + out2 * self.attn2 + out3 * self.attn3 + out4 * self.attn4
# 将输出重排列回原始的维度
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
# 最后投影输出匹配输入维度
out = self.project_out(out)
return out
多尺度信息建模(MSFN)
输入和层归一化:给定一个输入张量
X
,首先对其进行层归一化操作(Layer Normalization
)。1x1卷积扩展通道维度:为了更好地捕捉多尺度信息,网络首先使用一个
1×1
卷积扩展通道维度,扩展比例为 r。分支设计:然后,扩展后的特征被送入两个并行的卷积分支中,每个分支使用不同的卷积核(如
3×3
和5×5
的深度卷积)来提取多尺度的局部特征。这些不同尺寸的卷积核有助于捕捉不同尺度的雨条纹信息。特征融合:这两条卷积路径的输出会被融合,以增强特征表达,最终通过这些多尺度特征的融合,网络能够更好地恢复去雨后的清晰图像。
源码实现
## Mixed-Scale Feed-forward Network (MSFN)
class FeedForward(nn.Module):
def __init__(self, dim, ffn_expansion_factor, bias):
super(FeedForward, self).__init__()
# 计算隐藏层特征维度(通过扩展因子调整)
hidden_features = int(dim * ffn_expansion_factor)
# 输入的1x1卷积,用于通道数扩展
self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
# 使用3x3深度卷积来处理局部特征
self.dwconv3x3 = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3,
stride=1, padding=1, groups=hidden_features * 2, bias=bias)
# 使用5x5深度卷积来处理更大尺度的局部特征
self.dwconv5x5 = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=5,
stride=1, padding=2, groups=hidden_features * 2, bias=bias)
# ReLU激活函数,用于增加非线性
self.relu3 = nn.ReLU()
self.relu5 = nn.ReLU()
# 对3x3和5x5卷积的输出进行1x1卷积,以减少通道数
self.dwconv3x3_1 = nn.Conv2d(hidden_features * 2, hidden_features, kernel_size=3,
stride=1, padding=1, groups=hidden_features, bias=bias)
self.dwconv5x5_1 = nn.Conv2d(hidden_features * 2, hidden_features, kernel_size=5,
stride=1, padding=2, groups=hidden_features, bias=bias)
# 激活函数
self.relu3_1 = nn.ReLU()
self.relu5_1 = nn.ReLU()
# 最后的1x1卷积层,用于将处理后的特征映射回原始通道数
self.project_out = nn.Conv2d(hidden_features * 2, dim, kernel_size=1, bias=bias)
def forward(self, x):
# 通过1x1卷积对输入进行通道数扩展
x = self.project_in(x)
# 对通过3x3卷积处理后的特征进行ReLU激活并拆分成两部分
x1_3, x2_3 = self.relu3(self.dwconv3x3(x)).chunk(2, dim=1)
# 对通过5x5卷积处理后的特征进行ReLU激活并拆分成两部分
x1_5, x2_5 = self.relu5(self.dwconv5x5(x)).chunk(2, dim=1)
# 将通过3x3和5x5卷积处理后的特征拼接起来
x1 = torch.cat([x1_3, x1_5], dim=1)
x2 = torch.cat([x2_3, x2_5], dim=1)
# 对拼接后的特征分别通过3x3和5x5卷积层
x1 = self.relu3_1(self.dwconv3x3_1(x1))
x2 = self.relu5_1(self.dwconv5x5_1(x2))
# 将两个卷积后的特征拼接起来
x = torch.cat([x1, x2], dim=1)
# 通过1x1卷积层输出最终的特征
x = self.project_out(x)
return x
混合专家补偿器(MEFC)
- 利用稀疏卷积减少计算开销:
MEFC
中的卷积操作专注于图像中的非稀疏区域,利用稀疏卷积可以减少计算量,从而提升处理效率。- 自适应选择专家的表示:
MEFC
并未使用传统的外部门控网络来选择合适的专家,而是通过自注意力机制来根据输入特征的内容稀疏性自动选择适合的专家。这使得模型能够更灵活地应对不同稀疏条件下的输入。- 增强多样化的特征表示:不同专家使用的卷积核大小(如
1×1、3×3、5×5、7×7
等)和操作类型(如分离卷积、空洞卷积等)不同,使得MEFC
能够从多尺度、不同语义层次上丰富特征信息,从而更好地去除雨痕并保留图像内容。简单来说:就是使用不同的卷积核对特征图进行卷积,例如使用(
1×1、3×3、5×5、7×7
)卷积核,卷积出来的特征图,每个特征图都对应一个参数W
(W1×1、W3×3、W5×5、W7×7
),模型训练过程会根据模型他们的效果来更新权重,也就是说控制他们在整个特征图中合成时所占的比例,实现自适应选择专家。
源码实现
## Mixture of Experts Feature Compensator (MEFC)
class subnet(nn.Module):
def __init__(self, dim, layer_num=1, steps=4):
super(subnet, self).__init__()
# 初始化特征维度、操作数、层数和步数
self._C = dim # 输入特征的通道数(维度)
self.num_ops = len(Operations) # 操作的种类数量(假设是某些特定操作的集合)
self._layer_num = layer_num # 网络中的层数
self._steps = steps # 每层包含的步数
# 定义网络层列表
self.layers = nn.ModuleList()
for _ in range(self._layer_num):
# 定义自适应权重选择层
attention = OALayer(self._C, self._steps, self.num_ops)
self.layers += [attention] # 添加自适应层到层列表
# 定义每层的混合专家操作
layer = GroupOLs(steps, self._C)
self.layers += [layer] # 添加混合专家层到层列表
def forward(self, x):
# 前向传播过程
for _, layer in enumerate(self.layers):
if isinstance(layer, OALayer):
# 如果是自适应权重选择层,计算特征权重
weights = layer(x)
weights = F.softmax(weights, dim=-1) # 使用 softmax 进行归一化
else:
# 使用混合专家操作层处理输入特征
x = layer(x, weights) # 基于权重将特征输入到专家模型中
return x # 返回输出特征
对于其中一些层数定义,可以看以下代码,对于
OALayer
和GroupOLs
的定义,就是使用不同卷积,两层,如果是初始化权重,则初始化权重,如果是混合操作层,则带上权重去进行卷积,但是其实在实际执行过程中,两者是成双的,一一对应,都是先初始化权重,再进行卷积。
OALayer权重初始化
# 权重初始化
class OALayer(nn.Module):
def __init__(self, channel, k, num_ops):
super(OALayer, self).__init__()
self.k = k # k 是每组操作的数量
self.num_ops = num_ops # num_ops 是操作的数量
self.output = k * num_ops # 输出的大小,即 k 乘以 num_ops
self.avg_pool = nn.AdaptiveAvgPool2d(1) # 自适应池化,用于缩小输入特征图至单个值
# 通道注意力的全连接层,使用两个全连接层和 ReLU 激活函数来生成权重
self.ca_fc = nn.Sequential(
nn.Linear(channel, self.output * 2), # 将输入通道数映射为两倍的输出大小
nn.ReLU(), # 非线性激活函数
nn.Linear(self.output * 2, self.k * self.num_ops) # 映射为 k * num_ops 的输出
)
def forward(self, x):
y = self.avg_pool(x) # 自适应平均池化,将输入特征图的每个通道池化为单个值,提取全局特征
y = y.view(x.size(0), -1) # 调整形状以适应全连接层
y = self.ca_fc(y) # 通过全连接层生成权重
y = y.view(-1, self.k, self.num_ops) # 将权重调整为 (B, k, num_ops) 形状,方便后续使用
return y # 返回自适应权重
GroupOLs卷积
Operations = [
'sep_conv_1x1',
'sep_conv_3x3',
'sep_conv_5x5',
'sep_conv_7x7',
'dil_conv_3x3',
'dil_conv_5x5',
'dil_conv_7x7',
'avg_pool_3x3'
]
OPS = {
'avg_pool_3x3' : lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
'sep_conv_1x1' : lambda C, stride, affine: SepConv(C, C, 1, stride, 0, affine=affine),
'sep_conv_3x3' : lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
'sep_conv_5x5' : lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
'sep_conv_7x7' : lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
'dil_conv_3x3' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine),
'dil_conv_5x5' : lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine),
'dil_conv_7x7' : lambda C, stride, affine: DilConv(C, C, 7, stride, 6, 2, affine=affine),
}
class OperationLayer(nn.Module):
def __init__(self, C, stride):
super(OperationLayer, self).__init__()
# 初始化一个操作列表,包含所有操作
self._ops = nn.ModuleList()
for o in Operations: # 遍历预定义的操作
op = OPS[o](C, stride, False) # 创建操作并加入到操作列表
self._ops.append(op)
# 最终输出层:将多个操作的输出通过1x1卷积降回到通道数C,并加ReLU激活
self._out = nn.Sequential(
nn.Conv2d(C * len(Operations), C, 1, padding=0, bias=False),
nn.ReLU()
)
def forward(self, x, weights):
weights = weights.transpose(1, 0) # 将权重矩阵进行转置,适配每个操作的维度
states = []
# 遍历权重和对应的操作,将操作结果与对应的权重相乘并收集
for w, op in zip(weights, self._ops):
states.append(op(x) * w.view([-1, 1, 1, 1])) # 每个操作的输出按权重加权
# 将所有操作的输出按通道维度连接,并通过1x1卷积整合
return self._out(torch.cat(states[:], dim=1))
class GroupOLs(nn.Module):
def __init__(self, steps, C):
super(GroupOLs, self).__init__()
# 初始化步骤数和通道数
self.preprocess = ReLUConv(C, C, 1, 1, 0, affine=False) # 用于预处理输入的ReLU+1x1卷积层
self._steps = steps # 操作层的步数
self._ops = nn.ModuleList() # 包含多个OperationLayer的模块列表
self.relu = nn.ReLU() # ReLU激活函数
stride = 1 # 步幅设为1
# 添加每一步的操作层
for _ in range(self._steps):
op = OperationLayer(C, stride) # 创建 OperationLayer 实例
self._ops.append(op) # 加入操作列表
def forward(self, s0, weights):
s0 = self.preprocess(s0) # 先进行预处理
for i in range(self._steps):
res = s0 # 保留当前状态,便于残差连接
s0 = self._ops[i](s0, weights[:, i, :]) # 调用 OperationLayer
s0 = self.relu(s0 + res) # 残差连接并通过ReLU激活
return s0