文章目录
1、Bitemporal Feature Aggregation Module
B2CNet 论文中提出的 双时相特征聚合模块(Bitemporal Feature Aggregation Module)、改变边界感知模块(Change Boundary-Aware Module)、深度特征提取模块(Deep Feature Extraction Module) 都是为了解决遥感图像变化检测中的关键问题,并提升检测精度。以下是每个模块的设计动机、原理和实现过程:
由于遥感图像变化检测需要综合考虑低级纹理信息和高级语义信息。而现有的方法往往只关注其中一种信息,导致检测结果不够完整。所以 BFAM 通过聚合来自不同感受野的低级纹理信息和高层次语义信息。使特征更加丰富。
对于输入特征,BFAM 的实现过程:
- 对输入的双时相特征进行通道拼接。
- 使用不同膨胀率的卷积操作提取多尺度特征。
- 使用 1x1 卷积进行通道降维。
- 使用 SimAM 注意力机制强化特征。
- 使用 SimAM 注意力机制分别提取不同时相特征的共性特征。
- 将共性特征与各自时相特征相乘,得到相似性特征。
- 将低级纹理特征、相似性特征和高级语义特征进行求和。
- 使用 SimAM 注意力机制进行特征聚合。
Bitemporal Feature Aggregation Module 结构图:
2、Change Boundary-Aware Module
在遥感图像变化检测中,边界信息对于准确识别变化区域至关重要。现有的方法往往忽略了边界信息与变化信息之间的关系,导致边界特征提取能力不足,影响检测精度。CBM 利用边界信息来辅助提取变化区域特征,其通过边缘增强操作,强化边界信息,使其在特征图中更加突出。此外,CBM 学习如何更好地利用边界信息来定位变化区域,并指导其他模块进行特征解耦。
对于输入特征,CBM 的实现过程:
- 使用 SimAM 注意力机制识别特征图中的显著区域。
- 通过池化、减法和卷积操作提取边缘特征。
- 使用 SimAM 注意力机制再次强化边缘特征。
- 对增强后的边缘特征进行特征差异操作,得到边缘增强的差异特征。
- 对 CBM 分支的预测结果进行监督学习,指导模型更好地定位变化区域。
Change Boundary-Aware Module 结构图:
3、Deep Feature Extraction Module
遥感图像变化检测需要提取更深层次的语义特征,以更好地理解复杂场景。现有的方法往往由于特征解耦过程的信息损失,导致高级语义信息不足。DFEM 聚合 CBM 和 BFAM 的特征,进行深度特征提取。然后利用残差操作融合特征信息,并提取更深层次的高层次语义特征。
对于输入特征,DFEM 的实现过程:
- 对 CBM 和 BFAM 的特征进行拼接和求和。
- 使用 1x1 卷积进行通道降维。
- 使用残差拼接操作保留信息完整性。
- 使用 3x3 卷积提取深度特征。
- 将上一层的 CBM 输出特征与深度特征进行求和。
- 使用 SimAM 注意力机制进行特征聚合。
Deep Feature Extraction Module 结构图:
4、代码实现
import torch
import torch.nn as nn
from einops import rearrange
class simam_module(torch.nn.Module):
def __init__(self, channels=None, e_lambda=1e-4):
super(simam_module, self).__init__()
self.activaton = nn.Sigmoid()
self.e_lambda = e_lambda
def __repr__(self):
s = self.__class__.__name__ + '('
s += ('lambda=%f)' % self.e_lambda)
return s
@staticmethod
def get_module_name():
return "simam"
def forward(self, x):
b, c, h, w = x.size()
n = w * h - 1
x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5
return x * self.activaton(y)
class BFAM(nn.Module):
def __init__(self, inp, out):
super(BFAM, self).__init__()
self.pre_siam = simam_module()
self.lat_siam = simam_module()
out_1 = int(inp / 2)
self.conv_1 = nn.Conv2d(inp*2, out_1, padding=1, kernel_size=3, groups=out_1,
dilation=1)
self.conv_2 = nn.Conv2d(inp*2, out_1, padding=2, kernel_size=3, groups=out_1,
dilation=2)
self.conv_3 = nn.Conv2d(inp*2, out_1, padding=3, kernel_size=3, groups=out_1,
dilation=3)
self.conv_4 = nn.Conv2d(inp*2, out_1, padding=4, kernel_size=3, groups=out_1,
dilation=4)
self.fuse = nn.Sequential(
nn.Conv2d(out_1 * 4, out_1*2, kernel_size=1, padding=0),
nn.BatchNorm2d(out_1*2),
nn.ReLU(inplace=True)
)
self.fuse_siam = simam_module()
self.out = nn.Sequential(
nn.Conv2d(out_1*2, out, kernel_size=3, padding=1),
nn.BatchNorm2d(out),
nn.ReLU(inplace=True)
)
def forward(self, inp1, inp2, last_feature=None):
x = torch.cat([inp1, inp2], dim=1)
c1 = self.conv_1(x)
c2 = self.conv_2(x)
c3 = self.conv_3(x)
c4 = self.conv_4(x)
cat = torch.cat([c1, c2, c3, c4], dim=1)
fuse = self.fuse(cat)
inp1_siam = self.pre_siam(inp1)
inp2_siam = self.lat_siam(inp2)
inp1_mul = torch.mul(inp1_siam, fuse)
inp2_mul = torch.mul(inp2_siam, fuse)
fuse = self.fuse_siam(fuse)
if last_feature is None:
out = self.out(fuse + inp1 + inp2 + inp2_mul + inp1_mul)
else:
out = self.out(fuse + inp2_mul + inp1_mul + last_feature + inp1 + inp2)
out = self.fuse_siam(out)
return out
class diff_moudel(nn.Module):
def __init__(self, in_channel):
super(diff_moudel, self).__init__()
self.avg_pool = nn.AvgPool2d((3, 3), stride=1, padding=1)
self.conv_1 = nn.Conv2d(in_channel, in_channel, kernel_size=1, stride=1, padding=0)
self.bn1 = nn.BatchNorm2d(in_channel)
self.sigmoid = nn.Sigmoid()
self.simam = simam_module()
def forward(self, x):
x = self.simam(x)
edge = x - self.avg_pool(x) # Xi=X-Avgpool(X)
weight = self.sigmoid(self.bn1(self.conv_1(edge)))
# weight = self.conv_1(edge)
out = weight * x + x
out = self.simam(out)
return out
class CBM(nn.Module):
def __init__(self, in_channel):
super(CBM, self).__init__()
self.diff_1 = diff_moudel(in_channel)
self.diff_2 = diff_moudel(in_channel)
self.simam = simam_module()
def forward(self, x1, x2):
d1 = self.diff_1(x1)
d2 = self.diff_2(x2)
d = torch.abs(d1 - d2)
d = self.simam(d)
return d
class DFEM(nn.Module):
def __init__(self, inc, outc):
super(DFEM, self).__init__()
self.Conv_1 = nn.Sequential(nn.Conv2d(inc*2, outc, kernel_size=1),
nn.BatchNorm2d(outc),
nn.ReLU(inplace=True)
)
self.Conv = nn.Sequential(nn.Conv2d(outc, outc, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(outc),
nn.ReLU(inplace=True),
)
self.relu = nn.ReLU(inplace=True)
self.Up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
def forward(self, diff, accom):
cat = torch.cat([accom, diff], dim=1)
cat = self.Conv_1(cat) + diff + accom
c = self.Conv(cat) + cat
c = self.relu(c) + diff
c = self.Up(c)
return c
if __name__ == '__main__':
x = torch.randn(4, 64, 128, 128).cuda()
y = torch.randn(4, 64, 128, 128).cuda()
# model = BFAM(64,64).cuda()
# out = model(x,y)
# model = CBM(64).cuda()
# out = model(x, y)
model = DFEM(64,64).cuda()
out = model(x, y)
print(out.shape)