import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.ops import DeformConv2d
class EnhancedSRU(nn.Module):
"""增强型空间-通道双注意力机制"""
def __init__(self, channels, reduction=16):
super(EnhancedSRU, self).__init__()
# 通道注意力
self.channel_attn = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, channels // reduction, 1),
nn.GELU(),
nn.Conv2d(channels // reduction, channels, 1),
nn.Sigmoid()
)
# 空间注意力
self.spatial_attn = nn.Sequential(
nn.Conv2d(channels, 1, kernel_size=3, padding=1),
nn.Sigmoid()
)
# 自适应缩放因子
self.scale = nn.Parameter(torch.tensor(1.0))
def forward(self, x):
# 通道注意力
channel_weight = self.channel_attn(x)
# 空间注意力
spatial_weight = self.spatial_attn(x)
# 双注意力融合
attn_weight = channel_weight * spatial_weight
return x * (self.scale * attn_weight + 1) # 保留原始特征[^1]
class MultiScaleDeformableConv(nn.Module):
"""多尺度可变形卷积组"""
def __init__(self, in_channels, out_channels, deform_groups=4):
super().__init__()
# 偏移量生成网络
self.offset_conv = nn.Sequential(
nn.Conv2d(in_channels, in_channels // 2, kernel_size=3, padding=1),
nn.GroupNorm(4, in_channels // 2),
nn.GELU(),
nn.Conv2d(in_channels // 2, 2 * 3 * 3 * len([1, 2, 3]), kernel_size=1) # 3种尺度
)
self.offset_bn = nn.BatchNorm2d(2 * 3 * 3 * 3)
# 多尺度可变形卷积
self.deform_convs = nn.ModuleList([
DeformConv2d(
in_channels,
out_channels,
kernel_size=3,
padding=d_rate,
dilation=d_rate,
deform_groups=deform_groups
) for d_rate in [1, 2, 3] # 多尺度膨胀率
])
def forward(self, x):
# 生成多尺度偏移量
offsets = self.offset_conv(x)
offsets = self.offset_bn(offsets)
offsets = 2.0 * torch.sigmoid(offsets) - 1.0 # [-1, 1]范围
# 拆分不同尺度的偏移量
offset_chunks = torch.chunk(offsets, 3, dim=1)
# 多尺度特征提取
features = []
for i, conv in enumerate(self.deform_convs):
feat = conv(x, offset_chunks[i])
features.append(feat)
return torch.cat(features, dim=1)
class OptimizedDSC(nn.Module):
"""优化后的可变形空间卷积模块"""
def __init__(self, in_channels, out_channels, kernel_size=3, use_residual=True):
super(OptimizedDSC, self).__init__()
self.use_residual = use_residual
# 多尺度可变形卷积组
self.multi_scale_deform = MultiScaleDeformableConv(in_channels, out_channels // 3)
# 特征融合
self.fusion = nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size=1),
nn.BatchNorm2d(out_channels),
nn.GELU()
)
# 边缘增强
self.edge_enhance = nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.Sigmoid()
)
# 注意力机制
self.attn = EnhancedSRU(out_channels, reduction=8)
# 残差连接
if use_residual and in_channels != out_channels:
self.res_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1),
nn.BatchNorm2d(out_channels)
)
else:
self.res_conv = nn.Identity()
# 深度可分离卷积增强
self.depthwise = nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size=3,
padding=1, groups=out_channels),
nn.BatchNorm2d(out_channels),
nn.GELU(),
nn.Conv2d(out_channels, out_channels, kernel_size=1),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
identity = self.res_conv(x)
# 多尺度特征提取
x = self.multi_scale_deform(x)
# 特征融合
x = self.fusion(x)
# 边缘增强
edge = self.edge_enhance(x)
x = x * edge + x
# 注意力机制
x = self.attn(x)
# 深度可分离卷积
x = self.depthwise(x)
# 残差连接
if self.use_residual:
x = x + identity
return x我应该怎么能将他们加入到yolov11中,我的yaml文件怎么改backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 2, C3k2, [256, False, 0.25]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 2, C3k2, [512, False, 0.25]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 2, C3k2, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 2, nn.Module, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
- [-1, 2, C2PSA, [1024]] # 10
# YOLO11n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 2, C3k2, [512, False]] # 13
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 13], 1, Concat, [1]] # cat head P4
- [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 10], 1, Concat, [1]] # cat head P5
- [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
- [[16, 19, 22], 1, Detect, [nc]] # Detect(P3, P4, P5)
最新发布