paper:HCF-Net: Hierarchical Context Fusion Network for Infrared Small Object Detection
1、Parallelized Patch-Aware Attention
红外小目标检测是一项重要的计算机视觉任务,涉及到对红外图像中微小目标的识别和定位,而红外图像通常只包含几个像素。由于红外图像中目标的尺寸小,背景复杂等原因,使得红外图像处理技术遇到了困难。论文提出了一种 并行化补丁感知注意力(Parallelized Patch-Aware Attention)。在红外小目标检测任务中,小目标在多次下采样过程中容易丢失重要信息。而 PPA 的提出则替代了编码器和解码器基本组件中的传统卷积。
PPA 的主要优势是多分支特征提取策略。PPA 通过采用多分支特征提取策略,使得每个分支都是在多尺度的层面上去提取特征。因而这种多分支策略有利于目标多尺度特征的捕获,从而提高了小目标检测的准确性。
具体来说,PPA 涉及三个部分:Patch-Aware、Feature fusion 和 Attention。对于一个给定的输入特征张量F:
1、首先,通过 PWConv 进行调整,得到F’。然后,通过两条 Patch-Aware 分支,一条卷积分支,可以分别计算F_local、F_global 和 F_conv。最后将三个结果相加,得到输出 F。在这其中,Patch-Aware 是将特征划分为多个补丁,然后通过多尺度卷积来捕获特征,让模型能够更好的关注多个范围的区域。
2、在 Patch-Aware 尾部,包含一个 feature fusion 部分,在这里通过选择特征分为 token 和 Channel,以确保模型使用最合适的上下文信息进行检测。
3、在模块最后,使用一个注意力机制来进行自适应特征增强。注意力模块包括 通道注意力 和 空间注意力。
PPA 结构图:
2、代码实现
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class SpatialAttentionModule(nn.Module):
def __init__(self):
super(SpatialAttentionModule, self).__init__()
self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avgout = torch.mean(x, dim=1, keepdim=True)
maxout, _ = torch.max(x, dim=1, keepdim=True)
out = torch.cat([avgout, maxout], dim=1)
out = self.sigmoid(self.conv2d(out))
return out * x
class LocalGlobalAttention(nn.Module):
def __init__(self, output_dim, patch_size):
super().__init__()
self.output_dim = output_dim
self.patch_size = patch_size
self.mlp1 = nn.Linear(patch_size * patch_size, output_dim // 2)
self.norm = nn.LayerNorm(output_dim // 2)
self.mlp2 = nn.Linear(output_dim // 2, output_dim)
self.conv = nn.Conv2d(output_dim, output_dim, kernel_size=1)
self.prompt = torch.nn.parameter.Parameter(torch.randn(output_dim, requires_grad=True))
self.top_down_transform = torch.nn.parameter.Parameter(torch.eye(output_dim), requires_grad=True)
def forward(self, x):
x = x.permute(0, 2, 3, 1)
B, H, W, C = x.shape
P = self.patch_size
# Local branch
local_patches = x.unfold(1, P, P).unfold(2, P, P) # (B, H/P, W/P, P, P, C)
local_patches = local_patches.reshape(B, -1, P * P, C) # (B, H/P*W/P, P*P, C)
local_patches = local_patches.mean(dim=-1) # (B, H/P*W/P, P*P)
local_patches = self.mlp1(local_patches) # (B, H/P*W/P, input_dim // 2)
local_patches = self.norm(local_patches) # (B, H/P*W/P, input_dim // 2)
local_patches = self.mlp2(local_patches) # (B, H/P*W/P, output_dim)
local_attention = F.softmax(local_patches, dim=-1) # (B, H/P*W/P, output_dim)
local_out = local_patches * local_attention # (B, H/P*W/P, output_dim)
cos_sim = F.normalize(local_out, dim=-1) @ F.normalize(self.prompt[None, ..., None], dim=1) # B, N, 1
mask = cos_sim.clamp(0, 1)
local_out = local_out * mask
local_out = local_out @ self.top_down_transform
# Restore shapes
local_out = local_out.reshape(B, H // P, W // P, self.output_dim) # (B, H/P, W/P, output_dim)
local_out = local_out.permute(0, 3, 1, 2)
local_out = F.interpolate(local_out, size=(H, W), mode='bilinear', align_corners=False)
output = self.conv(local_out)
return output
class ECA(nn.Module):
def __init__(self, in_channel, gamma=2, b=1):
super(ECA, self).__init__()
k = int(abs((math.log(in_channel, 2) + b) / gamma))
kernel_size = k if k % 2 else k + 1
padding = kernel_size // 2
self.pool = nn.AdaptiveAvgPool2d(output_size=1)
self.conv = nn.Sequential(
nn.Conv1d(in_channels=1, out_channels=1, kernel_size=kernel_size, padding=padding, bias=False),
nn.Sigmoid()
)
def forward(self, x):
out = self.pool(x)
out = out.view(x.size(0), 1, x.size(1))
out = self.conv(out)
out = out.view(x.size(0), x.size(1), 1, 1)
return out * x
class conv_block(nn.Module):
def __init__(self,
in_features,
out_features,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
dilation=(1, 1),
norm_type='bn',
activation=True,
use_bias=True,
groups=1
):
super().__init__()
self.conv = nn.Conv2d(in_channels=in_features,
out_channels=out_features,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=use_bias,
groups=groups)
self.norm_type = norm_type
self.act = activation
if self.norm_type == 'gn':
self.norm = nn.GroupNorm(32 if out_features >= 32 else out_features, out_features)
if self.norm_type == 'bn':
self.norm = nn.BatchNorm2d(out_features)
if self.act:
# self.relu = nn.GELU()
self.relu = nn.ReLU(inplace=False)
def forward(self, x):
x = self.conv(x)
if self.norm_type is not None:
x = self.norm(x)
if self.act:
x = self.relu(x)
return x
class PPA(nn.Module):
def __init__(self, in_features, filters) -> None:
super().__init__()
self.skip = conv_block(in_features=in_features,
out_features=filters,
kernel_size=(1, 1),
padding=(0, 0),
norm_type='bn',
activation=False)
self.c1 = conv_block(in_features=in_features,
out_features=filters,
kernel_size=(3, 3),
padding=(1, 1),
norm_type='bn',
activation=True)
self.c2 = conv_block(in_features=filters,
out_features=filters,
kernel_size=(3, 3),
padding=(1, 1),
norm_type='bn',
activation=True)
self.c3 = conv_block(in_features=filters,
out_features=filters,
kernel_size=(3, 3),
padding=(1, 1),
norm_type='bn',
activation=True)
self.sa = SpatialAttentionModule()
self.cn = ECA(filters)
self.lga2 = LocalGlobalAttention(filters, 2)
self.lga4 = LocalGlobalAttention(filters, 4)
self.bn1 = nn.BatchNorm2d(filters)
self.drop = nn.Dropout2d(0.1)
self.relu = nn.ReLU()
self.gelu = nn.GELU()
def forward(self, x):
x_skip = self.skip(x)
x_lga2 = self.lga2(x_skip)
x_lga4 = self.lga4(x_skip)
x1 = self.c1(x)
x2 = self.c2(x1)
x3 = self.c3(x2)
x = x1 + x2 + x3 + x_skip + x_lga2 + x_lga4
x = self.cn(x)
x = self.sa(x)
x = self.drop(x)
x = self.bn1(x)
x = self.relu(x)
return x
if __name__ == '__main__':
x = torch.randn(4, 64, 128, 128).cuda()
model = PPA(64, 64).cuda()
out = model(x)
print(out.shape)
x = self.bn1(x)
x = self.relu(x)
return x
if name == ‘main’:
x = torch.randn(4, 64, 128, 128).cuda()
model = PPA(64, 64).cuda()
out = model(x)
print(out.shape)