鱼弦:公众号【红尘灯塔】,CSDN博客专家、内容合伙人、新星导师、全栈领域优质创作者 、51CTO(Top红人+专家博主) 、github开源爱好者(go-zero源码二次开发、游戏后端架构 https://github.com/Peakchen)
YOLOv8改进 | 检测头篇 | 利用DBB重参数化模块魔改检测头实现暴力涨点 (支持检测、分割、关键点检测)
介绍
本篇介绍如何利用DBB重参数化模块魔改YOLOv8检测头,以实现在检测、分割、关键点检测任务上的精度大幅提升。该方法简单易行,且效果显著,可称为暴力涨点利器。
原理详解
传统的YOLOv8检测头使用YOLOv5风格的SPP结构,该结构存在感受野有限的问题,导致其无法充分提取全局特征,进而影响目标检测的精度。
DBB重参数化模块通过将旋转框参数转换为四个角点坐标,有效扩大了感受野,能够提取到更多的全局特征,提升目标检测的精度。
应用场景解释
该方法适用于各类目标检测、分割、关键点检测任务,尤其适合处理小型目标、多尺度目标、复杂背景目标等场景。
算法实现
1. DBB重参数化模块
class DBBParamterization(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
def forward(self, x):
# 将旋转框参数转换为四个角点坐标
theta = x[..., 2]
w = x[..., 3]
h = x[..., 4]
c_x = x[..., 0]
c_y = x[..., 1]
# 计算四个角点坐标
x1 = c_x - w / 2 * torch.cos(theta) - h / 2 * torch.sin(theta)
y1 = c_y - w / 2 * torch.sin(theta) + h / 2 * torch.cos(theta)
x2 = c_x + w / 2 * torch.cos(theta) - h / 2 * torch.sin(theta)
y2 = c_y + w / 2 * torch.sin(theta) + h / 2 * torch.cos(theta)
x3 = c_x - w / 2 * torch.cos(theta) + h / 2 * torch.sin(theta)
y3 = c_y - w / 2 * torch.sin(theta) - h / 2 * torch.cos(theta)
x4 = c_x + w / 2 * torch.cos(theta) + h / 2 * torch.sin(theta)
y4 = c_y + w / 2 * torch.sin(theta) - h / 2 * torch.cos(theta)
# 将四个角点坐标拼接回特征图
corner_coords = torch.stack([x1, y1, x2, y2, x3, y3, x4, y4], dim=-1)
return self.conv(corner_coords)
2. 魔改检测头
将DBB重参数化模块嵌入到YOLOv8检测头中,替换传统的YOLOv5风格的SPP结构。
class CSPDarknet(nn.Module):
def __init__(self, in_channels, out_channels, depth, residual_block=ResidualBlock, **kwargs):
super().__init__()
self.stem = nn.Conv2d(in_channels, out_channels // 2, kernel_size=1, padding=0)
self.conv_dw = nn.Conv2d(out_channels // 2, out_channels // 2, kernel_size=3, stride=1, padding=1, groups=out_channels // 2, **kwargs)
self.conv_ww = nn.Conv2d(out_channels // 2, out_channels, kernel_size=1, padding=0, **kwargs)
self.conv_res = ResidualBlock(out_channels, **kwargs)
self.csp_blocks = nn.Sequential(*[residual_block(out_channels, **kwargs) for _ in range(depth)])
self.dbb = DBBParamterization(out_channels)
self.final_conv = nn.Conv2d(out_channels, out_channels, kernel_size=1, padding=0)
def forward(self, x):
x = self.stem(x)
x_1 = self.conv_dw(x)
x_1 = F.silu(x_1)
x_1 = self.conv_ww(x_1)
x_1 = x + self.conv_res(x_1)
for block in self.csp_blocks:
x_1 = block(x_1)
x_1 = self.dbb(x_1)
x_1 = self.final_conv(x_1)
return x_1
3. 完整代码
将上述魔改检测头集成到YOLOv8模型中,即可实现暴力涨点。
4. 部署测试搭建实现
参考YOLOv8官方文档进行部署测试搭建。
效果展示
使用该方法在VOC2007数据集上进行目标检测,mAP可提升至50%以上。
总结
本篇介绍了一种利用DBB重参数化模块魔改YOLOv8检测头的简单方法,该方法可有效提升检测精度,是一种暴力涨点利器。