深度学习中模块设计汇总(一)

深度学习模块最小构成部分
1.PixelShuffle

nn.PixelShuffle(upscale_factor=scailingFactor)

2.Convolution
3.Strided Convolution

深度学习经典模块
Beyond Joint Demosaicking and Denoising: An Image Processing Pipeline for a Pixel-bin Image Sensor
1.Group Depth Attention Bottleneck Block
在这里插入图片描述
2.Depth A ttention Bottleneck Block
在这里插入图片描述
3.Spatial Attention Block
在这里插入图片描述

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)
        
class SpatialAttentionBlock(nn.Module):
    def __init__(self, spatial_filter=32):
        super(SpatialAttentionBlock, self).__init__()
        self.spatialAttenton = SpatialAttention()
        self.conv = nn.Conv2d(spatial_filter, spatial_filter,  3, padding=1)
    def forward(self, x):
        x1 = self.spatialAttenton(x)
        #print(" spatial attention",x1.shape)
        xC = self.conv(x)
        #print("conv",xC.shape)
        y = x1 * xC
        #print("output",y.shape)
        return y       

AWNet: Attentive Wavelet Network for Image ISP AWNet
1.全局上下文 res-dense 模块
全局上下文 res-dense 模块包含一个残差密集块 (RDB) 和一个全局上下文块 (GCB)

全局上下文 res-dense 模块
1.1 残差密集块 (RDB)

class MakeDense(nn.Module):
	'''单个残差块'''
    def __init__(self, in_channels, growth_rate, kernel_size=3):
        super(MakeDense, self).__init__()
        self.conv = nn.Conv2d(in_channels, growth_rate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.norm_layer = nn.BatchNorm2d(growth_rate)
    def forward(self, x):
        out = F.relu(self.conv(x))
        out = self.norm_layer(out)
        out = torch.cat((x, out), 1)
        return out
for i in range(num_dense_layer):
	'''循环连接单个残差块形成一个残差密集块 (RDB) '''
    modules.append(MakeDense(_in_channels, growth_rate))
    _in_channels += growth_rate
self.residual_dense_layers = nn.Sequential(*modules)  

1.2 全局上下文块 (GCB)

class ContextBlock2d(nn.Module):
    def __init__(self, inplanes=9, planes=32, pool='att', fusions=['channel_add'], ratio=4):
        super(ContextBlock2d, self).__init__()
        assert pool in ['avg', 'att']
        assert all([f in ['channel_add', 'channel_mul'] for f in fusions])
        assert len(fusions) > 0, 'at least one fusion should be used'
        self.inplanes = inplanes
        self.planes = planes
        self.pool = pool
        self.fusions = fusions
        if 'att' in pool:
            self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)  # context Modeling
            self.softmax = nn.Softmax(dim=2)
        else:
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
        if 'channel_add' in fusions:
            self.channel_add_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes // ratio, kernel_size=1),
                nn.LayerNorm([self.planes // ratio, 1, 1]),
                nn.PReLU(),
                nn.Conv2d(self.planes // ratio, self.inplanes, kernel_size=1)
            )
        else:
            self.channel_add_conv = None
        if 'channel_mul' in fusions:
            self.channel_mul_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes // ratio, kernel_size=1),
                nn.LayerNorm([self.planes // ratio, 1, 1]),
                nn.PReLU(),
                nn.Conv2d(self.planes // ratio, self.inplanes, kernel_size=1)
            )
        else:
            self.channel_mul_conv = None
    def spatial_pool(self, x):
        batch, channel, height, width = x.size()
        if self.pool == 'att':
            input_x = x
            # [N, C, H * W]
            input_x = input_x.view(batch, channel, height * width)
            # [N, 1, C, H * W]
            input_x = input_x.unsqueeze(1)
            # [N, 1, H, W]
            context_mask = self.conv_mask(x)
            # [N, 1, H * W]
            context_mask = context_mask.view(batch, 1, height * width)
            # [N, 1, H * W]
            context_mask = self.softmax(context_mask)
            # [N, 1, H * W, 1]
            context_mask = context_mask.unsqueeze(3)
            # [N, 1, C, 1]
            context = torch.matmul(input_x, context_mask)
            # [N, C, 1, 1]
            context = context.view(batch, channel, 1, 1)
        else:
            # [N, C, 1, 1]
            context = self.avg_pool(x)
        return context
    def forward(self, x):
        # [N, C, 1, 1]
        context = self.spatial_pool(x)
        if self.channel_mul_conv is not None:
            # [N, C, 1, 1]
            channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
            out = x * channel_mul_term
        else:
            out = x
        if self.channel_add_conv is not None:
            # [N, C, 1, 1]
            channel_add_term = self.channel_add_conv(context)
            out = out + channel_add_term
        return out

2.离散小波变换(DWT)
2.1离散小波变换
DWT 的本质是将输入特征图分解为高频和低频分量,离散小波变换(DWT)上采样和下采样

def dwt_init(x):
    x01 = x[:, :, 0::2, :] / 2
    x02 = x[:, :, 1::2, :] / 2
    x1 = x01[:, :, :, 0::2]
    x2 = x02[:, :, :, 0::2]
    x3 = x01[:, :, :, 1::2]
    x4 = x02[:, :, :, 1::2]
    x_LL = x1 + x2 + x3 + x4
    x_HL = -x1 - x2 + x3 + x4
    x_LH = -x1 + x2 - x3 + x4
    x_HH = x1 - x2 - x3 + x4
    return x_LL, torch.cat((x_LL, x_HL, x_LH, x_HH), 1)
class DWT(nn.Module):
    def __init__(self):
        super(DWT, self).__init__()
        self.requires_grad = False

    def forward(self, x):
        return dwt_init(x)

2.1离散小波逆变换 Inverse discrete wavelet transform (IDWT)

def iwt_init(x):
    r = 2
    in_batch, in_channel, in_height, in_width = x.size()
    out_batch, out_channel, out_height, out_width = in_batch, int(
        in_channel / (r**2)), r * in_height, r * in_width
    x1 = x[:, 0:out_channel, :, :] / 2
    x2 = x[:, out_channel:out_channel * 2, :, :] / 2
    x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2
    x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2
    h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().to(x.device)

    h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
    h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
    h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
    h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
    return h
    
class IWT(nn.Module):
    def __init__(self):
        super(IWT, self).__init__()
        self.requires_grad = False

    def forward(self, x):
        return iwt_init(x)

2.2Residual Wavelet Down-sampling Block

class GCWTResDown(nn.Module):
    def __init__(self, in_channels, att_block, norm_layer=nn.BatchNorm2d):
        super().__init__()
        self.dwt = DWT()
        if norm_layer:
            self.stem = nn.Sequential(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1),
                                      norm_layer(in_channels),
                                      nn.PReLU(),
                                      nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
                                      norm_layer(in_channels),
                                      nn.PReLU())
        else:
            self.stem = nn.Sequential(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1),
                                      nn.PReLU(),
                                      nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
                                      nn.PReLU())
        self.conv1x1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0)
        self.conv_down = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, stride=2)
        #self.att = att_block(in_channels * 2, in_channels * 2)

    def forward(self, x):
        stem = self.stem(x)
        xLL, dwt = self.dwt(x)
        res = self.conv1x1(xLL)
        out = torch.cat([stem, res], dim=1)
        #out = self.att(out)
        return out, dwt

2.3 Residual Wavelet Up sampling Block
在这里插入图片描述

class GCIWTResUp(nn.Module):
    def __init__(self, in_channels, att_block, norm_layer=None):
        super().__init__()
        if norm_layer:
            self.stem = nn.Sequential(
                nn.PixelShuffle(2),
                nn.Conv2d(in_channels // 4, in_channels // 4, kernel_size=3, padding=1),
                norm_layer(in_channels // 4),
                nn.PReLU(),
                nn.Conv2d(in_channels // 4, in_channels // 4, kernel_size=3, padding=1),
                norm_layer(in_channels // 4),
                nn.PReLU(),
            )
        else:
            self.stem = nn.Sequential(
                nn.PixelShuffle(2),
                nn.Conv2d(in_channels // 4, in_channels // 4, kernel_size=3, padding=1),
                nn.PReLU(),
                nn.Conv2d(in_channels // 4, in_channels // 4, kernel_size=3, padding=1),
                nn.PReLU(),
            )
        self.pre_conv_stem = nn.Conv2d(in_channels // 2, in_channels, kernel_size=1, padding=0)
        self.pre_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0)
        # self.prelu = nn.PReLU()
        self.post_conv = nn.Conv2d(in_channels // 4, in_channels // 4, kernel_size=1, padding=0)
        self.iwt = IWT()
        self.last_conv = nn.Conv2d(in_channels // 2, in_channels // 4, kernel_size=1, padding=0)
        # self.se = SE_net(in_channels // 2, in_channels // 4)
    def forward(self, x, x_dwt):
        x = self.pre_conv_stem(x)
        stem = self.stem(x)
        x_dwt = self.pre_conv(x_dwt)
        x_iwt = self.iwt(x_dwt)
        x_iwt = self.post_conv(x_iwt)
        out = torch.cat((stem, x_iwt), dim=1)
        out = self.last_conv(out)
        return out

CycleISP: Real Image Restoration via Improved Data Synthesis
3.RRG: Recursive Residual Group
在这里插入图片描述
递归残差组 (RRG) 包含多个双重注意块 (DAB)。每个 DAB 包含空间注意和通道注意模块.
3.1Channel attention

class CALayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(CALayer, self).__init__()
        # global average pooling: feature --> point
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # feature channel downscale and upscale --> channel weight
        self.conv_du = nn.Sequential(
                nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
                nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv_du(y)
        return x * y

3.2Spatial attention

class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=False, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x

class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )

class spatial_attn_layer(nn.Module):
    def __init__(self, kernel_size=3):
        super(spatial_attn_layer, self).__init__()
        self.compress = ChannelPool()
        self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
    def forward(self, x):
        # import pdb;pdb.set_trace()
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = torch.sigmoid(x_out) # broadcasting
        return x * scale

3.3双重注意块(DAB)

class DAB(nn.Module):
    def __init__(
        self, conv, n_feat, kernel_size, reduction,
        bias=True, bn=False, act=nn.ReLU(True)):

        super(DAB, self).__init__()
        modules_body = []
        for i in range(2):
            modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
            if bn: modules_body.append(nn.BatchNorm2d(n_feat))
            if i == 0: modules_body.append(act)
        
        self.SA = spatial_attn_layer()            ## Spatial Attention
        self.CA = CALayer(n_feat, reduction)     ## Channel Attention
        self.body = nn.Sequential(*modules_body)
        self.conv1x1 = nn.Conv2d(n_feat*2, n_feat, kernel_size=1)
    def forward(self, x):
        res = self.body(x)
        sa_branch = self.SA(res)
        ca_branch = self.CA(res)
        res = torch.cat([sa_branch, ca_branch], dim=1)
        res = self.conv1x1(res)
        res += x
        return res

3.4 Recursive Residual Group

class RRG(nn.Module):
    def __init__(self, conv, n_feat, kernel_size, reduction, act,  num_dab):
        super(RRG, self).__init__()
        modules_body = []
        modules_body = [
            DAB(
                conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=act) \
            for _ in range(num_dab)]
        modules_body.append(conv(n_feat, n_feat, kernel_size))
        self.body = nn.Sequential(*modules_body)

    def forward(self, x):
        res = self.body(x)
        res += x
        return res

3.5 Color Correction

class CCM(nn.Module):
    def __init__(self,  conv=conv):
        super(CCM, self).__init__()        
        input_nc  = 3
        output_nc = 96

        num_rrg = 2
        num_dab = 2
        n_feats = 96
        kernel_size = 3
        reduction = 8

        sigma = 12 ## GAUSSIAN_SIGMA

        act =nn.PReLU(n_feats)
        modules_head = [conv(input_nc, n_feats, kernel_size = kernel_size, stride = 1)]
        modules_downsample = [nn.MaxPool2d(kernel_size=2)] 
        self.downsample = nn.Sequential(*modules_downsample)
        modules_body = [
            RRG(
                conv, n_feats, kernel_size, reduction, act=act, num_dab=num_dab) \
            for _ in range(num_rrg)]

        modules_body.append(conv(n_feats, n_feats, kernel_size))
        modules_body.append(act) 

        modules_tail = [conv(n_feats, output_nc, kernel_size),nn.Sigmoid()]

        self.head = nn.Sequential(*modules_head)
        self.body = nn.Sequential(*modules_body)
        self.tail = nn.Sequential(*modules_tail)
        self.blur, self.pad = get_gaussian_kernel(sigma=sigma)
    def forward(self, x):
        x = F.pad(x, (self.pad, self.pad, self.pad, self.pad), mode='reflect')
        x = self.blur(x)
        x = self.head(x)
        x = self.downsample(x)  
        x = self.body(x)
        x = self.tail(x)
        return x
  • 2
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值