论文重点提取--Adversarial_Complementary_Learning_for_Multisource_Remote_Sensing_Classification

一种对抗互补学习(ACL)卷积神经网络(CNN)方法,用于多源遥感分类任务。

Adversarial Complementary Learning for Multisource Remote Sensing Classification | IEEE Journals & Magazine | IEEE Xplore

论文的架构很简单如下图所示,一目了然。

ACL主要包含两方面:
  1. 通过对抗生成器和判别器进行互补学习,从多源数据中提取出共同模式和特定模式,形成紧凑且判别性特征表示;

{生成器尝试最小化其损失(即生成更容易被判别器误认为是真实的数据),而判别器尝试最大化其损失(即更好地区分真实数据和生成数据)。}

       2.设计了模式采样模块(PSM)从特定模式中提取出互斥关系,以消除奇异噪声。

{如冗余频谱和斑点噪声,互斥关系指的是两种模式不能同时存在的关系,因为它们彼此之间存在冲突或矛盾。通过提取这些互斥关系,可以有效地减少噪声的影响,并提高分类的准确性。}

        此外,本文还设计了共享解码器保证特征完整性,并采用分类器进行标签预测。

优点

1.提出对抗互补方法,提取互补信息,从而提高了分类性能。

2. 模型在分类准确率、平均准确率等指标上优于多个基准方法。

缺点

1.大量的标注样本?-->半监督,无监督

2. 判别损失的权重要自己设置-->自适应权重

3.对抗学习网络复杂-->简化网络

ACL_CNN代码解读

论文并没有给开源代码,这个是自己编写的代码,希望对读者有帮助。

import torch
import torch.nn as nn
import torch.nn.functional as F

class ACL_CNN(nn.Module):
    def __init__(self, in_channels_pan, in_channels_sar, num_classes=7):
        super(ACL_CNN, self).__init__()
        # 定义用于处理PAN图像的卷积层,这里的in_channels应该与PAN图像的通道数相匹配
        self.pan_conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channels_pan, out_channels=32, kernel_size=3,  stride=1,padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )

        # 定义用于处理SAR图像的卷积层,这里的in_channels应该与SAR图像的通道数相匹配
        self.sar_conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channels_sar, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        # 定义共享的全连接层
        self.shared_fc = nn.Sequential(
            nn.Linear(64 * 16 * 16, 512),  # 假设每个卷积层输出64个通道,特征图大小为16x16
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)  # 假设有num_classes个类别
        )
        # 其他层保持不变...

    def forward(self, pan, sar):
        # PAN特征提取
        pan_features = self.pan_conv(pan)
        # SAR特征提取
        sar_features = self.sar_conv(sar)
        # 特征融合,这里需要根据论文中的描述来实现具体的融合策略
        # 例如,可以通过拼接、加权求和或其他方式来融合特征
        combined_features = torch.cat([pan_features, sar_features], dim=1)  # 假设特征图尺寸相同,直接拼接
        # 或者使用其他融合策略,如注意力机制等
        # 分类器
        # 假设combined_features是经过融合的特征图,尺寸为[batch_size, channels, height, width]
        # 我们需要将其转换为一维向量以输入到全连接层
        batch_size, channels, height, width = combined_features.size()
        # 计算展平后的特征尺寸
        flat_features_size = channels * height * width
        combined_features = combined_features.view(batch_size, -1)  # 展平特征图
        # 确保展平后的特征图尺寸与全连接层的输入尺寸相匹配
        # 确保展平后的特征图尺寸与全连接层的输入尺寸相匹配
        assert flat_features_size == combined_features.size(1), "The size of the flattened features does not match the expected input size for the fully connected layer."
        # 通过全连接层进行分类
        # 确保展平后的特征图尺寸与全连接层的输入尺寸相匹配
        # 这里我们动态地调整全连接层的输入特征尺寸
        self.shared_fc[0] = nn.Linear(flat_features_size, 512)
        output = self.shared_fc(combined_features)
        return output

# 检查卷积层的输出尺寸
model = ACL_CNN(4, 1, 7)
dummy_input_pan = torch.randn(1, 4, 16, 16)
dummy_input_sar = torch.randn(1, 1, 16, 16)
with torch.no_grad():
    output = model(dummy_input_pan, dummy_input_sar)

print(output)

完结撒花❀,感谢看到最后的每一个人~~

  • 4
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值