即插即用的涨点模块之特征融合(DFM)详解及代码,可应用于检测、分割、分类等各种算法领域

目录

前言

一、DFM网络讲解

二、DFM参数

三、代码讲解  


前言

FCCDN: Feature constraint network for VHR image change detection

来源:ISPRS2022

官方代码:https://github.com/chenpan0615/FCCDN_pytorch

         Dense fusion module(DFM)由作者为孪生网络的变化检测所提出。双时态特征融合是变化检测的一个关键部分。这一任务难以处理有两个原因:(1)输入到孪生网络中的双时相图像在空间位置和颜色上往往存在偏差。(2)背景对象复杂多变。传统方法使用直接减法或连接来融合特征。不幸的是,尽管孪生网络通过双时像提取特征,但双时像特征之间仍存在很多不对齐的问题。也有许多研究者尝试用注意力机制来解决这个问题。然而,大多数现有的基于注意力的特征模块都会引入大量计算并消耗相当多的内存。所以提出了一个基于密集连接的简单而有效的特征融合模块,我们将这个模块命名为DFM。笔者认为该模块不仅可用于变化检测,可也用于目标检测、分割等领域中图像的特征融合。


一、DFM网络讲解

如图1DFM结构所示,DFM包括两个分支,和分支(Sum)和差分支(Diff)。和分支用于增强边缘信息,差分支用于生成变化特征。每个分支都由两个具有权重共享的密集连接流构建,所有的卷积操作都使用3×3的卷积核。构建DFM的目的是为了在每个流中集成多个特征,从而做出更好的决策。这种结构可以增加模型的鲁棒性,防止由于特征不对齐而导致的伪变化。此外,由于密集连接中丰富的残差连接,每个流中的最后两个特征可以视为前一个特征的残差,从某种程度上是前一个特征的校正,使得新的特征图更加对齐。作者通过在图2中可视化Diff的特征来验证这个模块。从特征图的可视化中,我们可以看到DFM确实减少了特征不对齐,并计算出了更准确的变化特征。

图1 DFM结构

图2 DFM可视化

讲解:

和分支(Sum Branch):通过加法操作融合来自两个时间点的特征。这种方式有助于增强图像中的边缘信息,因为变化区域往往在边缘上更为明显。

差分支(Diff Branch):通过差分操作生成反映变化的特征。这是直接反映两个时间点之间差异的方式,特别适用于变化检测。

这两个分支都采用了密集连接(Dense Connectivity)的策略,即每层的输出不仅连接到下一层,还与之前所有层的输出相连接,这样的结构有助于保留丰富的历史信息,并增强特征的表达能力。通过权重共享,减少了模型参数和避免了过拟合。

此外,DFM通过其密集连接还引入了一种自然的残差学习机制,后一层特征可以看作是对前一层特征的校正,从而在不断迭代中逐渐对齐和精化特征表示,最终产生更准确的变化检测结果。

二、DFM参数

利用thop库的profile函数计算FLOPs和Param。Input:(64,32,32)(64,32,32)。

ModuleFLOPsParam
DFM16043212899168

三、代码讲解  

import torch
import torch.nn as nn


class densecat_cat_add(nn.Module):
    def __init__(self, in_chn, out_chn):
        super(densecat_cat_add, self).__init__()

        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv2d(in_chn, in_chn, kernel_size=3, padding=1),
            torch.nn.ReLU(inplace=True),
        )
        self.conv2 = torch.nn.Sequential(
            torch.nn.Conv2d(in_chn, in_chn, kernel_size=3, padding=1),
            torch.nn.ReLU(inplace=True),
        )
        self.conv3 = torch.nn.Sequential(
            torch.nn.Conv2d(in_chn, in_chn, kernel_size=3, padding=1),
            torch.nn.ReLU(inplace=True),
        )
        self.conv_out = torch.nn.Sequential(
            torch.nn.Conv2d(in_chn, out_chn, kernel_size=1, padding=0),
            torch.nn.BatchNorm2d(out_chn),
            torch.nn.ReLU(inplace=True),
        )

    def forward(self, x, y):
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        x3 = self.conv3(x2+x1)

        y1 = self.conv1(y)
        y2 = self.conv2(y1)
        y3 = self.conv3(y2+y1)

        return self.conv_out(x1 + x2 + x3 + y1 + y2 + y3)

class densecat_cat_diff(nn.Module):
    def __init__(self, in_chn, out_chn):
        super(densecat_cat_diff, self).__init__()
        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv2d(in_chn, in_chn, kernel_size=3, padding=1),
            torch.nn.ReLU(inplace=True),
        )
        self.conv2 = torch.nn.Sequential(
            torch.nn.Conv2d(in_chn, in_chn, kernel_size=3, padding=1),
            torch.nn.ReLU(inplace=True),
        )
        self.conv3 = torch.nn.Sequential(
            torch.nn.Conv2d(in_chn, in_chn, kernel_size=3, padding=1),
            torch.nn.ReLU(inplace=True),
        )
        self.conv_out = torch.nn.Sequential(
            torch.nn.Conv2d(in_chn, out_chn, kernel_size=1, padding=0),
            torch.nn.BatchNorm2d(out_chn),
            torch.nn.ReLU(inplace=True),
        )

    def forward(self, x, y):

        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        x3 = self.conv3(x2+x1)

        y1 = self.conv1(y)
        y2 = self.conv2(y1)
        y3 = self.conv3(y2+y1)
        out = self.conv_out(torch.abs(x1 + x2 + x3 - y1 - y2 - y3))
        return out


class DF_Module(nn.Module):
    def __init__(self, dim_in, dim_out, reduction=True):
        super(DF_Module, self).__init__()
        if reduction:
            self.reduction = torch.nn.Sequential(
                torch.nn.Conv2d(dim_in, dim_in//2, kernel_size=1, padding=0),
                nn.BatchNorm2d(dim_in//2),
                torch.nn.ReLU(inplace=True),
            )
            dim_in = dim_in//2
        else:
            self.reduction = None
        self.cat1 = densecat_cat_add(dim_in, dim_out)
        self.cat2 = densecat_cat_diff(dim_in, dim_out)
        self.conv1 = nn.Sequential(
            nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(dim_out),
            nn.ReLU(inplace=True),
        )

    def forward(self, x1, x2):
        if self.reduction is not None:
            x1 = self.reduction(x1)
            x2 = self.reduction(x2)
        x_add = self.cat1(x1, x2)
        x_diff = self.cat2(x1, x2)
        y = self.conv1(x_diff) + x_add
        return y

if __name__ =='__main__':
    from thop import profile
    model = DF_Module(64, 64, True)
    x1 = torch.randn(1, 64, 32, 32)
    x2 = torch.randn(1, 64, 32, 32)
    y = model(x1, x2)
    flops, params = profile(model, inputs=(x1, x2))
    print(f"FLOPs: {flops}, Params: {params}")

  • 16
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值