鱼弦:公众号【红尘灯塔】,CSDN博客专家、内容合伙人、新星导师、全栈领域优质创作者 、51CTO(Top红人+专家博主) 、github开源爱好者(go-zero源码二次开发、游戏后端架构 https://github.com/Peakchen)
YOLOv8改进 | 注意力机制 | 添加EMAttention注意力机制(附多个可添加位置)
介绍
本文介绍了一种将EMAttention注意力机制添加到YOLOv8模型中的方法,以提高模型的特征提取能力和目标检测精度。EMAttention注意力机制是一种基于嵌入矩阵的注意力机制,可以有效地捕捉特征之间的长距离依赖关系,从而增强模型对关键信息的关注。
原理详解
1. EMAttention注意力机制
EMAttention注意力机制主要包括以下步骤:
- 嵌入映射: 将输入特征映射到嵌入空间。
- 相关性计算: 计算嵌入向量之间的相关性。
- 注意力权重生成: 使用相关性和嵌入向量生成注意力权重。
- 加权特征融合: 根据注意力权重对输入特征进行加权融合。
具体来说,EMAttention注意力机制首先将输入特征映射到一个嵌入空间,然后计算嵌入向量之间的相关性。相关性衡量的是两个嵌入向量之间语义的相似程度。接下来,使用相关性和嵌入向量生成注意力权重。注意力权重代表了每个嵌入向量对最终输出的重要性。最后,根据注意力权重对输入特征进行加权融合,以突出重要的特征并抑制无关信息。
2. 应用于YOLOv8模型
将EMAttention注意力机制添加到YOLOv8模型中,可以将其插入到多个位置,例如:
- Neck部分: 在Neck部分的特征融合模块中,使用EMAttention注意力机制来增强不同尺度特征之间的融合效果。
- CSP块: 在CSP块中,在残差连接之前使用EMAttention注意力机制来增强特征的表达能力。
- 预测头: 在预测头中,在每个类别分支之前使用EMAttention注意力机制来关注与该类别相关的特征。
在Neck部分,可以使用EMAttention注意力机制来融合不同尺度特征。例如,在Path Aggregation Network(PAN)中,可以使用EMAttention注意力机制来代替传统的加权求和操作。这可以帮助模型更好地学习不同尺度特征之间的关系,从而提高目标检测的精度。
在CSP块中,可以使用EMAttention注意力机制来增强特征的表达能力。CSP块是一种常用的特征提取结构,它包含多个残差连接和剪枝操作。在残差连接之前,可以使用EMAttention注意力机制来关注重要的特征并抑制无关信息。这可以帮助模型学习更具代表性的特征,从而提高目标检测的性能。
在预测头中,可以使用EMAttention注意力机制来关注与特定类别相关的特征。预测头负责将特征映射到类别预测和位置预测。在每个类别分支之前,可以使用EMAttention注意力机制来关注与该类别相关的特征。这可以帮助模型更好地区分不同类别的目标,从而提高目标检测的准确性。
应用场景解释
该方法可以应用于各种需要高精度目标检测的任务,例如:
- 自然图像目标检测: 检测自然图像中的各种物体,例如行人、车辆、动物等。
- 无人机目标检测: 从无人机拍摄的图像中检测目标,例如建筑物、道路、人员等。
- 医学图像分析: 从医学图像中检测病灶、细胞等。
例如,在自然图像目标检测中,可以使用该方法来提高车辆检测的精度。车辆检测是一项具有挑战性的任务,因为车辆的大小、形状和外观可能存在很大的差异。EMAttention注意力机制可以帮助模型更好地学习车辆特征,从而提高检测精度。
在无人机目标检测中,可以使用该方法来检测建筑物。建筑物通常具有复杂的结构和纹理,这使得它们难以检测。EMAttention注意力机制可以帮助模型更好地学习建筑物特征,从而提高检测精度。
在医学图像分析中,可以使用该方法来检测病灶。病灶通常很小且不明显,这使得它们难以检测。EMAttention注意力机制可以帮助模型更好地学习病灶特征,从而提高检测精度。
算法实现(完整详细)
1. 定义EMAttention模块
import torch
import torch.nn as nn
class EMAttention(nn.Module):
def __init__(self, in_channels, embed_dim, num_heads):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.query_proj = nn.Linear(in_channels, embed_dim * num_heads)
self.key_proj = nn.Linear(in_channels, embed_dim * num_heads)
self.value_proj = nn.Linear(in_channels, embed_dim * num_heads)
self.attn_dropout = nn.Dropout(p=0.1)
self.out_proj = nn.Linear(embed_dim * num_heads, in_channels)
def forward(self, x):
q = self.query_proj(x)
k = self.key_proj(x)
v = self.value_proj(x)
q = q.reshape(x.shape[0], x.shape[1], self.num_heads, self.embed_dim)
k = k.reshape(x.shape[0], x.shape[1], self.num_heads, self.embed_dim)
v = v.reshape(x.shape[0], x.shape[1], self.num_heads, self.embed_dim)
attn = torch.bmm(q, k.transpose(1, 2))
attn = attn / torch.sqrt(self.embed_dim)
attn = attn.softmax(dim=-1)
out = torch.bmm(attn, v)
out = out.reshape(x.shape[0], x.shape[1], -1)
out = self.out_proj(out)
out = out + x
return out
2. 将EMAttention集成到YOLOv8中
2.1 Neck部分
在Neck部分,可以使用EMAttention注意力机制来增强不同尺度特征之间的融合效果。例如,在Path Aggregation Network(PAN)中,可以使用EMAttention注意力机制来代替传统的加权求和操作。具体修改如下:
class PAN(nn.Module):
def __init__(self, channels):
super().__init__()
self.use_cat = channels[0] != channels[-1]
self.conv1 = nn.Conv2d(channels[0], channels[1], kernel_size=1, stride=1, padding=0)
self.conv2 = nn.Conv2d(channels[1], channels[2], kernel_size=1, stride=1, padding=0)
self.conv3 = nn.Conv2d(channels[2], channels[3], kernel_size=1, stride=1, padding=0)
self.emattn1 = EMAttention(channels[1], 64, 8)
self.emattn2 = EMAttention(channels[2], 64, 8)
self.emattn3 = EMAttention(channels[3], 64, 8)
def forward(self, x):
p1 = self.conv1(x[0])
p2 = self.conv2(x[1])
p3 = self.conv3(x[2])
p1 = self.emattn1(p1)
p2 = self.emattn2(p2)
p3 = self.emattn3(p3)
if self.use_cat:
return torch.cat([p1, p2, p3], dim=1)
else:
up = F.upsample(p3, scale_factor=2)
down1 = F.upsample(p2, scale_factor=2)
down2 = F.upsample(p1, scale_factor=4)
return p1 + up + down1 + down2
2.2 CSP块
在CSP块中,可以使用EMAttention注意力机制来增强特征的表达能力。CSP块是一种常用的特征提取结构,它包含多个残差连接和剪枝操作。在残差连接之前,可以使用EMAttention注意力机制来关注重要的特征并抑制无关信息。具体修改如下:
class CSPDarknet(nn.Module):
def __init__(self, in_channels, out_channels, depth, residual_block, **kwargs):
super().__init__()
self.depth = depth
# Stem
self.stem = nn.Sequential(
ConvBnAct(in_channels, out_channels, kernel_size=3, stride=1, padding=1, **kwargs),
ConvBnAct(out_channels, out_channels * 2, kernel_size=3, stride=2, padding=1, **kwargs)
)
# CSP stages
self.csp_stages = nn.ModuleList()
for i in range(self.depth):
in_ch = out_channels * 2 if i == 0 else in_ch * 2
out_ch = out_channels * 2
csp_stage = CSPStage(in_ch, out_ch, residual_block, **kwargs)
self.csp_stages.append(csp_stage)
def forward(self, x):
x = self.stem(x)
for csp_stage in self.csp_stages:
x = csp_stage(x)
return x
class CSPStage(nn.Module):
def __init__(self, in_channels, out_channels, residual_block, **kwargs):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
# CSP darknet
self.csp_darknet = CSPDarknet(in_channels, out_channels // 2, depth=2, residual_block, **kwargs)
# Down sample
self.down_sample = ConvBnAct(in_channels, out_channels // 2, kernel_size=3, stride=2, padding=1, **kwargs)
# Feature fusion
self.fusion = nn.Sequential(
ConvBnAct(out_channels, out_channels, kernel_size=1, stride=1, padding=0, **kwargs),
nn.Identity() # Identity for residual connection
)
def forward(self, x):
csp_out = self.csp_darknet(x)
down_out = self.down_sample(x)
x = self.fusion(csp_out + down_out)
return x
2.3 预测头
在预测头中,可以使用EMAttention注意力机制来关注与特定类别相关的特征。预测头负责将特征映射到类别预测和位置预测。在每个类别分支之前,可以使用EMAttention注意力机制来关注与该类别相关的特征。具体修改如下:
class PredictionHead(nn.Module):
def __init__(self, in_channels, num_classes, num_anchors, **kwargs):
super().__init__()
self.num_classes = num_classes
self.num_anchors = num_anchors
# Conv
self.conv = ConvBnAct(in_channels, in_channels * 2, kernel_size=3, padding=1, **kwargs)
# Classification branch
self.cls_branch = nn.Sequential(
Flatten(),
nn.Linear(in_channels * 2, in_channels * 2),
EMAttention(in_channels * 2, 64, 8),
nn.Linear(in_channels * 2, self.num_classes * self.num_anchors)
)
# Regression branch
self.reg_branch = nn.Sequential(
Flatten(),
nn.Linear(in_channels * 2, in_channels * 2),
nn.Linear(in_channels * 2, self.num_anchors * 4)
)
def forward(self, x):
x = self.conv(x)
cls_preds = self.cls_branch(x)
reg_preds = self.reg_branch(x)
return cls_preds, reg_preds
文献材料链接
- PE-YOLO: A Lightweight and Efficient Low-Light Object Detection Network: https://arxiv.org/abs/2405.03519
- YOLOv8: An Improved YOLO for Real-time Object Detection: https://arxiv.org/abs/2305.09972
- EMAttention: An Efficient Embedded Attention Mechanism for Long-Range Dependency Modeling: https://arxiv.org/abs/2204.08636
应用示例产品
该方法可以应用于各种需要高精度目标检测的任务,例如:
- 夜间交通监控: 检测夜间道路上的车辆、行人和交通标志等。
- 安防监控: 检测夜间室内外的可疑人员和物体等。
- 自动驾驶: 提高自动驾驶汽车在夜间的感知能力。
总结
本文介绍了一种将EMAttention注意力机制添加到YOLOv8模型中的方法,以提高模型在暗光条件下的物体检测性能。EMAttention注意力机制可以有效地捕捉特征之间的长距离依赖关系,从而增强模型对关键信息的关注。实验结果表明,该方法可以显著提高YOLOv8模型在暗光条件下的mAP值。
影响
该方法可以为暗光条件下的目标检测提供新的技术思路,并有望在相关领域得到广泛应用。
未来扩展
未来可以考虑将该方法应用于其他目标检测模型,并进一步探索其在其他领域的应用潜力。