学习笔记1:遥感顶刊-FFCA-YOLO for Small Object Detectionin Remote Sensing Images,FEM和SCAM模块代码

文章:https://ieeexplore.ieee.org/abstract/document/10423050

摘要——特征表示不足、背景混淆等问题,使得遥感中小目标的检测任务十分艰巨。特别是,当算法将部署到机载进行实时处理时,这需要在有限的计算资源下对精度和速度进行广泛的优化。为
了解决这些问题,本文提出了一种高效的特征增强、融合和上下文感知YOLO (FFCA-YOLO)检测器。FFCA-YOLO包括三个创新的轻量级即插即用模块:特征增强模块(FEM)、特征融合模块(FFM)和空间上下文感知模块(SCAM)。这三个模块分别提高了局部区域感知、多尺度特征融合以及跨通道和空间的全局关联的网络能力,同时尽可能避免增加复杂性。从而增强小物体的弱特征表征,抑制易混淆的背景。利用VEDAI和AI-TOD两个公共遥感小目标检测数据集和USOD一个自建数据集验证了FFCA-YOLO的有效性。FFCA-YOLO的精度达到0.748、0.617和0.909(以mAP50计算),超过了几个基准模型和最先进的方法。同时,在不同的模拟退化条件下,验证了FFCA-YOLO的鲁棒性。此外,为了在保证效率的同时进一步降低计算资源消耗,通过基于部分卷积(PConv)重构FFCA-YOLO的主干和颈部,对精简版FFCA-YOLO (L-FFCA-YOLO)进行了优化。与FFCA-YOLO相比,L-FFCA-YOLO速度更快,参数规模更小,计算能力要求更低,精度损失小。源代码可从https://github.com/yemu1138178251/FFCA-YOLO获得。

创新点之一:提出了三种创新的轻量级即插即用模块:FEM、FFM和SCAM。这三个模块分别提高了局部区域感知、多尺度特征融合和跨信道、跨空间的全局关联的网络能力。它们可以作为通用模块插入到任何检测网络中,增强小物体的弱特征表征,抑制易混淆的背景。
 

FEM放在主干网络后,SCAM用于head前

论文中yolo整体网络结构:

FCM结构:

分为四个分支,由标准卷积和空洞卷积组成。

class FEM(nn.Module):
    r""" Feature Enhancement Module (FEM).

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
    """

    def __init__(self, in_channels, out_channels):
        super(FEM, self).__init__()
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),         
        )
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1),           
            nn.Conv2d(out_channels, out_channels, kernel_size=(1, 3), padding=(0, 1)),
            nn.Conv2d(out_channels, out_channels, kernel_size=(3, 1), padding=(1, 0)),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=3, dilation=3)
        )
        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.Conv2d(out_channels, out_channels, kernel_size=(3, 1), padding=(1, 0)),
            nn.Conv2d(out_channels, out_channels, kernel_size=(1, 3), padding=(0, 1)),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=3, dilation=3)
        )
        self.branch4 = nn.Conv2d(in_channels, out_channels, kernel_size=1)

        self.conv_cat = nn.Conv2d(out_channels * 4, out_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        x4 = self.branch4(x)
        x_cat = torch.cat([x1, x2, x3, x4], dim=1)
        out = self.conv_cat(x_cat)
        return out

验证一下各分支输出形状;

# 创建输入数据
input_data = torch.randn(2, 3, 192, 192)  # 假设输入通道为3,图像尺寸192x192

# 创建模型实例
model = FEM(in_channels=3, out_channels=64)

# 前向传播并打印每一层的输出形状
with torch.no_grad():
    x1 = model.branch1(input_data)
    print(f'x1 shape: {x1.shape}')
    x2 = model.branch2(input_data)
    print(f'x2 shape: {x2.shape}')
    x3 = model.branch3(input_data)
    print(f'x3 shape: {x3.shape}')
    x4 = model.branch4(input_data)
    print(f'x4 shape: {x4.shape}')
    x_cat = torch.cat([x1, x2, x3, x4], dim=1)
    print(f'x_cat shape: {x_cat.shape}')
    out = model.conv_cat(x_cat)
    print(f'out shape: {out.shape}')
各形状输出如下:
x1 shape: torch.Size([2, 64, 192, 192])
x2 shape: torch.Size([2, 64, 192, 192])
x3 shape: torch.Size([2, 64, 192, 192])
x4 shape: torch.Size([2, 64, 192, 192])
x_cat shape: torch.Size([2, 256, 192, 192])
out shape: torch.Size([2, 64, 192, 192])

SCAM结构:

class SCAM(nn.Module):
    def __init__(self,in_channels):
        super(SCAM, self).__init__()
        # 分支1 平均和最大池化层
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        # 分支2 value conv
        self.value_conv = nn.Conv2d(in_channels,in_channels,kernel_size=1)
        # 分支3 qk conv
        self.qk_conv = nn.Conv2d(in_channels,1,kernel_size=1)
        # 分支4  最后结合 conv
        self.final_r_conv = nn.Conv2d(in_channels,in_channels,kernel_size=1)
        self.final_l_conv = nn.Conv2d(2,in_channels,kernel_size=1)
     
    def forward(self,x):
        B,C,H,W = x.size()
        max_out = self.max_pool(x) # [B,C,1,1]
        avg_out = self.avg_pool(x)  # [B,C,1,1]
        x1 = torch.cat((max_out,avg_out),dim=-2) #[B,C,2,1]
        x1 = F.softmax(x1,dim=-2) # [B,C,2,1]
        # value
        x2 = self.value_conv(x) #[B,C,H,W]
        x2 = x2.view(B,C,-1)   #[B,C,HW] 
        # QK        
        x3 = self.qk_conv(x) #[B,1,H,W]
        x3 = x3.view(B,1,-1).permute(0,2,1)   #[B,1,HW] ->[B,HW,1]
        x3 = F.softmax(x3,dim=1) #[B,HW,1]
        # 右分支 矩阵相乘
        bran_r = torch.matmul(x2,x3) #[B,C,HW] X [B,HW,1] = [B,C,1]
        bran_r = bran_r.view(B,C,1,1) #[B,C,1] -> [B,C,1,1]
        bran_r = self.final_r_conv(bran_r) # [B,C,1,1]
        # 左分支 矩阵相乘
        x11 = x1.view(B,C,2) #[B,C,2,1] -> [B,C,2]
        x2l = x2.permute(0,2,1)   #[B,C,HW] -> #[B,HW,C] 
        bran_l = torch.matmul(x2l,x11) #[B,HW,2]
        bran_l = bran_l.view(B,H,W,2) #[B,H,W,2]
        bran_l = bran_l.permute(0,3,1,2) #[B,2,H,W]
        bran_l = self.final_l_conv(bran_l) #[B,C,H,W]
        out = bran_l*bran_r
        
        return out

验证一下代码输入输出:

input_data = torch.randn(2, 3, 192, 96) 
model = SCAM(3)
Y = model(input_data)
Y.shape

输出:
torch.Size([2, 3, 192, 96])

  • 8
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值