Multi-Scale Boosted Dehazing Network with Dense Feature Fusion笔记和代码

Multi-Scale Boosted Dehazing Network with Dense Feature Fusion笔记和代码

本篇论文的主要创新点是SOS增强策略密集特征融合,创新点均是从其他领域进行挖掘。

摘要

  • 提出了一种基于U-Net结构的具有密集特征融合多尺度增强去雾网络

  • 该方法基于增强反馈误差反馈两种原理进行了设计,并证明了该方法适用于脱雾问题。

  • 通过在该模型的解码器中加入增强-操作-减弱(SOS)的增强策略,开发了一个简单而有效的增强解码器来逐步恢复无雾图像。

  • 为了解决在U-Net架构中保留空间信息的问题,我们设计了一个使用反投影反馈方法的密集特征融合模块。结果表明,密集特征融合模块可以同时弥补高分辨率特征中缺失的空间信息,并利用非相邻特征。

提出的方法

网络结构

网络分为三部分:编码器GEnc、增强解码器GDec和特征恢复模块GRes。如图为网络结构图。

MSBDN结构图

为了逐步恢复有特征恢复模块Gres得到的结构JL,设计了基于SOS增强策略的解码器Gdec。

SOS增强策略

SOS增强策略如下:

Jn+1=g(I + Jn) − Jn

Jn+1是第n次迭代的预测结果,g(⋅)为去雾操作,I + Jn表示用雾图I 增强Jn

如图是五种不同的增强模块。为了完整起见,我们还列出了四个针对SOS提升模块的替代方案。扩散[44]和扭曲[6]方案可以用于设计增强模块,如图(a)和图(b).所示它们可以分别表述为jn = Gnθn ((jn+1) ↑2), 和jn=Gnθn(in−(jn+1)↑2)+(jn+1)↑2。

由于(5)和(6)中的细化单元没有充分利用中的特征,与上采样特征(jn+1)↑2相比,我们采用了SOS结构和空间信息。

SOS策略模块图

密集特征融合模块

如模型结构图所示,在每个级别上都引入了两个DFF模块,一个在编码器中的残差组之前,另一个在解码器中的SOS增强模块之后。编码器/解码器中增强的DFF输出直接连接到编码器/解码器中的所有以下DFF模块,以进行特征融合。

与其他采样和串联融合方法相比,该模块由于其反馈机制,可以更好地从后续层的高分辨率特征中提取高频信息。通过逐步将这些差异融合回降采样的潜在特征中,可以弥补缺失的空间信息。另一方面,该模块可以利用之前所有的高级特征,作为一种纠错反馈机制来改进增强的特征,以获得更好的结果。

密集特征融合模块如图所示。

密集特征融合模块

代码阅读

MSBDN-RDFF.py

这个文件定义的是整个模型的网络结构。文章中说明了所有图像均为3* 256 *256,我通过输入一个对应大小的全一张量对网络中各个模块的输出进行观察。根据论文中的网络结构图,我通过print方法,对其中的模块进行标记并输出经过每个网络层时,特征图的形状和通道数。例如:

        print('Residual Group', x.shape)

代码(带各个模块输出):

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from networks.base_networks import Encoder_MDCBlock1, Decoder_MDCBlock1
def make_model(args, parent=False):
    return Net()

class make_dense(nn.Module):
  def __init__(self, nChannels, growthRate, kernel_size=3):
    super(make_dense, self).__init__()
    self.conv = nn.Conv2d(nChannels, growthRate, kernel_size=kernel_size, padding=(kernel_size-1)//2, bias=False)
  def forward(self, x):
    out = F.relu(self.conv(x))
    out = torch.cat((x, out), 1)
    return out

# Residual dense block (RDB) architecture
class RDB(nn.Module):
  def __init__(self, nChannels, nDenselayer, growthRate, scale = 1.0):
    super(RDB, self).__init__()
    nChannels_ = nChannels
    self.scale = scale
    modules = []
    for i in range(nDenselayer):
        modules.append(make_dense(nChannels_, growthRate))
        nChannels_ += growthRate
    self.dense_layers = nn.Sequential(*modules)
    self.conv_1x1 = nn.Conv2d(nChannels_, nChannels, kernel_size=1, padding=0, bias=False)
  def forward(self, x):
    out = self.dense_layers(x)
    out = self.conv_1x1(out) * self.scale
    out = out + x
    return out

class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        reflection_padding = kernel_size // 2
        self.reflection_pad = nn.ReflectionPad2d(reflection_padding)  # 四个方向填充
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out


class UpsampleConvLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
      super(UpsampleConvLayer, self).__init__()
      self.conv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride)

    def forward(self, x):
        out = self.conv2d(x)
        return out


class ResidualBlock(torch.nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.relu = nn.PReLU()

    def forward(self, x):
        residual = x
        out = self.relu(self.conv1(x))
        out = self.conv2(out) * 0.1
        out = torch.add(out, residual)
        return out

class Net(nn.Module):
    def __init__(self, res_blocks=18):
        super(Net, self).__init__()

        self.conv_input = ConvLayer(3, 16, kernel_size=11, stride=1)
        self.dense0 = nn.Sequential(
            ResidualBlock(16),
            ResidualBlock(16),
            ResidualBlock(16)
        )

        self.conv2x = ConvLayer(16, 32, kernel_size=3, stride=2)
        self.conv1 = RDB(16, 4, 16)
        self.fusion1 = Encoder_MDCBlock1(16, 2, mode='iter2')
        self.dense1 = nn.Sequential(
            ResidualBlock(32),
            ResidualBlock(32),
            ResidualBlock(32)
        )

        self.conv4x = ConvLayer(32, 64, kernel_size=3, stride=2)
        self.conv2 = RDB(32, 4, 32)
        self.fusion2 = Encoder_MDCBlock1(32, 3, mode='iter2')
        self.dense2 = nn.Sequential(
            ResidualBlock(64),
            ResidualBlock(64),
            ResidualBlock(64)
        )

        self.conv8x = ConvLayer(64, 128, kernel_size=3, stride=2)
        self.conv3 = RDB(64, 4, 64)
        self.fusion3 = Encoder_MDCBlock1(64, 4, mode='iter2')
        self.dense3 = nn.Sequential(
            ResidualBlock(128),
            ResidualBlock(128),
            ResidualBlock(128)
        )

        self.conv16x = ConvLayer(128, 256, kernel_size=3, stride=2)
        self.conv4 = RDB(128, 4, 128)
        self.fusion4 = Encoder_MDCBlock1(128, 5, mode='iter2')

        self.dehaze = nn.Sequential()
        for i in range(0, res_blocks):
            self.dehaze.add_module('res%d' % i, ResidualBlock(256))

        self.convd16x = UpsampleConvLayer(256, 128, kernel_size=3, stride=2)
        self.dense_4 = nn.Sequential(
            ResidualBlock(128),
            ResidualBlock(128),
            ResidualBlock(128)
        )
        self.conv_4 = RDB(64, 4, 64)
        self.fusion_4 = Decoder_MDCBlock1(64, 2, mode='iter2')

        self.convd8x = UpsampleConvLayer(128, 64, kernel_size=3, stride=2)
        self.dense_3 = nn.Sequential(
            ResidualBlock(64),
            ResidualBlock(64),
            ResidualBlock(64)
        )
        self.conv_3 = RDB(32, 4, 32)
        self.fusion_3 = Decoder_MDCBlock1(32, 3, mode='iter2')

        self.convd4x = UpsampleConvLayer(64, 32, kernel_size=3, stride=2)
        self.dense_2 = nn.Sequential(
            ResidualBlock(32),
            ResidualBlock(32),
            ResidualBlock(32)
        )
        self.conv_2 = RDB(16, 4, 16)
        self.fusion_2 = Decoder_MDCBlock1(16, 4, mode='iter2')

        self.convd2x = UpsampleConvLayer(32, 16, kernel_size=3, stride=2)
        self.dense_1 = nn.Sequential(
            ResidualBlock(16),
            ResidualBlock(16),
            ResidualBlock(16)
        )
        self.conv_1 = RDB(8, 4, 8)
        self.fusion_1 = Decoder_MDCBlock1(8, 5, mode='iter2')

        self.conv_output = ConvLayer(16, 3, kernel_size=3, stride=1)


    def forward(self, x):
        # Encoder
        res1x = self.conv_input(x)
        print('Conv_stride1', res1x.shape)

        res1x_1, res1x_2 = res1x.split([(res1x.size()[1] // 2), (res1x.size()[1] // 2)], dim=1)
        feature_mem = [res1x_1]
        x = self.dense0(res1x) + res1x
        print('Residual Group', x.shape)

        res2x = self.conv2x(x)
        print('Conv_stride2',res2x.shape)

        res2x_1, res2x_2 = res2x.split([(res2x.size()[1] // 2), (res2x.size()[1] // 2)], dim=1)
        res2x_1 = self.fusion1(res2x_1, feature_mem)
        res2x_2 = self.conv1(res2x_2)
        print('Dense Feature', res2x_2.shape)
        feature_mem.append(res2x_1)

        res2x = torch.cat((res2x_1, res2x_2), dim=1)
        res2x =self.dense1(res2x) + res2x
        print('Residual Group', res2x.shape)


        res4x =self.conv4x(res2x)
        print('Conv_stride2', res4x.shape)

        res4x_1, res4x_2 = res4x.split([(res4x.size()[1] // 2), (res4x.size()[1] // 2)], dim=1)
        res4x_1 = self.fusion2(res4x_1, feature_mem)
        res4x_2 = self.conv2(res4x_2)
        print('Dense Feature', res4x_2.shape)
        feature_mem.append(res4x_1)


        res4x = torch.cat((res4x_1, res4x_2), dim=1)
        res4x = self.dense2(res4x) + res4x
        print('Residual Group', res4x.shape)


        res8x = self.conv8x(res4x)
        print('Conv_stride2', res8x.shape)

        res8x_1, res8x_2 = res8x.split([(res8x.size()[1] // 2), (res8x.size()[1] // 2)], dim=1)
        res8x_1 = self.fusion3(res8x_1, feature_mem)
        res8x_2 = self.conv3(res8x_2)
        print('Dense Feature', res8x_2.shape)
        feature_mem.append(res8x_1)


        res8x = torch.cat((res8x_1, res8x_2), dim=1)
        res8x = self.dense3(res8x) + res8x

        res16x = self.conv16x(res8x)
        print('Encoder Output', res16x.shape)

        # Gres
        res16x_1, res16x_2 = res16x.split([(res16x.size()[1] // 2), (res16x.size()[1] // 2)], dim=1)
        res16x_1 = self.fusion4(res16x_1, feature_mem)
        res16x_2 = self.conv4(res16x_2)
        res16x = torch.cat((res16x_1, res16x_2), dim=1)
        res_dehaze = res16x
        
        in_ft = res16x*2
        res16x = self.dehaze(in_ft) + in_ft - res_dehaze
        res16x_1, res16x_2 = res16x.split([(res16x.size()[1] // 2), (res16x.size()[1] // 2)], dim=1)
        feature_mem_up = [res16x_1]

        # Boosted Decoder
        print('Decoder Input',res16x.shape)

        res16x = self.convd16x(res16x)
        res16x = F.upsample(res16x, res8x.size()[2:], mode='bilinear')
        res8x = torch.add(res16x, res8x)
        print('Deconv_stride2', res8x.shape)


        res8x = self.dense_4(res8x) + res8x - res16x
        print('Residual Group',res8x.shape)


        res8x_1, res8x_2 = res8x.split([(res8x.size()[1] // 2), (res8x.size()[1] // 2)], dim=1)
        res8x_1 = self.fusion_4(res8x_1, feature_mem_up)
        res8x_2 = self.conv_4(res8x_2)

        feature_mem_up.append(res8x_1)
        res8x = torch.cat((res8x_1, res8x_2), dim=1)
        print('Dense Feature',res8x.shape)

        res8x = self.convd8x(res8x)
        res8x = F.upsample(res8x, res4x.size()[2:], mode='bilinear')
        print('Deconv_stride2', res8x.shape)

        res4x = torch.add(res8x, res4x)
        res4x = self.dense_3(res4x) + res4x - res8x
        print('Residual Group', res4x.shape)

        res4x_1, res4x_2 = res4x.split([(res4x.size()[1] // 2), (res4x.size()[1] // 2)], dim=1)
        res4x_1 = self.fusion_3(res4x_1, feature_mem_up)
        res4x_2 = self.conv_3(res4x_2)
        feature_mem_up.append(res4x_1)
        res4x = torch.cat((res4x_1, res4x_2), dim=1)
        print('Dense Feature',res4x.shape)

        res4x = self.convd4x(res4x)
        res4x = F.upsample(res4x, res2x.size()[2:], mode='bilinear')
        print('Deconv_stride2', res4x.shape)

        res2x = torch.add(res4x, res2x)
        res2x = self.dense_2(res2x) + res2x - res4x
        print('Residual Group', res2x.shape)

        res2x_1, res2x_2 = res2x.split([(res2x.size()[1] // 2), (res2x.size()[1] // 2)], dim=1)
        res2x_1 = self.fusion_2(res2x_1, feature_mem_up)
        res2x_2 = self.conv_2(res2x_2)
        feature_mem_up.append(res2x_1)
        res2x = torch.cat((res2x_1, res2x_2), dim=1)
        print('Dense Feature', res2x.shape)

        res2x = self.convd2x(res2x)
        res2x = F.upsample(res2x, x.size()[2:], mode='bilinear')
        x = torch.add(res2x, x)
        x = self.dense_1(x) + x - res2x
        print('Residual Group', x.shape)

        x_1, x_2 = x.split([(x.size()[1] // 2), (x.size()[1] // 2)], dim=1)
        x_1 = self.fusion_1(x_1, feature_mem_up)
        x_2 = self.conv_1(x_2)
        x = torch.cat((x_1, x_2), dim=1)
        print('Dense Feature', x.shape)

        x = self.conv_output(x)
        print('Conv_stride1',x.shape)

        return x

image_example = np.ones(shape=(3,256,256))
image = torch.Tensor(image_example).unsqueeze(0)

print('Input:', image.shape)
net = Net()
out = net(image)
print('Output:', out.shape)

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值