【第六期论文复现赛-变化检测】SNUNet-CD

本文档介绍了SNUNet-CD模型,一种用于高分辨率图像变化检测的深度学习网络,结合了Siamese Network和Nested UNet。模型通过密集连接保持高分辨率特征,并利用通道注意力模块融合不同层次信息。在CDD测试集上,SNUNet-CD实现了95.54%的F1-Score。此外,提供了训练、验证和预测的详细步骤,以及模型导出和TIPC测试流程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

转自AI Studio,原文链接:

【第六期论文复现赛-变化检测】SNUNet-CD - 飞桨AI Studio

 

第六期论文复现赛-变化检测】A Densely Connected Siamese Network for Change Detection of VHR Images

一、前言介绍

论文简介

SNUNET-CD的结构如下图所示。作者提出了一种用于变化检测的稠密链接网络,即SNUNet-CD (siamese network和NestedUNet的组合),受DenseNet和NestedUNet的启发,设计了一个密集的连接的连体网络用于变更检测。通过编码器和解码器之间、解码器和解码器之间的密集跳过连接,它可以保持高分辨率、细粒度的表示。提出了集成通道关注模块(ECAM)的深度监控方法。通过ECAM,可以细化不同语义层次的最具代表性的特征,并用于最终的分类。

二、网络结构

:这里有关网络的分析,参考CSDN博客:狗都能看懂的变化检测网络Siam-NestedUNet讲解——解决工业检测的痛点

  • 模型的整体架构已在上图展示,该网络是典型的encoder–decoder结构,可以分为三大部分进行拆解:
    • 网络的backbone,类似于UNet++
    • 提取两幅图像差异信息的孪生网络结构
    • 网络最后为了加强不同级别输出的信息的ECAM模块

2.1 Backbone

  • 可以看到模型的主干由UNet++衍生而来。我们不看双输入部分,只看backbone,从x1,1x^{1,1}x1,1 下采样到x4,0x^{4,0}x4,0,然后再通过上采样到x0,4x^{0,4}x0,4这一部分,它是呈现一个U型结构,和UNet类似,是经典的图像分割中非常经典的Encoder-Decoder结构。
  • 同时在两层卷积中使用了类似残差网络的连接,这是参考DenseNet采用密集残差边,可以解决两个问题:
    • 梯度回传时,浅层网络难以优化的问题
    • 加强特征融合,使得深层网络可以结合浅层网络的特征,同时融合了低层的细节信息和高层的语义信息,增大了低层的感受野,使得低层在做小目标检测时能获得更多上下文信息

2.2 Siamese Network-孪生网络

  • 如上图所示,孪生网络有两个输入,其诞生的初衷是为了解决小数据集泛化性差的问题。一个输入对应一个网络,最终会的得到两个输出,这两个输出对应这两个输入的高维特征,对其简单做差可近似看为二者的loss,loss越小代表差异越小,loss越大代表差异越大。通常情况下,两个输入的网络权重是共享的
  • SNUNet中,在输入图片时,将两个时相的图片分别进行encode,并只要在跳跃连接时将两组特征concat起来,再进行相对应的decode,得到该级别的输出

上述两个模块的paddle源代码如下所示

class SNUNet(nn.Layer, KaimingInitMixin):
    """
    Args:
        in_channels (int): The number of bands of the input images.
        num_classes (int): The number of target classes.
        width (int, optional): The output channels of the first convolutional layer. Default: 32.
    """

    def __init__(self, in_channels, num_classes, width=32):
        super(SNUNet, self).__init__()

        filters = (width, width * 2, width * 4, width * 8, width * 16)

        self.conv0_0 = ConvBlockNested(in_channels, filters[0], filters[0])
        self.conv1_0 = ConvBlockNested(filters[0], filters[1], filters[1])
        self.conv2_0 = ConvBlockNested(filters[1], filters[2], filters[2])
        self.conv3_0 = ConvBlockNested(filters[2], filters[3], filters[3])
        self.conv4_0 = ConvBlockNested(filters[3], filters[4], filters[4])
        self.down1 = MaxPool2x2()
        self.down2 = MaxPool2x2()
        self.down3 = MaxPool2x2()
        self.down4 = MaxPool2x2()
        self.up1_0 = Up(filters[1])
        self.up2_0 = Up(filters[2])
        self.up3_0 = Up(filters[3])
        self.up4_0 = Up(filters[4])

        self.conv0_1 = ConvBlockNested(filters[0] * 2 + filters[1], filters[0],
                                       filters[0])
        self.conv1_1 = ConvBlockNested(filters[1] * 2 + filters[2], filters[1],
                                       filters[1])
        self.conv2_1 = ConvBlockNested(filters[2] * 2 + filters[3], filters[2],
                                       filters[2])
        self.conv3_1 = ConvBlockNested(filters[3] * 2 + filters[4], filters[3],
                                       filters[3])
        self.up1_1 = Up(filters[1])
        self.up2_1 = Up(filters[2])
        self.up3_1 = Up(filters[3])

        self.conv0_2 = ConvBlockNested(filters[0] * 3 + filters[1], filters[0],
                                       filters[0])
        self.conv1_2 = ConvBlockNested(filters[1] * 3 + filters[2], filters[1],
                                       filters[1])
        self.conv2_2 = ConvBlockNested(filters[2] * 3 + filters[3], filters[2],
                                       filters[2])
        self.up1_2 = Up(filters[1])
        self.up2_2 = Up(filters[2])

        self.conv0_3 = ConvBlockNested(filters[0] * 4 + filters[1], filters[0],
                                       filters[0])
        self.conv1_3 = ConvBlockNested(filters[1] * 4 + filters[2], filters[1],
                                       filters[1])
        self.up1_3 = Up(filters[1])

        self.conv0_4 = ConvBlockNested(filters[0] * 5 + filters[1], filters[0],
                                       filters[0])

        self.ca_intra = ChannelAttention(filters[0], ratio=4)
        self.ca_inter = ChannelAttention(filters[0] * 4, ratio=16)

        self.conv_out = Conv1x1(filters[0] * 4, num_classes)

        self.init_weight()

    def forward(self, t1, t2):
        x0_0_t1 = self.conv0_0(t1)
        x1_0_t1 = self.conv1_0(self.down1(x0_0_t1))
        x2_0_t1 = self.conv2_0(self.down2(x1_0_t1))
        x3_0_t1 = self.conv3_0(self.down3(x2_0_t1))

        x0_0_t2 = self.conv0_0(t2)
        x1_0_t2 = self.conv1_0(self.down1(x0_0_t2))
        x2_0_t2 = self.conv2_0(self.down2(x1_0_t2))
        x3_0_t2 = self.conv3_0(self.down3(x2_0_t2))
        x4_0_t2 = self.conv4_0(self.down4(x3_0_t2))

        x0_1 = self.conv0_1(
            paddle.concat([x0_0_t1, x0_0_t2, self.up1_0(x1_0_t2)], 1))
        x1_1 = self.conv1_1(
            paddle.concat([x1_0_t1, x1_0_t2, self.up2_0(x2_0_t2)], 1))
        x0_2 = self.conv0_2(
            paddle.concat([x0_0_t1, x0_0_t2, x0_1, self.up1_1(x1_1)], 1))

        x2_1 = self.conv2_1(
            paddle.concat([x2_0_t1, x2_0_t2, self.up3_0(x3_0_t2)], 1))
        x1_2 = self.conv1_2(
            paddle.concat([x1_0_t1, x1_0_t2, x1_1, self.up2_1(x2_1)], 1))
        x0_3 = self.conv0_3(
            paddle.concat([x0_0_t1, x0_0_t2, x0_1, x0_2, self.up1_2(x1_2)], 1))

        x3_1 = self.conv3_1(
            paddle.concat([x3_0_t1, x3_0_t2, self.up4_0(x4_0_t2)], 1))
        x2_2 = self.conv2_2(
            paddle.concat([x2_0_t1, x2_0_t2, x2_1, self.up3_1(x3_1)], 1))
        x1_3 = self.conv1_3(
            paddle.concat([x1_0_t1, x1_0_t2, x1_1, x1_2, self.up2_2(x2_2)], 1))
        x0_4 = self.conv0_4(
            paddle.concat(
                [x0_0_t1, x0_0_t2, x0_1, x0_2, x0_3, self.up1_3(x1_3)], 1))

        out = paddle.concat([x0_1, x0_2, x0_3, x0_4], 1)

        intra = paddle.sum(paddle.stack([x0_1, x0_2, x0_3, x0_4]), axis=0)
        m_intra = self.ca_intra(intra)
        out = self.ca_inter(out) * (out + paddle.tile(m_intra, (1, 4, 1, 1)))

        pred = self.conv_out(out)
        return [pred]

2.3 Ensemble Channel Attention Module- 集成通道注意力模块

  • 在经过SNUNet的encode-decode之后,最终获得4个和原图大小相同的输出。虽然大小一样,但是不同的输出之间语义层次和空间位置的表达也不相同。
  • 浅层的输出具有准确的空间位置信息,而深层的输出具有更细致的语义信息,因此融合这些特征时需要考虑不同层次输出之间语义信息空间位置的差异,SNUNet采用ECAM模块来进行融合
  • ECAM模块的具体结构在整体结构图中的(b)部分,实现的逻辑可以看成是一个残差块 + 两个通道注意力模块构成,上面贴的代码已经将ECAM大致逻辑写好,下面贴通道注意力机制的代码
class ChannelAttention(nn.Layer):
    """
    Args:
        in_ch (int): The number of channels of the input features.
        ratio (int, optional): The channel reduction ratio. Default: 8.
    """

    def __init__(self, in_ch, ratio=8):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2D(1)
        self.max_pool = nn.AdaptiveMaxPool2D(1)
        self.fc1 = Conv1x1(in_ch, in_ch // ratio, bias=False, act=True)
        self.fc2 = Conv1x1(in_ch // ratio, in_ch, bias=False)

    def forward(self, x):
        avg_out = self.fc2(self.fc1(self.avg_pool(x)))
        max_out = self.fc2(self.fc1(self.max_pool(x)))
        out = avg_out + max_out
        return F.sigmoid(out)

三、复现精度

在CDD的测试集的测试效果如下表,达到验收指标,F1-Score=95.3%

Networkoptepochbatch_sizedatasetF1-ScoremIOU
SNUNET-32AdamW10016CDD95.54%95.12%

注意:验收评估的模型为SNUNet-32

四、环境与数据准备

  • 克隆仓库

In [1]

!git clone https://github.com/kongdebug/SNUNet-Paddle.git
正克隆到 'SNUNet-Paddle'...
remote: Enumerating objects: 1027, done.
remote: Counting objects: 100% (1027/1027), done.
remote: Compressing objects: 100% (804/804), done.
remote: Total 1027 (delta 214), reused 960 (delta 189), pack-reused 0
接收对象中: 100% (1027/1027), 12.22 MiB | 7.44 MiB/s, 完成.
处理 delta 中: 100% (214/214), 完成.
检查连接... 完成。
  • 解压数据,并进行处理

In [2]

# 解压数据
!unzip -qo data/data29275/CDData.zip -d ./work/

In [ ]

# 安装相应依赖
%cd SNUNet-Paddle/
!pip install -r requirements.txt

In [4]

# 生成模型训练需要的.txt文件
!python ./data/process_cdd_data.py --data_dir=../work/Real/subset
数据集划分已完成。

五、快速体验

  • 模型训练
    • 注意:由于SNUNET-CD在训练时没有使用Normalize处理,所以可能会导致前几个epoch的loss比较大,在第10到第16个epoch时可正常

In [ ]

!python tutorials/train/snunet.py --data_dir=../work/Real/subset --out_dir=./output/snunet/
  • 模型验证
    • 最优模型权重已放入work/output/snunet/best_model文件夹下
    • 可将--weight_path参数替换为自己训练出的模型权重路径

In [10]

!python tutorials/eval/snunet_eval.py --data_dir=../work/Real/subset \
                                      --weight_path=../work/output/snunet/best_model/model.pdparams
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
[04-26 02:26:41 MainThread @logger.py:242] Argv: tutorials/eval/snunet_eval.py --data_dir=../work/Real/subset --weight_path=../work/output/snunet/best_model/model.pdparams
[04-26 02:26:41 MainThread @utils.py:79] WRN paddlepaddle version: 2.2.2. The dynamic graph version of PARL is under development, not fully tested and supported
2022-04-26 02:26:42 [INFO]	10000 samples in file ../work/Real/subset/train.txt
2022-04-26 02:26:42 [INFO]	3000 samples in file ../work/Real/subset/test.txt
W0426 02:26:42.497771  4362 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0426 02:26:42.502462  4362 device_context.cc:465] device: 0, cuDNN Version: 7.6.
2022-04-26 02:26:45 [INFO]	Loading pretrained model from ../work/output/snunet/best_model/model.pdparams
2022-04-26 02:26:45 [INFO]	There are 186/186 variables loaded into SNUNet.
2022-04-26 02:26:45 [INFO]	Start to evaluate(total_samples=3000, total_steps=3000)...
OrderedDict([('miou', 0.9511789327930941), ('category_iou', array([0.98762963, 0.91472823])), ('oacc', 0.989078862508138), ('category_acc', array([0.99284724, 0.96190638])), ('kappa', 0.9492419817634077), ('category_F1-score', array([0.99377632, 0.95546534]))])
  • 模型预测
    • 使用最优模型权重对模型进行预测
    • 参数介绍:
      • weight 训练好的权重
      • A,B, 是T1影像路径,T2影像路径
      • pre 预测图片存储的位置

In [11]

!python tutorials/predict/snunet_pred.py --weight=../work/output/snunet/best_model/model.pdparams \
                                         --A=../work/Real/subset/test/A/00002.jpg --B=../work/Real/subset/test/B/00002.jpg \
                                         --pre=../work/pre.png
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
[04-26 02:30:00 MainThread @logger.py:242] Argv: tutorials/predict/snunet_pred.py --weight=../work/output/snunet/best_model/model.pdparams --A=../work/Real/subset/test/A/00002.jpg --B=../work/Real/subset/test/B/00002.jpg --pre=../work/pre.png
[04-26 02:30:00 MainThread @utils.py:79] WRN paddlepaddle version: 2.2.2. The dynamic graph version of PARL is under development, not fully tested and supported
W0426 02:30:00.780316  4623 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0426 02:30:00.785306  4623 device_context.cc:465] device: 0, cuDNN Version: 7.6.
finish!

In [13]

# 展示预测的结果,最后一张为真值
import matplotlib.pyplot as plt
from PIL import Image

T1 = Image.open(r"../work/Real/subset/test/A/00002.jpg")
T2 = Image.open(r"../work/Real/subset/test/B/00002.jpg")
GT = Image.open(r"../work/Real/subset/test/OUT/00002.jpg")
pred = Image.open(r"../work/pre.png")

plt.figure(figsize=(16, 8))
plt.subplot(1,4,1), plt.title('T1')
plt.imshow(T1), plt.axis('off')
plt.subplot(1,4,2), plt.title('T2') 
plt.imshow(T2), plt.axis('off')
plt.subplot(1,4,3), plt.title('pred') 
plt.imshow(pred), plt.axis('off')
plt.subplot(1,4,4), plt.title('GT') 
plt.imshow(GT), plt.axis('off')
plt.show()

<Figure size 1152x576 with 4 Axes>
  • SNUNet模型导出

In [14]

!python deploy/export/export_model.py --model_dir=../work/output/snunet/best_model/ \
                                      --save_dir=./inference_model/ 
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
[04-26 02:34:02 MainThread @logger.py:242] Argv: deploy/export/export_model.py --model_dir=../work/output/snunet/best_model/ --save_dir=./inference_model/
[04-26 02:34:02 MainThread @utils.py:79] WRN paddlepaddle version: 2.2.2. The dynamic graph version of PARL is under development, not fully tested and supported
W0426 02:34:02.990654  4975 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0426 02:34:02.995680  4975 device_context.cc:465] device: 0, cuDNN Version: 7.6.
2022-04-26 02:34:05 [INFO]	Model[SNUNet] loaded.
2022-04-26 02:34:08 [INFO]	The model for the inference deployment is saved in ./inference_model/.

六、TIPC基础链条测试

该部分依赖auto_log,需要进行安装,安装方式如下:

auto_log的详细介绍参考https://github.com/LDOUBLEV/AutoLog

In [ ]

!git clone https://github.com/LDOUBLEV/AutoLog
!pip3 install -r requirements.txt
!python3 setup.py bdist_wheel
!pip3 install ./dist/auto_log-1.0.0-py3-none-any.whl
  • 准备数据

In [ ]

!bash ./test_tipc/prepare.sh test_tipc/configs/SNUNET/train_infer_python.txt 'lite_train_lite_infer'
  • 测试

In [ ]

!bash test_tipc/test_train_inference_python.sh test_tipc/configs/SNUNET/train_infer_python.txt 'lite_train_lite_infer'

七、项目总结

  • 本项目对SNUNet进行了简单的介绍,包括总体的模型结构和具体的网络细节,帮助大家更好的理解SNUNet网络
  • 同时本项目给出了SNUNet的Paddle复现仓库的使用方法,可以快速进行训练、评估和预测,以及SNUNet模型的导出
  • 遥感类的复现使用PaddleRS套件可以赢在起跑线上,同时官方给出的《论文复现赛指南》非常有借鉴意义,教程中的每个复现关键节点都已经指出,教会了我如何对一篇论文进行复现,除了本次比赛,对今后的学习也帮助很大

八、致谢

  • 再次由衷的感谢飞桨团队提供的算力支持,感谢RD小姐姐的解答与帮助。也很感激飞桨能够开源很多套件帮助我的学习与科研。同时再次感谢古代飞奔向未来的样子两位的帮助。
  • 再话痨一下,自己当初接触AI studio就是因为要做毕设了,但实验室的卡要排队,在网上看到百度能每天提供免费的算力,果断加入了”白嫖“的行列。后来毕设要用到GAN,苦于如何入门的时候,飞桨恰好推出生成对抗网络七日打卡营的课程。真的很感谢飞桨,希望以后越来越好,谢谢!
### ELGC-Net 模型架构与实现介绍 #### 1. 背景与目标 ELGC-Net 是一种专为遥感变化检测设计的深度学习框架,其核心目的是解决传统 CNN 和基于变换器模型在处理高分辨率卫星图像时遇到的挑战。这些挑战主要包括复杂的背景干扰、局部和全局上下文信息的有效融合以及计算资源的需求限制[^2]。 #### 2. 主要贡献 ELGC-Net 的主要创新点在于提出了 **Efficient Local-Global Context Aggregation (ELGCA)** 方法,这是一种能够有效聚合局部和全局特征的技术。通过这一技术,ELGC-Net 不仅提高了变化检测的准确性,还显著降低了模型的计算复杂度和参数量[^3]。 #### 3. 模型架构详解 ELGC-Net 的整体结构可以分为以下几个部分: ##### (1)输入层 模型接收两幅不同时期的遥感影像作为输入数据。为了便于后续处理,通常会对这两幅图像进行预处理操作,例如标准化、裁剪或增强等。 ##### (2)特征提取模块 此阶段采用轻量化卷积神经网络(如 ResNet 或 MobileNet)来分别提取两张输入图像的空间特征图。这种设计有助于保留更多的细节信息,从而提升最终的变化检测精度。 ##### (3)HAFB 层级注意融合块 这是 ELGC-Net 中的关键组件之一——自研模块 Hierarchical Attention Fusion Block (HAFB)[^1]。它负责将来自两个时间点的不同层次特征图进行深度融合,并利用多尺度注意力机制突出显示潜在发生变化区域的位置及其特性。具体来说,HAFB 可以完成以下任务: - 自动调整权重分配给各个通道上的重要程度; - 增强跨时间段间差异信号的表现力; - 抑制噪声影响并保持边界清晰度。 ```python class HAFB(nn.Module): def __init__(self, in_channels): super(HAFB, self).__init__() # 定义一些必要的子模块... def forward(self, x_t1, x_t2): # 实现具体的前向传播逻辑... pass ``` ##### (4)上下文聚合单元 在此之后,ELGC-Net 进一步引入了 Efficient Local-Global Context Aggregation (ELGCA),用于综合考虑更大范围内的依赖关系。这种方法不仅增强了对于大尺寸对象或者远距离关联的理解能力,同时也大幅削减了不必要的冗余运算开销。 ##### (5)决策分类头 最后,在经过上述一系列处理步骤后得到的结果会被送入到二元交叉熵损失函数指导下的全连接层当中去预测每一个像素属于“未改变”还是“已发生变更”的类别标签。 #### 4. 实验验证与性能表现 根据官方实验结果显示,相比于其他主流算法(例如 FC-EF、SNUNet 等),ELGC-Net 在多个公开基准测试集上均取得了领先的 F1-score 和 IoU 数值指标成绩。特别是在面对含有大量遮挡物场景下依旧保持着稳健可靠的判断水平。 --- ###
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值