目录
📌 论文信息
📄 论文标题:Multisource Joint Representation Learning Fusion Classification for Remote Sensing Images
📚 发表期刊:IEEE TGRS, 2023
✍ 作者:Xueli Geng 等人
🔗 DOI: 10.1109/TGRS.2023.3296813
1. 研究背景
多源遥感影像(如 高光谱(HSI)、SAR(合成孔径雷达)、LiDAR)能够提供丰富的地表信息,广泛应用于 土地利用分类、生态监测、城市规划 等领域。然而,多源数据的成像机制存在差异,导致 模态间特征异质性,影响融合分类的效果。
💡 现有问题
- 优化不均衡问题:不同模态的数据学习速度不同,某些模态的特征可能欠优化。
- 简单的线性融合方式不足:高阶非线性关系被忽略,影响信息的有效集成。
- 冗余信息问题:多源融合可能会包含不必要的信息,影响分类精度。
🎯 解决方案
本文提出 MIBF-Net,基于 信息瓶颈(Information Bottleneck, IB)原理 进行 多源联合表示学习(Joint Representation Learning),通过:
- AD-NAL(属性驱动噪声自适应层):动态平衡不同模态的特征学习速率,从而提取有区别的单源固有信息。
- CRE(跨源关系编码模块):建模跨模态复杂关系,以增强融合表示的丰富性。
- IB-Fusion(信息瓶颈融合模块):降低冗余信息,提高融合特征的判别能力
2. 方法解析
📌 MIBF-Net 结构
MIBF-Net 采用 双分支结构,分别处理 高光谱数据(HSI) 和 SAR/LiDAR 数据,包括以下关键模块:
-
特征提取(Feature Extraction)
- MS-CNN 处理高光谱数据
- P-CNN 处理 SAR/LiDAR 数据
- 结合
Squeeze-and-Excitation (SE) Layer
进行通道注意力增强
-
自适应噪声对齐学习(AD-NAL)
- 通过 自适应噪声扰动(Adaptive Noise Perturbation),平衡两个分支的特征学习
- 随机选择部分样本,对其中一个模态(HS 或 SAR)添加 高斯噪声 ,防止模型对特定数据源产生偏倚
- 通过 伯努利随机变量 确定添加噪声的模态
- 本质上,AD-NAL可以被认为是一种正则化技术,它将噪声引入到高属性的源特征中,并鼓励低属性的源特征学习独立的特征。
其中,加号表示逐元素加法。ε 是从正态分布随机生成的噪声,即,ε ~ N(0,1).δt是概率为qt的伯努利随机变量,qt ≥ 1,即,δt ~ Bernoulli(qt). qt的值决定了添加到每个数据源的噪声的频率,它是从CRE模块中获得的
-
跨模态关系建模(CRE)
- 通过 加权双线性交互(Weighted Bilinear Interaction),建模跨源特征关系
- 外积 g(· , ·) 是实现双线性相互作用的一种简单方法。它可以为二阶特征交互计算,并且所得特征矩阵中每个元素的值可以被解释为来自不同数据源的特征之间的交互。
- 计算 重要性权重 qt,赋予不同模态不同的关注度
-
信息瓶颈融合(IB-Fusion)
-
信息瓶颈(IB)原理旨在通过丢弃尽可能多的与任务无关的输入信息来学习任务相关的表示,从而提高生成的表示的鲁棒性
-
约束联合表示
z
与 原始数据x
的互信息(减少冗余) -
约束联合表示
z
与 类别标签y
的互信息(保留判别信息) -
采用 变分推断(Variational Inference) 进行优化
-
论文中将这个损失函数进行了化简,下面是我根据论文内容做的笔记。
3. 代码解析
📌 代码架构
1️⃣ SELayer(通道注意力机制)
class SELayer(nn.Module):
def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1) # 全局平均池化
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=False),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid() # 计算通道注意力权重
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c) # 计算通道的全局特征
y = self.fc(y).view(b, c, 1, 1) # 变回原尺寸
return x * y.expand_as(x) # 对每个通道进行加权
✅ 作用:
- 计算每个通道的重要性,并通过
Sigmoid
归一化,增强关键特征,抑制无关信息。
2️⃣ MSCNN(高光谱特征提取网络)
class MSCNN(nn.Module):
def __init__(self, channel_msi):
super(MSCNN, self).__init__()
self.conv3x3_1 = nn.Sequential(
nn.Conv2d(channel_msi, 128, 3, 1, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=False)
)
self.conv3x3_2 = nn.Sequential(
nn.Conv2d(128, 64, 1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=False)
)
self.conv3x3_3 = nn.Sequential(
nn.Conv2d(64, 64, 3, 1, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=False),
nn.MaxPool2d(2, 2) # 降低分辨率 (H, W) → (H/2, W/2)
)
self.SELayer = SELayer(64) # 结合通道注意力机制
def forward(self, x):
x = self.conv3x3_1(x)
x = self.conv3x3_2(x)
x = self.SELayer(x) # 提取重要通道
x = self.conv3x3_3(x)
return x
✅ 作用:
- 提取 高光谱图像(HSI) 特征,并通过 SE 模块增强通道信息。
3️⃣ PCNN(SAR/LiDAR 特征提取网络)
class PCNN(nn.Module):
def __init__(self, channel_pan):
super(PCNN, self).__init__()
self.conv3x3_1 = nn.Sequential(
nn.Conv2d(channel_pan, 16, 3, 1, 1, bias=False),
nn.BatchNorm2d(16),
nn.ReLU(inplace=False)
)
self.conv3x3_2 = nn.Sequential(
nn.Conv2d(16, 32, 1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=False)
)
self.conv3x3_3 = nn.Sequential(
nn.Conv2d(32, 64, 3, 1, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=False),
nn.MaxPool2d(2, 2) # 降低分辨率
)
self.SELayer = SELayer(32) # 结合通道注意力
def forward(self, x):
x = self.conv3x3_1(x)
x = self.conv3x3_2(x)
x = self.SELayer(x) # 提取重要通道
x = self.conv3x3_3(x)
return x
✅ 作用:
- 提取 SAR/LiDAR 结构信息,并利用 SELayer 进行特征增强。
4️⃣ NormInteraction(跨模态特征交互)
class NormInteraction(nn.Module):
def __init__(self, inplanes=128):
super(NormInteraction, self).__init__()
self.conv = nn.Conv2d(inplanes * 2, inplanes, kernel_size=1, stride=1) # 1×1 卷积用于跨模态融合
self.conv3 = nn.Sequential(
nn.Conv2d(64, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=False)
)
def forward(self, x1, x2):
concat = torch.cat([x1, x2], 1) # 拼接两个模态特征
gate = torch.sigmoid(self.conv(concat)) # 计算模态注意力权重
p1 = x1 * gate
p2 = x2 * (1 - gate)
p = torch.matmul(p1, p2) # 计算模态关联
p = self.conv3(p)
return p, torch.mean(gate) # 返回融合特征
✅ 作用:
- 通过 门控机制 选择不同模态的重要性,增强跨模态信息交互。
5️⃣ IBfusion(信息瓶颈融合)
class IBfusion(nn.Module):
def __init__(self, dim, num_classes, inner=128):
super(IBfusion, self).__init__()
self.encoder = nn.Sequential(nn.Linear(dim, inner), nn.ReLU())
self.fc_mu = nn.Linear(inner, dim) # 计算 `μ`
self.fc_std = nn.Linear(inner, dim) # 计算 `σ`
self.decoder = nn.Linear(dim, num_classes) # 用 `z` 进行分类
def encode(self, x):
x = self.encoder(x)
return self.fc_mu(x), F.softplus(self.fc_std(x) - 5)
def reparameterise(self, mu, std):
eps = torch.randn_like(std)
return mu + std * eps # 生成 `z`
def forward(self, x):
mu, std = self.encode(x)
z = self.reparameterise(mu, std)
return self.decoder(z), mu, std
✅ 作用:
- 使用变分信息瓶颈(VIB)优化特征表达,减少冗余信息,提高分类效果。
6️⃣ MIBfusion(最终模型)
class MIBfusion(nn.Module):
def __init__(self, in_channels1, in_channels2, num_classes):
super(MIBfusion, self).__init__()
self.net1 = MSCNN(in_channels1)
self.net2 = PCNN(in_channels2)
self.interaction = NormInteraction(64)
self.m = torch.tensor(0.5) #随机掩码控制因子
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) #将输入的 H × W 特征图转换为 1 × 1 特征向量
self.f1 = nn.Sequential(
nn.Conv2d(64, 64, 3, 1, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=False)
)
self.f2 = nn.Sequential(
nn.Conv2d(64, 64, 3, 1, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=False)
)
self.ff = nn.Sequential( #融合特征变换
nn.Conv2d(256, 128, 1),
nn.BatchNorm2d(128),
nn.ReLU(),
)
self.fc = nn.Sequential(
nn.Linear(128, num_classes),
)
self.IB_fusion = IBfusion(128, 128, num_classes)
self.ii = torch.tensor(0.0) #记录训练时的步数,用于控制 branch1 和 branch2 施加噪声的时机
def forward(self, HS, SAR):
self.ii = self.ii + 1 #每次前向传播,步数加 1
branch1 = self.net1(HS)
branch2 = self.net2(SAR) #提取 HS 和 SAR 的特征
#AD-NAL,平衡两个分支的特征提取,减轻对特定源数据的依赖
if self.training and self.ii % 1 == 0:#训练模式and恒成立
o = torch.bernoulli(1-self.m) #随机生成 0 或 1
size = branch1[0].size() #单个样本大小
indices = torch.randint(0, branch1.shape[0], [int(branch1.shape[0] * 0.7), 1]) #70% 样本会被选中施加噪声
if o == 0:
for ind in indices:
rand_g = torch.randn(size, dtype=torch.float32).to(branch1.device)
branch1[ind] += rand_g
else:
for ind in indices:
rand_g = torch.randn(size, dtype=torch.float32).to(branch2.device)
branch2[ind] += rand_g
align_fusion, self.m = self.interaction(branch1, branch2) #交互融合
branch1_1 = self.f1(branch1) #增强特征
branch2_1 = self.f2(branch2)
align_fusion = torch.cat([branch1_1, align_fusion, branch2_1], 1) #拼接三者
align_fusion = self.ff(align_fusion)
x = self.avg_pool(align_fusion)
x = torch.flatten(x, 1) #变成向量 (batch_size, 128)
out, mu, std = self.IB_fusion(x)
return out, mu, std
✅ 作用:
- 联合 MSCNN 和 PCNN 进行多模态特征提取
- 通过 AD-NAL 进行动态平衡不同模态的特征学习速率
- 利用 NormInteraction 进行跨模态融合
- 使用 IB-Fusion 进行信息压缩,提高分类能力
7️⃣ IBloss_R(损失函数)
class IBloss_R(nn.Module):
def __init__(self):
super(IBloss_R, self).__init__()
self.alpha = 1e-3
def forward(self, output, y):
logit, mu, std = output
class_loss = F.cross_entropy(logit, y.squeeze())
info_loss = 0.5 * torch.mean(mu.pow(2) + std.pow(2) - 2 * std.log() - 1)
return class_loss + self.alpha * info_loss
✅ 作用:
- 交叉熵损失 用于分类任务
- 信息瓶颈损失 限制
z
的冗余信息,增强判别能力
总结
✅ MIBfusion 通过:
- MSCNN & PCNN 提取模态特征
- AD-NAL 平衡特征学习速率
- NormInteraction 进行模态交互
- IB-Fusion 进行信息约束
- 最终用 IBloss_R 进行优化
🔹 提高了遥感影像分类的准确率,同时减少了冗余信息,提高了模型的泛化能力! 🚀
4. 复现实验与收获
实验结果表明,我跑的MUUFL和Houston-2013数据集和论文中的结果相比差距很小,属于正常偏差。
数据集 | 精度(OA) | AA | Kappa | LR |
---|---|---|---|---|
MUUFL | 92.64% | 80.18% | 90.27% | 0.001 |
Houston-2013 | 96.63% | 96.93% | 96.35% | 0.0005 |
🔹 个人收获
1️⃣ 深入理解信息瓶颈理论
- 这篇论文让我理解了如何在分类任务中利用信息瓶颈(IB)进行特征优化。
- IB-Fusion 让我思考:如何有效减少冗余信息,同时增强判别信息,这对大规模遥感影像分析具有重要启发。
2️⃣ 代码复现
- 从最基础的debug开始,改文件路径,参数等等。
- 在复现论文时,我看懂了模型的代码模块化设计,作者让 MSCNN、PCNN、NormInteraction 互相独立,便于后续改进。
- 当我试图尝试其他数据集比如Trento时,我设置了很多不同的学习率,发现训练和测试准确率在几个epoch之后就持续100%,结果异常的好,可能是数据集的原因,或者这是正常现象……
3️⃣ 多源数据融合的挑战
- 不同模态的数据分布不同,简单拼接可能会影响模型效果,而 MIBF-Net 提供了一种基于联合表征学习的解决方案。
- 实验中,我发现不同超参数(如LR)对结果影响很大,这让我思考如何通过参数调优进一步提升模型性能。
5. 结论
✅ MIBF-Net 有效融合多源数据,避免信息冗余
✅ 信息瓶颈约束减少了冗余,提高了判别能力,使得模型更具泛化性。
✅ 复现过程中,我收获了深度学习与信息理论结合的实践经验,为后续研究提供了方向。 🚀