paper:MULTI-SCALE ATTENTION NETWORK FOR SINGLE IMAGE SUPER-RESOLUTION
1、Multi-scale Large Kernel Attention
为了解决如何有效地建立不同区域之间的长距离相关性,并避免由于大卷积核带来的“块效应”问题。这篇论文在 LKA 的基础上提出了一种 多尺度大核注意力(Multi-scale Large Kernel Attention),MLKA 的设计动机是为了解决图像超分辨率任务中,MLKA 结合了 大卷积核分解 和 多尺度机制 来实现这一目标。
MLKA 的实现过程:
- 输入特征图 X: 输入特征图 X 被分解成多个组,每个组包含相同数量的通道。
- LKA 模块: 对每个组应用 LKA 模块,生成不同尺度上的注意力图 LKAi。
- 门控模块: 为了避免扩张卷积带来的“块效应”,对每个组生成的注意力图进行动态重校准。这样可以更好地保留局部纹理信息。通过对每个 LKAi 应用门控模块,生成门控注意力图 MLKAi。
- 聚合: 将所有 MLKAi 聚合,得到最终的注意力图。
MLKA 的优势:
- 更全面的长距离相关性学习: 通过多尺度机制,MLKA 可以学习不同尺度上的长距离相关性,从而更好地恢复图像细节。
- 避免“块效应”: 通过门控机制,MLKA 可以有效地避免扩张卷积带来的“块效应”,从而更好地保留图像的平滑性。
- 计算效率高: MLKA 通过大卷积核分解和门控机制,实现了计算效率的提升。
Multi-scale Large Kernel Attention 结构图:
2、代码实现
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class LayerNorm(nn.Module):
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class MLKA(nn.Module):
def __init__(self, n_feats, k=2, squeeze_factor=15):
super().__init__()
i_feats = 2 * n_feats
self.norm = LayerNorm(n_feats, data_format='channels_first')
self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)
# Multiscale Large Kernel Attention
self.LKA7 = nn.Sequential(
nn.Conv2d(n_feats // 3, n_feats // 3, 7, 1, 7 // 2, groups=n_feats // 3),
nn.Conv2d(n_feats // 3, n_feats // 3, 9, stride=1, padding=(9 // 2) * 4, groups=n_feats // 3, dilation=4),
nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))
self.LKA5 = nn.Sequential(
nn.Conv2d(n_feats // 3, n_feats // 3, 5, 1, 5 // 2, groups=n_feats // 3),
nn.Conv2d(n_feats // 3, n_feats // 3, 7, stride=1, padding=(7 // 2) * 3, groups=n_feats // 3, dilation=3),
nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))
self.LKA3 = nn.Sequential(
nn.Conv2d(n_feats // 3, n_feats // 3, 3, 1, 1, groups=n_feats // 3),
nn.Conv2d(n_feats // 3, n_feats // 3, 5, stride=1, padding=(5 // 2) * 2, groups=n_feats // 3, dilation=2),
nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))
self.X3 = nn.Conv2d(n_feats // 3, n_feats // 3, 3, 1, 1, groups=n_feats // 3)
self.X5 = nn.Conv2d(n_feats // 3, n_feats // 3, 5, 1, 5 // 2, groups=n_feats // 3)
self.X7 = nn.Conv2d(n_feats // 3, n_feats // 3, 7, 1, 7 // 2, groups=n_feats // 3)
self.proj_first = nn.Sequential(
nn.Conv2d(n_feats, i_feats, 1, 1, 0))
self.proj_last = nn.Sequential(
nn.Conv2d(n_feats, n_feats, 1, 1, 0))
def forward(self, x, pre_attn=None, RAA=None):
shortcut = x.clone()
x = self.norm(x)
x = self.proj_first(x)
a, x = torch.chunk(x, 2, dim=1)
a_1, a_2, a_3 = torch.chunk(a, 3, dim=1)
a = torch.cat([self.LKA3(a_1) * self.X3(a_1), self.LKA5(a_2) * self.X5(a_2), self.LKA7(a_3) * self.X7(a_3)],
dim=1)
x = self.proj_last(x * a) * self.scale + shortcut
return x
if __name__ == '__main__':
x = torch.randn(4, 360, 64, 64).cuda()
model = MLKA(360).cuda()
out = model(x)
print(out.shape)