论文地址:https://arxiv.org/pdf/2003.06792.pdfhttps://arxiv.org/pdf/2003.06792.pdf
代码地址:https://github.com/swz30/MIRNethttps://github.com/swz30/MIRNet
本博客为阅读这篇论文的一些记录,其具体内容和代码可以参考原论文和github上官方实现。个人觉得论文中的图画的非常直观,代码的可读性也非常好,优雅!
现有的基于CNN的方法通常在全分辨率或逐步低分辨率的表示上操作。
该论文的核心是一个包含几个关键元素的多尺度残差块:
-
parallel multi-resolution convolution streams for extracting multi-scale features;
-
information exchange across the multi-resolution streams;
-
spatial and channel attention mechanisms for capturing contextual information;
-
attention based multi-scale feature aggregation.
对于图像restoration和 enhancement 最常用的两种CNN方法:
- 编解码结构 encoder-decoder
编解码模型,在低分率空间进行处理;可以学习广泛的背景。精细的空间细节丢失了。
2.高分辨率(单尺度)特征处理 high-resolution (single-scale) feature processing
high-resolution 高分辨率不进行下采样操作。保留更精细的空间信息。缺乏感受野。编码上下文效率较差。
论文主要贡献:
- 一种新的特征提取模型,在多个空间尺度上获得互补的特征集,同时保持原始的高分辨率特征,以保持精确的空间细节。(类似于HRNet的思想)
- 一种定期重复的信息交换机制,其中跨多分辨率分支的特征逐渐融合在一起,以改进表示学习。
- 一种利用选择性核网络融合多尺度特征的新方法,该网络动态地结合可变的感受野,并在每个空间分辨率上地保留原始特征信息。
- 一种递归残差设计,逐步分解输入信号,以简化整体学习过程,并允许构建非常深的网络。
- 在5个真实图像基准数据集上进行了不同的图像处理任务,包括图像去噪、图像分辨率和图像增强。在所有五个数据集上都达到了SOTA。
与现有的方法不同,该论文以原始分辨率处理特征,以保留空间细节,同时有效地融合了来自多个并行分支的上下文信息。在我看来类似于HRNet的思想。
所提出的网络MIRNet的框架,学习丰富的特征表示用于图像恢复和增强。MIRNet是基于递归残差设计的。MIRNet的核心是多尺度残差块(MRB),其主要分支致力于通过整个网络维护空间精确的高分辨率表示,而互补的并行分支集提供了更好的上下文化特征。它还允许通过选择性内核特征融合(SKFF)跨并行流进行信息交换,以便在低分辨率特性的帮助下整合高分辨率特性,反之亦然。
多尺度残差模块(multi-scale residual block)包含的关键组件:
- 并行多分辨率卷积流,用于提取(细到粗)语义更丰富和(粗到细)空间精确的特征表示
- 跨多分辨率流的信息交换
- 来自多个流的基于注意力的特征聚合
- 利用双注意单元来捕获空间和通道维度上的上下文信息
- 残差大小调整模块,以执行降采样和上采样操作
Overall Pipeline:
如图所示,给定输入图像 I ,网络首先应用卷积操作提取low-level的特征X0 。随后X0 通过 N 个RRGs(recursive residual groups)产生深度特征Xd。然后对Xd应用卷积操作,获得residual 图像R。最终恢复的图像被表示为:
损失函数:
Multi-scale Residual Block(MRB)
本文提出了多尺度残差块(MRB),如图1所示。它能够通过维护高分辨率的表示来生成空间精确的输出,同时从低分辨率接收丰富的上下文信息。
Selective kernel feature fusion(SKFF)
SKFF操作来自多个卷积流的特征,并基于自注意力执行聚合。
如图2所示,SKFF模块通过Fuse和Select这两个操作对感受野进行动态调整。Fuse 通过结合来自多分辨率流的信息来生成全局特征描述。Select使用这些描述来重新校准特征映射(不同的流),然后对它们进行聚合。
(1)Fuse:SKFF接收来自三个携带不同尺度信息的并行卷积流的输入。我们首先使用element-wise summation 组合这些多尺度特征为:L=L1+L2+L3;然后在空间维度使用GAP(全局平均池化)来计算channel-wise 统计 s 。接下来,应用一个信道降尺度卷积层来生成一个紧凑的特征表示z。最后,特征向量z通过三个并行通道升级卷积层(每个分辨率流一个),并提供三个特征描述符v1、v2和v3,每个层的维数为1×1×C。
(2)Select:该操作将softmax函数应用于v1、v2和v3,产生注意激活s1、s2和s3,分别用来自适应地重新校准多尺度特征图L1、L2和L3。
特征重新校准和聚合的总体过程定义为:U=s1·L1+s2·L2+s3·L3。
Dual attention unit(DAU)
虽然SKFF块跨多分辨率分支融合信息,但我们还需要一种机制来在特征张量内沿着空间和通道维度共享信息。
(1) Channel attention(CA 通道注意力)
给定一个特征图M,挤压操作跨空间维度应用全局平均池来编码全局上下文,从而生成一个特征描述符d。(Fig3 下部分bottom)。经过两个卷积层和sigmoid生成 d hat .与M相乘。
Residual resizing modules
代码部分:
# 卷积模块
def conv(in_channels, out_channels, kernel_size, bias=False, padding = 1, stride = 1):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size//2), bias=bias, stride = stride)
##---------- Selective Kernel Feature Fusion (SKFF) ----------
class SKFF(nn.Module):
def __init__(self, in_channels, height=3,reduction=8,bias=False):
super(SKFF, self).__init__()
self.height = height
d = max(int(in_channels/reduction),4)
# 全局平均池化
self.avg_pool = nn.AdaptiveAvgPool2d(1)
# Conv PReLU
self.conv_du = nn.Sequential(nn.Conv2d(in_channels, d, 1, padding=0, bias=bias), nn.PReLU())
self.fcs = nn.ModuleList([])
for i in range(self.height):
self.fcs.append(nn.Conv2d(d, in_channels, kernel_size=1, stride=1,bias=bias))
self.softmax = nn.Softmax(dim=1)
def forward(self, inp_feats):
batch_size = inp_feats[0].shape[0]
n_feats = inp_feats[0].shape[1]
inp_feats = torch.cat(inp_feats, dim=1)
inp_feats = inp_feats.view(batch_size, self.height, n_feats, inp_feats.shape[2], inp_feats.shape[3])
# element-wise summation
feats_U = torch.sum(inp_feats, dim=1)
# GAP
feats_S = self.avg_pool(feats_U)
# Conv PReLU
feats_Z = self.conv_du(feats_S)
attention_vectors = [fc(feats_Z) for fc in self.fcs]
attention_vectors = torch.cat(attention_vectors, dim=1)
attention_vectors = attention_vectors.view(batch_size, self.height, n_feats, 1, 1)
# stx()
attention_vectors = self.softmax(attention_vectors)
feats_V = torch.sum(inp_feats*attention_vectors, dim=1)
return feats_V
## ------ Channel Attention --------------
class ca_layer(nn.Module):
def __init__(self, channel, reduction=8, bias=True):
super(ca_layer, self).__init__()
# global average pooling: feature --> point
self.avg_pool = nn.AdaptiveAvgPool2d(1)
# feature channel downscale and upscale --> channel weight
self.conv_du = nn.Sequential(
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
nn.ReLU(inplace=True),
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias),
nn.Sigmoid()
)
def forward(self, x):
y = self.avg_pool(x)
y = self.conv_du(y)
return x * y
##---------- Spatial Attention ----------
# Conv BN ReLu
class BasicConv(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=False, bias=False):
super(BasicConv, self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
self.relu = nn.ReLU() if relu else None
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
# GAP and GMP
class ChannelPool(nn.Module):
def forward(self, x):
# 全局平均池化和全局最大池化连接
return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
class spatial_attn_layer(nn.Module):
def __init__(self, kernel_size=5):
super(spatial_attn_layer, self).__init__()
self.compress = ChannelPool()
# 输入通道数2 输出通道数1
self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
def forward(self, x):
# import pdb;pdb.set_trace()
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = torch.sigmoid(x_out) # broadcasting
return x * scale
##---------- Dual Attention Unit (DAU) ----------
class DAU(nn.Module):
def __init__(
self, n_feat, kernel_size=3, reduction=8,
bias=False, bn=False, act=nn.PReLU(), res_scale=1):
super(DAU, self).__init__()
# Conv PReLU Conv
modules_body = [conv(n_feat, n_feat, kernel_size, bias=bias), act, conv(n_feat, n_feat, kernel_size, bias=bias)]
self.body = nn.Sequential(*modules_body)
## Spatial Attention
self.SA = spatial_attn_layer()
## Channel Attention
self.CA = ca_layer(n_feat,reduction, bias=bias)
self.conv1x1 = nn.Conv2d(n_feat*2, n_feat, kernel_size=1, bias=bias)
def forward(self, x):
res = self.body(x)
sa_branch = self.SA(res)
ca_branch = self.CA(res)
res = torch.cat([sa_branch, ca_branch], dim=1)
res = self.conv1x1(res)
res += x
return res
##---------- Resizing Modules ----------
# Fig 4(a)
class ResidualDownSample(nn.Module):
def __init__(self, in_channels, bias=False):
super(ResidualDownSample, self).__init__()
# Conv1x1 PReLU -> Conv3x3 PReLU -> Antialiasing Down-sampling -> Conv1x1
self.top = nn.Sequential(nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0, bias=bias),
nn.PReLU(),
nn.Conv2d(in_channels, in_channels, 3, stride=1, padding=1, bias=bias),
nn.PReLU(),
downsamp(channels=in_channels,filt_size=3,stride=2),
# 通道数由 C 变为 2C
nn.Conv2d(in_channels, in_channels*2, 1, stride=1, padding=0, bias=bias))
# Antialiasing Down-sampling -> Conv1x1 通道数由C变成2C
self.bot = nn.Sequential(downsamp(channels=in_channels,filt_size=3,stride=2),
nn.Conv2d(in_channels, in_channels*2, 1, stride=1, padding=0, bias=bias))
def forward(self, x):
top = self.top(x)
bot = self.bot(x)
# 连接
out = top+bot
return out
# Fig 4 (b)
class ResidualUpSample(nn.Module):
def __init__(self, in_channels, bias=False):
super(ResidualUpSample, self).__init__()
# Conv1x1 PReLU -> Conv3x3 PReLU -> Conv1x1
self.top = nn.Sequential(nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0, bias=bias),
nn.PReLU(),
nn.ConvTranspose2d(in_channels, in_channels, 3, stride=2, padding=1, output_padding=1,bias=bias),
nn.PReLU(),
# 通道数由C变成2C
nn.Conv2d(in_channels, in_channels//2, 1, stride=1, padding=0, bias=bias))
# Bilinear Up-sampling -> Conv1x1 通道数由C变成2C
self.bot = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=bias),
nn.Conv2d(in_channels, in_channels//2, 1, stride=1, padding=0, bias=bias))
def forward(self, x):
top = self.top(x)
bot = self.bot(x)
out = top+bot
return out
class DownSample(nn.Module):
def __init__(self, in_channels, scale_factor, stride=2, kernel_size=3):
super(DownSample, self).__init__()
self.scale_factor = int(np.log2(scale_factor))
modules_body = []
for i in range(self.scale_factor):
modules_body.append(ResidualDownSample(in_channels))
in_channels = int(in_channels * stride)
self.body = nn.Sequential(*modules_body)
def forward(self, x):
x = self.body(x)
return x
class UpSample(nn.Module):
def __init__(self, in_channels, scale_factor, stride=2, kernel_size=3):
super(UpSample, self).__init__()
self.scale_factor = int(np.log2(scale_factor))
modules_body = []
for i in range(self.scale_factor):
modules_body.append(ResidualUpSample(in_channels))
in_channels = int(in_channels // stride)
self.body = nn.Sequential(*modules_body)
def forward(self, x):
x = self.body(x)
return x
##---------- Multi-Scale Resiudal Block (MSRB) ----------
# MRB
class MSRB(nn.Module):
def __init__(self, n_feat, height, width, stride, bias):
super(MSRB, self).__init__()
self.n_feat, self.height, self.width = n_feat, height, width
self.blocks = nn.ModuleList([nn.ModuleList([DAU(int(n_feat*stride**i))]*width) for i in range(height)])
INDEX = np.arange(0,width, 2)
FEATS = [int((stride**i)*n_feat) for i in range(height)]
SCALE = [2**i for i in range(1,height)]
self.last_up = nn.ModuleDict()
for i in range(1,height):
self.last_up.update({f'{i}': UpSample(int(n_feat*stride**i),2**i,stride)})
self.down = nn.ModuleDict()
self.up = nn.ModuleDict()
i=0
SCALE.reverse()
for feat in FEATS:
for scale in SCALE[i:]:
self.down.update({f'{feat}_{scale}': DownSample(feat,scale,stride)})
i+=1
i=0
FEATS.reverse()
for feat in FEATS:
for scale in SCALE[i:]:
self.up.update({f'{feat}_{scale}': UpSample(feat,scale,stride)})
i+=1
self.conv_out = nn.Conv2d(n_feat, n_feat, kernel_size=3, padding=1, bias=bias)
self.selective_kernel = nn.ModuleList([SKFF(n_feat*stride**i, height) for i in range(height)])
def forward(self, x):
inp = x.clone()
#col 1 only
blocks_out = []
for j in range(self.height):
if j==0:
inp = self.blocks[j][0](inp)
else:
inp = self.blocks[j][0](self.down[f'{inp.size(1)}_{2}'](inp))
blocks_out.append(inp)
#rest of grid
for i in range(1,self.width):
#Mesh
# Replace condition(i%2!=0) with True(Mesh) or False(Plain)
# if i%2!=0:
if True:
tmp=[]
for j in range(self.height):
TENSOR = []
nfeats = (2**j)*self.n_feat
for k in range(self.height):
TENSOR.append(self.select_up_down(blocks_out[k], j, k))
selective_kernel_fusion = self.selective_kernel[j](TENSOR)
tmp.append(selective_kernel_fusion)
#Plain
else:
tmp = blocks_out
#Forward through either mesh or plain
for j in range(self.height):
blocks_out[j] = self.blocks[j][i](tmp[j])
#Sum after grid
out=[]
for k in range(self.height):
out.append(self.select_last_up(blocks_out[k], k))
out = self.selective_kernel[0](out)
out = self.conv_out(out)
out = out + x
return out
def select_up_down(self, tensor, j, k):
if j==k:
return tensor
else:
diff = 2 ** np.abs(j-k)
if j<k:
return self.up[f'{tensor.size(1)}_{diff}'](tensor)
else:
return self.down[f'{tensor.size(1)}_{diff}'](tensor)
def select_last_up(self, tensor, k):
if k==0:
return tensor
else:
return self.last_up[f'{k}'](tensor)
##---------- Recursive Residual Group (RRG) ----------
class RRG(nn.Module):
def __init__(self, n_feat, n_MSRB, height, width, stride, bias=False):
super(RRG, self).__init__()
modules_body = [MSRB(n_feat, height, width, stride, bias) for _ in range(n_MSRB)]
modules_body.append(conv(n_feat, n_feat, kernel_size=3))
self.body = nn.Sequential(*modules_body)
def forward(self, x):
res = self.body(x)
res += x
return res
##---------- MIRNet -----------------------
class MIRNet(nn.Module):
def __init__(self, in_channels=3, out_channels=3, n_feat=64, kernel_size=3, stride=2, n_RRG=3, n_MSRB=2, height=3, width=2, bias=False):
super(MIRNet, self).__init__()
# Conv I 通过卷积得到X0
self.conv_in = nn.Conv2d(in_channels, n_feat, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=bias)
modules_body = [RRG(n_feat, n_MSRB, height, width, stride, bias) for _ in range(n_RRG)]
self.body = nn.Sequential(*modules_body)
# Xd 通过卷积得到R
self.conv_out = nn.Conv2d(n_feat, out_channels, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=bias)
def forward(self, x):
h = self.conv_in(x)
h = self.body(h)
h = self.conv_out(h)
# 残差连接
h += x
return h