论文介绍
题目:
B2CNet: A Progressive Change Boundary-to-Center Refinement Network for Multitemporal Remote Sensing Images Change Detection
论文地址:
https://ieeexplore.ieee.org/document/10547405
创新点
- 设计了B2CNet模型:提出了一种名为B2CNet的边界到中心精细化的变化检测网络,无需额外的边缘提取算法或边界标签,仅通过简单操作和常规标签就能激活边界特征,指导模型达到更好的变化检测效果。
- 引入了变化边界感知模块(CBM):解决了对边界信息利用不足和精细检测边界不充分的问题。CBM模块能够从变化区域提取粗略的边界信息,并结合一般标签中的精确变化区域信息进行细化,从而提升边界的精细度,减少伪变化的影响,并指导其他分支的信息聚合。
- 提出了双时相特征聚合模块(BFAM):通过多个感受野聚合低层次纹理信息与高层次语义信息,增强了变化区域的内部完整性,弥补深层特征中细节信息的丢失问题。
- 引入了深度特征提取模块(DFEM):该模块通过增强各分支之间的特征融合,提取更深层的高层次语义信息,提高了网络的特征提取能力,加强了分支间的特征交互,从而提高了解耦效率。
方法
整体结构
B2CNet模型基于编码器-解码器架构,由变化边界感知模块(CBM)、双时相特征聚合模块(BFAM)和深度特征提取模块(DFEM)组成。CBM负责边界信息的提取和增强,BFAM聚合多尺度的纹理和语义信息以提升区域内部完整性,DFEM进一步融合高层语义信息和边界特征,确保变化区域的精确检测和边界的完整描述,从而实现对变化区域的细致识别与表达。
- 编码器部分:使用了ResNet18作为预训练的特征提取网络,通过输入两张不同时间的遥感图像,得到多层次的特征向量。这些特征向量用于后续模块中的特征聚合和边界检测。
- 变化边界感知模块(CBM):CBM模块接收编码器输出的特征,通过SimAM注意力机制和边缘增强操作提取边界信息。CBM模块增强了边界信息的表达能力,使其在解码过程中可以更清晰地描述变化区域的边界。
- 双时相特征聚合模块(BFAM):BFAM模块用于聚合多感受野的纹理信息和语义信息,通过并行的膨胀卷积捕捉不同尺度的变化区域信息,并通过SimAM注意力机制优化低层特征与高层语义信息的结合,提升变化区域内部的完整性。
- 深度特征提取模块(DFEM):该模块将CBM和BFAM模块的输出特征进行深层次融合,提取更高层次的变化语义信息,并通过残差连接保持特征的完整性。最终,DFEM输出整个模型的最终预测结果。
即插即用模块作用
BFAM 作为一个即插即用模块,主要适用于:
- 多尺度特征聚合:通过聚合不同感受野的特征,BFAM能够捕捉不同尺度的变化,提升模型对变化区域的敏感度。
- 内部一致性增强:通过融合低层次纹理和高层次语义信息,BFAM提高了变化区域的完整性,减少边界模糊和信息丢失。
- 细节信息保留:在特征融合过程中保留细节信息,增强模型对细微变化的辨识能力,提升检测的精确度。
消融实验结果
- 该表评估了不同模块(CBM、BFAM、DFEM和分支监督)的有效性。消融实验逐步替换基础网络中的特征拼接(FC)操作为各模块,结果显示每添加一个模块,模型性能均有所提升。具体来说,加入CBM后在边界特征检测上有显著改善,加入BFAM后提升了变化区域内部的完整性,而加入DFEM后进一步增强了高层语义信息的提取能力。最终,通过在CBM分支上加入监督,进一步提高了变化区域信息的精确度,验证了各模块对模型性能提升的重要性。
即插即用模块代码
#论文:B2CNet: A Progressive Change Boundary-to-Center Refinement Network for Multitemporal Remote Sensing Images Change Detection
#论文地址:https://ieeexplore.ieee.org/document/10547405
import torch
import torch.nn as nn
#Simam: A simple, parameter-free attention module for convolutional neural networks (ICML 2021)
class simam_module(torch.nn.Module):
def __init__(self, e_lambda=1e-4):
super(simam_module, self).__init__()
self.activaton = nn.Sigmoid()
self.e_lambda = e_lambda
def forward(self, x):
b, c, h, w = x.size()
n = w * h - 1
x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5
return x * self.activaton(y)
#bitemporal feature aggregation module (BFAM)
class BFAM(nn.Module):
def __init__(self,inp,out):
super(BFAM, self).__init__()
self.pre_siam = simam_module()
self.lat_siam = simam_module()
out_1 = int(inp/2)
self.conv_1 = nn.Conv2d(inp, out_1 , padding=1, kernel_size=3,groups=out_1,
dilation=1)
self.conv_2 = nn.Conv2d(inp, out_1, padding=2, kernel_size=3,groups=out_1,
dilation=2)
self.conv_3 = nn.Conv2d(inp, out_1, padding=3, kernel_size=3,groups=out_1,
dilation=3)
self.conv_4 = nn.Conv2d(inp, out_1, padding=4, kernel_size=3,groups=out_1,
dilation=4)
self.fuse = nn.Sequential(
nn.Conv2d(out_1 * 4, out_1, kernel_size=1, padding=0),
nn.BatchNorm2d(out_1),
nn.ReLU(inplace=True)
)
self.fuse_siam = simam_module()
self.out = nn.Sequential(
nn.Conv2d(out_1, out, kernel_size=3, padding=1),
nn.BatchNorm2d(out),
nn.ReLU(inplace=True)
)
def forward(self,inp1,inp2,last_feature=None):
x = torch.cat([inp1,inp2],dim=1)
c1 = self.conv_1(x)
c2 = self.conv_2(x)
c3 = self.conv_3(x)
c4 = self.conv_4(x)
cat = torch.cat([c1,c2,c3,c4],dim=1)
fuse = self.fuse(cat)
inp1_siam = self.pre_siam(inp1)
inp2_siam = self.lat_siam(inp2)
inp1_mul = torch.mul(inp1_siam,fuse)
inp2_mul = torch.mul(inp2_siam,fuse)
fuse = self.fuse_siam(fuse)
if last_feature is None:
out = self.out(fuse + inp1 + inp2 + inp2_mul + inp1_mul)
else:
out = self.out(fuse+inp2_mul+inp1_mul+last_feature+inp1+inp2)
out = self.fuse_siam(out)
return out
if __name__ == '__main__':
block = BFAM(inp=128, out=256)
inp1 = torch.rand(1, 128 // 2, 16, 16) # B C H W
inp2 = torch.rand(1, 128 // 2, 16, 16)# B C H W
last_feature = torch.rand(1, 128 // 2, 16, 16)# B C H W
# 通过BFAM模块,这里没有提供last_feature的话,可以为None
output = block(inp1, inp2, last_feature)
# output = bfam(inp1, inp2)
# 打印输入和输出的shape
print(inp1.size())
print(inp2.size()) print(output.size())