【图像Restoration】Learning Enriched Features for Real Image Restoration and Enhancement

 论文地址:https://arxiv.org/pdf/2003.06792.pdfhttps://arxiv.org/pdf/2003.06792.pdf

代码地址:https://github.com/swz30/MIRNethttps://github.com/swz30/MIRNet

本博客为阅读这篇论文的一些记录,其具体内容和代码可以参考原论文和github上官方实现。个人觉得论文中的图画的非常直观,代码的可读性也非常好,优雅! 

 现有的基于CNN的方法通常在全分辨率或逐步低分辨率的表示上操作。

该论文的核心是一个包含几个关键元素的多尺度残差块:

  • parallel multi-resolution convolution streams for extracting multi-scale features;
  • information exchange across the multi-resolution streams;
  • spatial and channel attention mechanisms for capturing contextual information;
  • attention based multi-scale feature aggregation.

对于图像restoration和 enhancement 最常用的两种CNN方法:

  1. 编解码结构 encoder-decoder 

        编解码模型,在低分率空间进行处理;可以学习广泛的背景。精细的空间细节丢失了。

     2.高分辨率(单尺度)特征处理 high-resolution (single-scale) feature processing

         high-resolution 高分辨率不进行下采样操作。保留更精细的空间信息。缺乏感受野。编码上下文效率较差。

 论文主要贡献:

  1. 一种新的特征提取模型,在多个空间尺度上获得互补的特征集,同时保持原始的高分辨率特征,以保持精确的空间细节。(类似于HRNet的思想)
  2. 一种定期重复的信息交换机制,其中跨多分辨率分支的特征逐渐融合在一起,以改进表示学习。
  3. 一种利用选择性核网络融合多尺度特征的新方法,该网络动态地结合可变的感受野,并在每个空间分辨率上地保留原始特征信息。
  4. 一种递归残差设计,逐步分解输入信号,以简化整体学习过程,并允许构建非常深的网络。
  5. 在5个真实图像基准数据集上进行了不同的图像处理任务,包括图像去噪、图像分辨率和图像增强。在所有五个数据集上都达到了SOTA。

与现有的方法不同,该论文以原始分辨率处理特征,以保留空间细节,同时有效地融合了来自多个并行分支的上下文信息。在我看来类似于HRNet的思想

所提出的网络MIRNet的框架,学习丰富的特征表示用于图像恢复和增强。MIRNet是基于递归残差设计的。MIRNet的核心是多尺度残差块(MRB),其主要分支致力于通过整个网络维护空间精确的高分辨率表示,而互补的并行分支集提供了更好的上下文化特征。它还允许通过选择性内核特征融合(SKFF)跨并行流进行信息交换,以便在低分辨率特性的帮助下整合高分辨率特性,反之亦然。


 多尺度残差模块(multi-scale residual block)包含的关键组件:

  • 并行多分辨率卷积流,用于提取(细到粗)语义更丰富和(粗到细)空间精确的特征表示
  • 跨多分辨率流的信息交换
  • 来自多个流的基于注意力的特征聚合
  • 利用双注意单元来捕获空间和通道维度上的上下文信息
  • 残差大小调整模块,以执行降采样和上采样操作

Overall Pipeline

 如图所示,给定输入图像 I ,网络首先应用卷积操作提取low-level的特征X0 。随后X0 通过 N 个RRGs(recursive residual groups)产生深度特征Xd。然后对Xd应用卷积操作,获得residual 图像R。最终恢复的图像被表示为:

损失函数:


 Multi-scale Residual Block(MRB)

 本文提出了多尺度残差块(MRB),如图1所示。它能够通过维护高分辨率的表示来生成空间精确的输出,同时从低分辨率接收丰富的上下文信息。


Selective kernel feature fusion(SKFF)

SKFF操作来自多个卷积流的特征,并基于自注意力执行聚合。

如图2所示,SKFF模块通过FuseSelect这两个操作对感受野进行动态调整。Fuse 通过结合来自多分辨率流的信息来生成全局特征描述。Select使用这些描述来重新校准特征映射(不同的流),然后对它们进行聚合。 

(1)Fuse:SKFF接收来自三个携带不同尺度信息的并行卷积流的输入。我们首先使用element-wise  summation 组合这些多尺度特征为:L=L1+L2+L3;然后在空间维度使用GAP(全局平均池化)来计算channel-wise 统计 s 。接下来,应用一个信道降尺度卷积层来生成一个紧凑的特征表示z。最后,特征向量z通过三个并行通道升级卷积层(每个分辨率流一个),并提供三个特征描述符v1、v2和v3,每个层的维数为1×1×C。

(2)Select:该操作将softmax函数应用于v1、v2和v3,产生注意激活s1、s2和s3,分别用来自适应地重新校准多尺度特征图L1、L2和L3。

 特征重新校准和聚合的总体过程定义为:U=s1·L1+s2·L2+s3·L3


  Dual attention unit(DAU)

虽然SKFF块跨多分辨率分支融合信息,但我们还需要一种机制来在特征张量内沿着空间和通道维度共享信息。 

 (1) Channel attention(CA 通道注意力)

给定一个特征图M,挤压操作跨空间维度应用全局平均池来编码全局上下文,从而生成一个特征描述符d。(Fig3 下部分bottom)。经过两个卷积层和sigmoid生成 d hat .与M相乘。

(2) Spatial attention (SA 空间注意力)
SA的目标是生成一个空间注意图,并使用它来重新校准传入的特征M。为了生成空间注意图,SA分支首先沿着信道维度对特征M独立应用全局平均池(torch.mean(x,1))和最大池化(torch.max(x,1))操作,并将输出连接起来,形成特征图f∈RH×W×2。图 f 通过卷积和sigmoid激活,得到空间注意图 ˆf,然后我们用它重新缩放M。

Residual resizing modules


代码部分:

# 卷积模块
def conv(in_channels, out_channels, kernel_size, bias=False, padding = 1, stride = 1):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size//2), bias=bias, stride = stride)


##---------- Selective Kernel Feature Fusion (SKFF) ----------
class SKFF(nn.Module):
    def __init__(self, in_channels, height=3,reduction=8,bias=False):
        super(SKFF, self).__init__()
        
        self.height = height
        d = max(int(in_channels/reduction),4)
        # 全局平均池化
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # Conv PReLU
        self.conv_du = nn.Sequential(nn.Conv2d(in_channels, d, 1, padding=0, bias=bias), nn.PReLU())

        self.fcs = nn.ModuleList([])
        for i in range(self.height):
            self.fcs.append(nn.Conv2d(d, in_channels, kernel_size=1, stride=1,bias=bias))
        
        self.softmax = nn.Softmax(dim=1)

    def forward(self, inp_feats):
        batch_size = inp_feats[0].shape[0]
        n_feats =  inp_feats[0].shape[1]
        
        inp_feats = torch.cat(inp_feats, dim=1)
        inp_feats = inp_feats.view(batch_size, self.height, n_feats, inp_feats.shape[2], inp_feats.shape[3])
        # element-wise summation
        feats_U = torch.sum(inp_feats, dim=1)
        # GAP
        feats_S = self.avg_pool(feats_U)
        # Conv PReLU
        feats_Z = self.conv_du(feats_S)

        attention_vectors = [fc(feats_Z) for fc in self.fcs]
        attention_vectors = torch.cat(attention_vectors, dim=1)
        attention_vectors = attention_vectors.view(batch_size, self.height, n_feats, 1, 1)
        # stx()
        attention_vectors = self.softmax(attention_vectors)
        
        feats_V = torch.sum(inp_feats*attention_vectors, dim=1)
        
        return feats_V        

  

## ------ Channel Attention --------------
class ca_layer(nn.Module):
    def __init__(self, channel, reduction=8, bias=True):
        super(ca_layer, self).__init__()
        # global average pooling: feature --> point
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # feature channel downscale and upscale --> channel weight
        self.conv_du = nn.Sequential(
                nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias),
                nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv_du(y)
        return x * y

##---------- Spatial Attention ----------
# Conv  BN  ReLu  
class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=False, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x

# GAP and GMP  
class ChannelPool(nn.Module):
    def forward(self, x):
        # 全局平均池化和全局最大池化连接
        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )

class spatial_attn_layer(nn.Module):
    def __init__(self, kernel_size=5):
        super(spatial_attn_layer, self).__init__()
        self.compress = ChannelPool()
        # 输入通道数2 输出通道数1
        self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
    def forward(self, x):
        # import pdb;pdb.set_trace()
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = torch.sigmoid(x_out) # broadcasting
        return x * scale

##---------- Dual Attention Unit (DAU) ----------
class DAU(nn.Module):
    def __init__(
        self, n_feat, kernel_size=3, reduction=8,
        bias=False, bn=False, act=nn.PReLU(), res_scale=1):

        super(DAU, self).__init__()
        # Conv PReLU Conv
        modules_body = [conv(n_feat, n_feat, kernel_size, bias=bias), act, conv(n_feat, n_feat, kernel_size, bias=bias)]
        self.body = nn.Sequential(*modules_body)
        
        ## Spatial Attention
        self.SA = spatial_attn_layer()

        ## Channel Attention        
        self.CA = ca_layer(n_feat,reduction, bias=bias)

        self.conv1x1 = nn.Conv2d(n_feat*2, n_feat, kernel_size=1, bias=bias)

    def forward(self, x):
        res = self.body(x)
        sa_branch = self.SA(res)
        ca_branch = self.CA(res)
        res = torch.cat([sa_branch, ca_branch], dim=1)
        res = self.conv1x1(res)
        res += x
        return res

##---------- Resizing Modules ----------
# Fig 4(a)    
class ResidualDownSample(nn.Module):
    def __init__(self, in_channels, bias=False):
        super(ResidualDownSample, self).__init__()
        # Conv1x1 PReLU  ->  Conv3x3 PReLU  -> Antialiasing Down-sampling -> Conv1x1
        self.top = nn.Sequential(nn.Conv2d(in_channels, in_channels,   1, stride=1, padding=0, bias=bias),
                                nn.PReLU(),
                                nn.Conv2d(in_channels, in_channels, 3, stride=1, padding=1, bias=bias),
                                nn.PReLU(),
                                downsamp(channels=in_channels,filt_size=3,stride=2),
                                # 通道数由 C 变为 2C
                                nn.Conv2d(in_channels, in_channels*2, 1, stride=1, padding=0, bias=bias))
        # Antialiasing Down-sampling -> Conv1x1  通道数由C变成2C
        self.bot = nn.Sequential(downsamp(channels=in_channels,filt_size=3,stride=2),
                                nn.Conv2d(in_channels, in_channels*2, 1, stride=1, padding=0, bias=bias))

    def forward(self, x):
        top = self.top(x)
        bot = self.bot(x)
        # 连接
        out = top+bot
        return out

# Fig 4 (b)
class ResidualUpSample(nn.Module):
    def __init__(self, in_channels, bias=False):
        super(ResidualUpSample, self).__init__()
        # Conv1x1 PReLU -> Conv3x3 PReLU -> Conv1x1
        self.top = nn.Sequential(nn.Conv2d(in_channels, in_channels,   1, stride=1, padding=0, bias=bias),
                                nn.PReLU(),
                                nn.ConvTranspose2d(in_channels, in_channels, 3, stride=2, padding=1, output_padding=1,bias=bias),
                                nn.PReLU(),
                                # 通道数由C变成2C
                                nn.Conv2d(in_channels, in_channels//2, 1, stride=1, padding=0, bias=bias))
        # Bilinear Up-sampling -> Conv1x1 通道数由C变成2C
        self.bot = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=bias),
                                nn.Conv2d(in_channels, in_channels//2, 1, stride=1, padding=0, bias=bias))

    def forward(self, x):
        top = self.top(x)
        bot = self.bot(x)
        out = top+bot
        return out

class DownSample(nn.Module):
    def __init__(self, in_channels, scale_factor, stride=2, kernel_size=3):
        super(DownSample, self).__init__()
        self.scale_factor = int(np.log2(scale_factor))

        modules_body = []
        for i in range(self.scale_factor):
            modules_body.append(ResidualDownSample(in_channels))
            in_channels = int(in_channels * stride)
        
        self.body = nn.Sequential(*modules_body)

    def forward(self, x):
        x = self.body(x)
        return x
class UpSample(nn.Module):
    def __init__(self, in_channels, scale_factor, stride=2, kernel_size=3):
        super(UpSample, self).__init__()
        self.scale_factor = int(np.log2(scale_factor))

        modules_body = []
        for i in range(self.scale_factor):
            modules_body.append(ResidualUpSample(in_channels))
            in_channels = int(in_channels // stride)
        
        self.body = nn.Sequential(*modules_body)

    def forward(self, x):
        x = self.body(x)
        return x

 

##---------- Multi-Scale Resiudal Block (MSRB) ----------
# MRB
class MSRB(nn.Module):
    def __init__(self, n_feat, height, width, stride, bias):
        super(MSRB, self).__init__()

        self.n_feat, self.height, self.width = n_feat, height, width
        self.blocks = nn.ModuleList([nn.ModuleList([DAU(int(n_feat*stride**i))]*width) for i in range(height)])

        INDEX = np.arange(0,width, 2)
        FEATS = [int((stride**i)*n_feat) for i in range(height)]
        SCALE = [2**i for i in range(1,height)]

        self.last_up   = nn.ModuleDict()
        for i in range(1,height):
            self.last_up.update({f'{i}': UpSample(int(n_feat*stride**i),2**i,stride)})

        self.down = nn.ModuleDict()
        self.up   = nn.ModuleDict()

        i=0
        SCALE.reverse()
        for feat in FEATS:
            for scale in SCALE[i:]:
                self.down.update({f'{feat}_{scale}': DownSample(feat,scale,stride)})
            i+=1

        i=0
        FEATS.reverse()
        for feat in FEATS:
            for scale in SCALE[i:]:                
                self.up.update({f'{feat}_{scale}': UpSample(feat,scale,stride)})
            i+=1

        self.conv_out = nn.Conv2d(n_feat, n_feat, kernel_size=3, padding=1, bias=bias)

        self.selective_kernel = nn.ModuleList([SKFF(n_feat*stride**i, height) for i in range(height)])
        


    def forward(self, x):
        inp = x.clone()
        #col 1 only
        blocks_out = []
        for j in range(self.height):
            if j==0:
                inp = self.blocks[j][0](inp)
            else:
                inp = self.blocks[j][0](self.down[f'{inp.size(1)}_{2}'](inp))
            blocks_out.append(inp)

        #rest of grid
        for i in range(1,self.width):
            #Mesh
            # Replace condition(i%2!=0) with True(Mesh) or False(Plain)
            # if i%2!=0:
            if True:
                tmp=[]
                for j in range(self.height):
                    TENSOR = []
                    nfeats = (2**j)*self.n_feat
                    for k in range(self.height):
                        TENSOR.append(self.select_up_down(blocks_out[k], j, k)) 

                    selective_kernel_fusion = self.selective_kernel[j](TENSOR)
                    tmp.append(selective_kernel_fusion)
            #Plain
            else:
                tmp = blocks_out
            #Forward through either mesh or plain
            for j in range(self.height):
                blocks_out[j] = self.blocks[j][i](tmp[j])

        #Sum after grid
        out=[]
        for k in range(self.height):
            out.append(self.select_last_up(blocks_out[k], k))  

        out = self.selective_kernel[0](out)

        out = self.conv_out(out)
        out = out + x

        return out

    def select_up_down(self, tensor, j, k):
        if j==k:
            return tensor
        else:
            diff = 2 ** np.abs(j-k)
            if j<k:
                return self.up[f'{tensor.size(1)}_{diff}'](tensor)
            else:
                return self.down[f'{tensor.size(1)}_{diff}'](tensor)


    def select_last_up(self, tensor, k):
        if k==0:
            return tensor
        else:
            return self.last_up[f'{k}'](tensor)

##---------- Recursive Residual Group (RRG) ----------
class RRG(nn.Module):
    def __init__(self, n_feat, n_MSRB, height, width, stride, bias=False):
        super(RRG, self).__init__()
        modules_body = [MSRB(n_feat, height, width, stride, bias) for _ in range(n_MSRB)]
        modules_body.append(conv(n_feat, n_feat, kernel_size=3))
        self.body = nn.Sequential(*modules_body)

    def forward(self, x):
        res = self.body(x)
        res += x
        return res

##---------- MIRNet  -----------------------
class MIRNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, n_feat=64, kernel_size=3, stride=2, n_RRG=3, n_MSRB=2, height=3, width=2, bias=False):
        super(MIRNet, self).__init__()
        # Conv  I 通过卷积得到X0
        self.conv_in = nn.Conv2d(in_channels, n_feat, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=bias)
        modules_body = [RRG(n_feat, n_MSRB, height, width, stride, bias) for _ in range(n_RRG)]
        self.body = nn.Sequential(*modules_body)
        # Xd 通过卷积得到R
        self.conv_out = nn.Conv2d(n_feat, out_channels, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=bias)

    def forward(self, x):
        h = self.conv_in(x)
        h = self.body(h)
        h = self.conv_out(h)
        # 残差连接
        h += x
        return h

  • 5
    点赞
  • 40
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

乐亦亦乐

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值