【图像去雨】DCSFN: Deep Cross-scale Fusion Network for Single Image RainRemoval

代码地址:GitHub - Ohraincu/DCSFN: DCSFN: Deep Cross-scale Fusion Network for Single Image Rain RemovalDCSFN: Deep Cross-scale Fusion Network for Single Image Rain Removal - GitHub - Ohraincu/DCSFN: DCSFN: Deep Cross-scale Fusion Network for Single Image Rain Removalhttps://github.com/Ohraincu/DCSFN

 

对于去雨任务,基于单个网络结构的训练而不考虑跨尺度关系,可能会导致信息丢失。具体来说,为了学习不同尺度的特征,论文提出了一个多子网(multi-sub-networks)络结构,其中这些子网通过门循环单元的跨尺度方式融合,以内部学习并充分利用这些子网络中不同尺度的信息。 

首先,作者设计了一个内尺度(inner-scale)的连接块,通过建立每个尺度之间的相关性来融合不同尺度的特征,从而更好地学习雨的特征。

其次,为了最大限度地增加信息流,能够计算远程空间依赖性,并有效地利用后续层的特征激活,本文引入了具有密集连接结构的编码器和解码器,并通过跳跃连接来内部连接这些块。

最后,提出了一种跨尺度融合网络来学习不同尺度上的特征,其中提出的跨尺度方式通过门循环单元(GRU)将不同尺度上的特征连接起来,充分利用了不同尺度上的信息。 

本文主要贡献:

  • 通过建立每个尺度之间的相关性,建立了一个尺度内的连接块,从而更好地学习雨的特征。 
  • 引入了具有密集连接结构的编码器和解码器,并跳过连接到内部连接,以提高去雨的性能。
  • 提出了一种跨尺度融合网络来学习不同尺度的特征,其中提出的跨尺度方式通过门循环单元连接不同尺度的特征,充分利用不同尺度的信息。
  • 在合成和真实数据集上的实验结果证明了提出的方法的优越性,它优于最先进的方法。

 

 

 

网络整体结构:

 该网络由三个不同规模的子网络组成。在每个尺度上,它们都有相同的结构,具有密集连接的编码器和解码器(如图2所示)。在每个编码器的最后,GRU通过跨尺度的方式融合不同尺度上的特征,以充分利用不同尺度上的信息。最后,将不同尺度上的所有特征融合,生成估计的雨,得到最终估计的无雨图像。

Inner-scale Connection Block
图2中的编码器由一系列内尺度的连接块组成(如图3所示)。
内尺度连接块,首先,利用全局最大池化来获得不同尺度的特征。其次,将不同尺度之间的多次卷积后的特征连接起来,以促进相关信息的探索。最后,通过融合不同尺度上的所有特征来学习主要特征。

代码:
Inner-scale Connection Block

 论文及图中对应scale_num为3。1x  1/2 x  1/4 x

# Figure 3 Inner-scale Connection Block
class Inner_scale_connection_block(nn.Module):
    def __init__(self):
        super(Inner_scale_connection_block, self).__init__()
        self.channel = settings.channel
        self.scale_num = settings.scale_num
        self.conv_num = settings.conv_num  # 论文中为4
        self.scale1 = nn.ModuleList()
        self.scale2 = nn.ModuleList()
        self.scale4 = nn.ModuleList()
        self.scale8 = nn.ModuleList()
        if settings.scale_num == 4:
            for i in range(self.conv_num):
                self.scale1.append(nn.Sequential(nn.Conv2d(self.channel, self.channel, 3, 1, 1), nn.LeakyReLU(0.2)))
                self.scale2.append(nn.Sequential(nn.Conv2d(self.channel, self.channel, 3, 1, 1), nn.LeakyReLU(0.2)))
                self.scale4.append(nn.Sequential(nn.Conv2d(self.channel, self.channel, 3, 1, 1), nn.LeakyReLU(0.2)))
                self.scale8.append(nn.Sequential(nn.Conv2d(self.channel, self.channel, 3, 1, 1), nn.LeakyReLU(0.2)))
            self.fusion84 = nn.Sequential(nn.Conv2d(2 * self.channel, self.channel, 1, 1), nn.LeakyReLU(0.2))
            self.fusion42 = nn.Sequential(nn.Conv2d(2 * self.channel, self.channel, 1, 1), nn.LeakyReLU(0.2))
            self.fusion21 = nn.Sequential(nn.Conv2d(2 * self.channel, self.channel, 1, 1), nn.LeakyReLU(0.2))
            self.pooling8 = nn.MaxPool2d(8, 8)
            self.pooling4 = nn.MaxPool2d(4, 4)
            self.pooling2 = nn.MaxPool2d(2, 2)
            self.fusion_all = nn.Sequential(nn.Conv2d(4 * self.channel, self.channel, 1, 1), nn.LeakyReLU(0.2))
        # 论文中(K=3)scale_num为3
        elif settings.scale_num == 3:
            for i in range(self.conv_num):
                #use 𝐿𝑒𝑎𝑘𝑦𝑅𝑒𝐿𝑈 with 𝛼 = 0.2
                self.scale1.append(nn.Sequential(nn.Conv2d(self.channel, self.channel, 3, 1, 1), nn.LeakyReLU(0.2)))
                self.scale2.append(nn.Sequential(nn.Conv2d(self.channel, self.channel, 3, 1, 1), nn.LeakyReLU(0.2)))
                self.scale4.append(nn.Sequential(nn.Conv2d(self.channel, self.channel, 3, 1, 1), nn.LeakyReLU(0.2)))
            self.fusion42 = nn.Sequential(nn.Conv2d(2 * self.channel, self.channel, 1, 1), nn.LeakyReLU(0.2))
            self.fusion21 = nn.Sequential(nn.Conv2d(2 * self.channel, self.channel, 1, 1), nn.LeakyReLU(0.2))
            self.pooling4 = nn.MaxPool2d(4, 4)
            self.pooling2 = nn.MaxPool2d(2, 2)
            self.fusion_all = nn.Sequential(nn.Conv2d(3 * self.channel, self.channel, 1, 1), nn.LeakyReLU(0.2))
        elif settings.scale_num == 2:
            for i in range(self.conv_num):
                self.scale1.append(nn.Sequential(nn.Conv2d(self.channel, self.channel, 3, 1, 1), nn.LeakyReLU(0.2)))
                self.scale2.append(nn.Sequential(nn.Conv2d(self.channel, self.channel, 3, 1, 1), nn.LeakyReLU(0.2)))
            self.fusion21 = nn.Sequential(nn.Conv2d(2 * self.channel, self.channel, 1, 1), nn.LeakyReLU(0.2))
            self.pooling2 = nn.MaxPool2d(2, 2)
            self.fusion_all = nn.Sequential(nn.Conv2d(2 * self.channel, self.channel, 1, 1), nn.LeakyReLU(0.2))
        elif settings.scale_num == 1:
            for i in range(self.conv_num):
                self.scale1.append(nn.Sequential(nn.Conv2d(self.channel, self.channel, 3, 1, 1), nn.LeakyReLU(0.2)))

    def forward(self, x):
        if settings.scale_num == 4:
            feature8 = self.pooling8(x)
            b8, c8, h8, w8 = feature8.size()
            feature4 = self.pooling4(x)
            b4, c4, h4, w4 = feature4.size()
            feature2 = self.pooling2(x)
            b2, c2, h2, w2 = feature2.size()
            feature1 = x
            b1, c1, h1, w1 = feature1.size()
            for i in range(self.conv_num):
                feature8 = self.scale8[i](feature8)
            scale8 = feature8
            feature4 = self.fusion84(torch.cat([feature4, F.upsample(scale8, [h4, w4])], dim=1))
            for i in range(self.conv_num):
                feature4 = self.scale4[i](feature4)
            scale4 = feature4
            feature2 = self.fusion42(torch.cat([feature2, F.upsample(scale4, [h2, w2])], dim=1))
            for i in range(self.conv_num):
                feature2 = self.scale2[i](feature2)

            scale2 = feature2
            feature1 = self.fusion21(torch.cat([feature1, F.upsample(scale2, [h1, w1])], dim=1))
            for i in range(self.conv_num):
                feature1 = self.scale1[i](feature1)
            scale1 = feature1
            fusion_all = self.fusion_all(torch.cat([scale1, F.upsample(scale2, [h1, w1]), F.upsample(scale4, [h1, w1]), F.upsample(scale8, [h1, w1])], dim=1))
            return fusion_all + x
        # 论文中scale_num 为3
        elif settings.scale_num == 3:
            # Global Max Pooling operation with 𝑘 × 𝑘 kernels and 𝑘 × 𝑘 strides
            # Firstly,Global Max Pooling is utiled to obtain features with different scales.
            feature4 = self.pooling4(x) # 1/4 x
            b4, c4, h4, w4 = feature4.size()
            feature2 = self.pooling2(x) # 1/2 x
            b2, c2, h2, w2 = feature2.size()
            feature1 = x # 1x
            b1, c1, h1, w1 = feature1.size()
            #  Secondly,features after serveral convolutions are connected between different scales to boost the correlation
            # information exploration.
            for i in range(self.conv_num):
                feature4 = self.scale4[i](feature4)
            scale4 = feature4
            feature2 = self.fusion42(torch.cat([feature2, F.upsample(scale4, [h2, w2])], dim=1))
            for i in range(self.conv_num):
                feature2 = self.scale2[i](feature2)
            scale2 = feature2
            feature1 = self.fusion21(torch.cat([feature1, F.upsample(scale2, [h1, w1])], dim=1))
            for i in range(self.conv_num):
                feature1 = self.scale1[i](feature1)
            scale1 = feature1
            # Lastly,all features at different scales are fused to learn main features.
            fusion_all = self.fusion_all(torch.cat([scale1, F.upsample(scale2, [h1, w1]), F.upsample(scale4, [h1, w1])],dim=1))
            return fusion_all + x
        elif settings.scale_num == 2:
            feature2 = self.pooling2(x)
            b2, c2, h2, w2 = feature2.size()
            feature1 = x
            b1, c1, h1, w1 = feature1.size()

            for i in range(self.conv_num):
                feature2 = self.scale2[i](feature2)
            scale2 = feature2
            feature1 = self.fusion21(torch.cat([feature1, F.upsample(scale2, [h1, w1])], dim=1))
            for i in range(self.conv_num):
                feature1 = self.scale1[i](feature1)
            scale1 = feature1
            fusion_all = self.fusion_all(
                torch.cat([scale1, F.upsample(scale2, [h1, w1])], dim=1))
            return fusion_all + x
        elif settings.scale_num == 1:
            feature1 = x
            b1, c1, h1, w1 = feature1.size()
            scale1 = self.scale1(feature1)
            fusion_all = scale1
            return fusion_all + x

编码器:

Scale_block = Inner_scale_connection_block

# Figure 2(a)
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.unit_num = settings.Num_encoder
        self.units = nn.ModuleList()
        self.channel_num = settings.channel
        self.conv1x1 = nn.ModuleList()
        # The encoder consists of a series of Inner-scale Connection Block with dense connection
        for i in range(self.unit_num):
            self.units.append(Scale_block())
            self.conv1x1.append(nn.Sequential(nn.Conv2d((i + 2) * self.channel_num, self.channel_num, 1, 1), nn.LeakyReLU(0.2)))

    def forward(self, x):
        catcompact = []
        catcompact.append(x)
        feature = []
        out = x
        for i in range(self.unit_num):
            tmp = self.units[i](out)
            feature.append(tmp)
            catcompact.append(tmp)
            out = self.conv1x1[i](torch.cat(catcompact, dim=1))
        return out, feature

解码器:

# Figure 2(b)
# The Decoder has the same structure with encoder and skip connection
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.unit_num = settings.Num_encoder
        self.units = nn.ModuleList()
        self.channel_num = settings.channel
        self.conv1x1 = nn.ModuleList()

        for i in range(self.unit_num):
            self.units.append(Scale_block())
            self.conv1x1.append(nn.Sequential(nn.Conv2d((i + 2) * self.channel_num, self.channel_num, 1, 1), nn.LeakyReLU(0.2)))

    def forward(self, x, feature):
        catcompact=[]
        catcompact.append(x)
        out = x
        for i in range(self.unit_num):
            tmp = self.units[i](out + feature[i])
            catcompact.append(tmp)
            out = self.conv1x1[i](torch.cat(catcompact, dim=1))
        return out

class Multi_model_fusion_learning(nn.Module):
    def __init__(self):
        super(Multi_model_fusion_learning, self).__init__()
        self.channel_num = settings.channel
        self.convert = nn.Sequential(
            nn.Conv2d(3, self.channel_num, 3, 1, 1),
            nn.LeakyReLU(0.2)
        )
        self.encoder_scale1 = Encoder()
        self.encoder_scale2 = Encoder()
        self.encoder_scale4 = Encoder()
        if settings.Net_cross is True:
            self.fusion_1_2_1 = nn.Sequential(
                nn.Conv2d(2 * self.channel_num, self.channel_num, 1, 1),
                nn.LeakyReLU(0.2))
            self.fusion_1_2_2 = nn.Sequential(
                nn.Conv2d(2 * self.channel_num, self.channel_num, 1, 1),
                nn.LeakyReLU(0.2))
            self.fusion_2_2_1 = nn.Sequential(
                nn.Conv2d(2 * self.channel_num, self.channel_num, 1, 1),
                nn.LeakyReLU(0.2))
            self.fusion_2_2_2 = nn.Sequential(
                nn.Conv2d(2 * self.channel_num, self.channel_num, 1, 1),
                nn.LeakyReLU(0.2))
        self.decoder_scale1 = Decoder()
        self.decoder_scale2 = Decoder()
        self.decoder_scale4 = Decoder()
        # 每个scale的第一个GRU
        self.rec1_1 = RecUnit(self.channel_num, self.channel_num, 3, 1)
        self.rec1_2 = RecUnit(self.channel_num, self.channel_num, 3, 1)
        self.rec1_4 = RecUnit(self.channel_num, self.channel_num, 3, 1)
        # 每个scale的第二个GRU
        self.rec2_1 = RecUnit(self.channel_num, self.channel_num, 3, 1)
        self.rec2_2 = RecUnit(self.channel_num, self.channel_num, 3, 1)
        self.rec2_4 = RecUnit(self.channel_num, self.channel_num, 3, 1)
        # conv1x1 -> LeakyReLU -> conv3x3
        self.merge = nn.Sequential(
            nn.Conv2d(3 * self.channel_num, 3, 1, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(3, 3, 3, 1, 1)
        )
        self.pooling2 = nn.MaxPool2d(2, 2)
        self.pooling4 = nn.MaxPool2d(4, 4)

    def forward(self, x):
        convert = self.convert(x) # 3x3 conv
        feature1 = convert # 1 x
        feature2 = self.pooling2(convert) #1/2 x
        feature4 = self.pooling4(convert) # 1/4 x
        b1, c1, h1, w1 = feature1.size()
        b2, c2, h2, w2 = feature2.size()
        b4, c4, h4, w4 = feature4.size()
        # Encoder
        scale1_encoder, scale1_feature = self.encoder_scale1(feature1)
        scale2_encoder, scale2_feature = self.encoder_scale2(feature2)
        scale4_encoder, scale4_feature = self.encoder_scale4(feature4)

        if settings.Net_cross is True:
            current1_4, rec1_4 = self.rec1_4(scale4_encoder)
            rec1_4_ori = rec1_4
            if settings.uint == "LSTM":
                rec1_4[0], rec1_4[1] = F.upsample(rec1_4_ori[0], [h2,w2]), F.upsample(rec1_4_ori[1], [h2,w2])
                current1_2, rec1_2 = self.rec1_2(scale2_encoder, rec1_4)
                rec1_2_ori = rec1_2
                rec1_2[0], rec1_2[1] = F.upsample(rec1_2_ori[0], [h1, w1]), F.upsample(rec1_2_ori[1], [h1, w1])
                rec1_4[0], rec1_4[1] = F.upsample(rec1_4_ori[0], [h1, w1]), F.upsample(rec1_4_ori[1], [h1, w1])
                current1_1, rec1_1 = self.rec1_1(scale1_encoder, [self.fusion_1_2_1(torch.cat([rec1_2[0], rec1_4[0]],dim=1)), self.fusion_1_2_2(torch.cat([rec1_2[1], rec1_4[1]], dim=1))])
            else:
                rec1_4 = F.upsample(rec1_4_ori, [h2, w2])
                current1_2, rec1_2 = self.rec1_2(scale2_encoder, rec1_4)
                rec1_2_ori = rec1_2
                rec1_2 = F.upsample(rec1_2_ori, [h1, w1])
                rec1_4 = F.upsample(rec1_4_ori, [h1, w1])
                current1_1, rec1_1 = self.rec1_1(scale1_encoder, self.fusion_1_2_1(torch.cat([rec1_2, rec1_4], dim=1)))
        else:
            current1_1 = scale1_encoder
            current1_2 = scale2_encoder
            current1_4 = scale4_encoder

        if settings.Net_cross is True:
            current2_1, rec2_1 = self.rec2_1(current1_1)
            rec2_1_ori = rec2_1
            if settings.uint == "LSTM":
                rec2_1[0], rec2_1[1] = F.upsample(rec2_1_ori[0], [h2, w2]), F.upsample(rec2_1_ori[1], [h2, w2])
                current2_2, rec2_2 = self.rec2_2(current1_2, rec2_1)
                rec2_2_ori = rec2_2
                rec2_2[0], rec2_2[1] = F.upsample(rec2_2_ori[0], [h4, w4]), F.upsample(rec2_2_ori[1], [h4, w4])
                rec2_1[0], rec2_1[1] = F.upsample(rec2_1_ori[0], [h4, w4]), F.upsample(rec2_1_ori[1], [h4, w4])
                current2_4, rec2_4 = self.rec2_4(current1_4, [self.fusion_2_2_1(torch.cat([rec2_1[0],rec2_2[0]],dim=1)), self.fusion_2_2_2(torch.cat([rec2_1[1], rec2_2[1]],dim=1))])
            else:
                rec2_1 = F.upsample(rec2_1_ori, [h2, w2])
                current2_2, rec2_2 = self.rec2_2(current1_2, rec2_1)
                rec2_2_ori = rec2_2
                rec2_2 = F.upsample(rec2_2_ori, [h4, w4])
                rec2_1 = F.upsample(rec2_1_ori, [h4, w4])
                current2_4, rec2_4 = self.rec2_4(current1_4, self.fusion_2_2_1(torch.cat([rec2_1, rec2_2],dim=1)))
        else:
            current2_1 = current1_1
            current2_2 = current1_2
            current2_4 = current1_4
        scale1_decoder = self.decoder_scale1(current2_1, scale1_feature)
        scale2_decoder = self.decoder_scale2(current2_2, scale2_feature)
        scale4_decoder = self.decoder_scale4(current2_4, scale4_feature)
        merge = self.merge(torch.cat([scale1_decoder, F.upsample(scale2_decoder,[h1,w1]), F.upsample(scale4_decoder,[h1,w1])],dim=1))

        return x - merge

其他细节参考源代码,以及原论文!

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

乐亦亦乐

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值