前两节我学习了yolov8的backbone和head操作。这一节就到了head部分。
我们知道yolov8在流行的yolov5的架构上进行了扩展。在多个方面提供了改进。尤其是head部分,变化最大。yolov8模型与其前身的主要区别在于使用了无锚点检测(即从原先的耦合头变成了解耦头,并且从YOLOv5的Anchor-Based变成了Anchor-Free),这加速了非极大值抑制的后处理操作。这里废话不多说,继续先看一下其yaml配置文件.
1,yolov8的yaml配置文件
首先,我们仍然展示一下yolov8-pose.yaml文件。看看其网络的构造:
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8-pose keypoints/pose estimation model. For Usage examples see https://docs.ultralytics.com/tasks/pose
# Parameters
nc: 1 # number of classes
kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
scales: # model compound scaling constants, i.e. 'model=yolov8n-pose.yaml' will call yolov8-pose.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.33, 0.25, 1024]
s: [0.33, 0.50, 1024]
m: [0.67, 0.75, 768]
l: [1.00, 1.00, 512]
x: [1.00, 1.25, 512]
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
# YOLOv8.0n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2f, [512]] # 12
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 15 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 12], 1, Concat, [1]] # cat head P4
- [-1, 3, C2f, [512]] # 18 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 9], 1, Concat, [1]] # cat head P5
- [-1, 3, C2f, [1024]] # 21 (P5/32-large)
- [[15, 18, 21], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5)
而对于head部分,实际上就是最后一行。
- [[15, 18, 21], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5) :本层是Pose层,[15, 18, 21]代表将第15、18、21层的输出(分别是80*80*256、40*40*512、20*20*1024)作为本层的输入。nc是数据集的类别数。1表示有一个检测头(Pose模块,也可以叫Detect模块)。Pose表示这是一个姿态识别层,用于预测类别,边界框和关键点。
这个15,18,21,我们展示一下(使用mmyolo画的网络结构图):其输出确实和下面一样:15层,即P3输出的是80*80*256*w, 18层,即P4输出的是40*40*512*w,21层,即P5输出的是20*20*1024*w。
并且P3,P4,P5检测块对应的检测目的也不一样。其中:
- 第一个检测块P3:专门用于检测小型物体,来自于15号节点的C2f块的输出,80*80大小的特征图
- 第二个检测块P4:专门用于检测中型物体,来自于18号节点的C2f块的输出,40*40大小的特征图
- 第三个检测块P5:专门用于检测大型物体,来自于21号节点的C2f块的输出,20*20大小的特征图
对应到我们的yaml文件也可以看到相同的结果:
2,yolov8的 head架构图
在YOLOv8中,Head部分负责将Neck部分输出的特征进行进一步处理,以生成最终的目标检测结果。Head部分的主要功能是将特征图转换为目标检测,分类和关键点检测任务所需要的具体信息,包含一个类别,位置和置信度。即一个检测头和一个分类头。
- 检测头:包含一系列卷积层和反卷积层,用于生成检测结果。这些层负责预测每个锚框的边界框回归值和目标存在的置信度。
- 分类头:采用全局平均池化(Global Average Pooling)对每个特征图进行分类,通过减少特征图的维度,输出每个类别的概率分布。分类头的设计使得YOLOv8能够有效的处理多类别分类任务。
2.1 检测块——detect block
检测块负责检测物体。与之前版本的YOLO不同,YOLOv8是一个无锚点模型,这意味着它直接预测物体的中心,而不是从已知的锚点框的偏移量进行预测。无锚点检测减少了框预测的数量,加快了推理后筛选候选检测结果的复杂后处理步骤。其结构图如下:
检测块包含两个轨道。第一轨道是用于边界框预测,第二轨道是用于类别预测。这两个轨道都包含两个卷积块,随后是一个单独的Conv2d层,分别给出边界框损失和类别损失。
而这节课的重点就是YOLOv8的Head层,对于其损失loss,我们下一节课分析。
3,yolo-pose 代码解析
我们打开ultralytics的代码,直接找到Head部分的Pose的代码:
然后点击进去Pose,则代码如下:
class Pose(Detect):
"""YOLOv8 Pose head for keypoints models."""
def __init__(self, nc=80, kpt_shape=(17, 3), ch=()):
"""Initialize YOLO network with default parameters and Convolutional Layers."""
super().__init__(nc, ch)
self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
self.nk = kpt_shape[0] * kpt_shape[1] # number of keypoints total
c4 = max(ch[0] // 4, self.nk)
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch)
def forward(self, x):
"""Perform forward pass through YOLO model and return predictions."""
bs = x[0].shape[0] # batch size
kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
x = Detect.forward(self, x)
if self.training:
return x, kpt
pred_kpt = self.kpts_decode(bs, kpt)
return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt))
def kpts_decode(self, bs, kpts):
"""Decodes keypoints."""
ndim = self.kpt_shape[1]
if self.export: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
y = kpts.view(bs, *self.kpt_shape, -1)
a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
if ndim == 3:
a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
return a.view(bs, self.nk, -1)
else:
y = kpts.clone()
if ndim == 3:
y[:, 2::3] = y[:, 2::3].sigmoid() # sigmoid (WARNING: inplace .sigmoid_() Apple MPS bug)
y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides
y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides
return y
很明显上面代码是用于yolov8模型中关键点检测头,即head的关键. 我们可以看到其Pose类继承了Detect类,用于处理关键点检测任务(keypoint detection). 负责生成关键点的预测。
他的initial函数,接收参数nc(类别数量), kpt_shape(关键点的形状,例如(17, 3)表示17个关键点,每个关键点有三个维度:x, y, visible) , ch(通道数列表)
forward函数: 首先计算关键点预测kpt,然后调用父类detect的前向传播方法来获得基本的检测结果,然后拼接检测结果和关键点预测.
kpts_decode函数: 解码关键点预测,根据关键点的维度(2D或3D)进行不同的解码操作。
而Pose的forward的关键代码也就是下面这一行:
然后就是Detect函数,我们也放上Detect的函数:
class Detect(nn.Module):
"""YOLOv8 Detect head for detection models."""
dynamic = False # force grid reconstruction
export = False # export mode
end2end = False # end2end
max_det = 300 # max_det
shape = None
anchors = torch.empty(0) # init
strides = torch.empty(0) # init
def __init__(self, nc=80, ch=()):
"""Initializes the YOLOv8 detection layer with specified number of classes and channels."""
super().__init__()
self.nc = nc # number of classes
self.nl = len(ch) # number of detection layers
self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
self.no = nc + self.reg_max * 4 # number of outputs per anchor
self.stride = torch.zeros(self.nl) # strides computed during build
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
self.cv2 = nn.ModuleList(
nn.Sequential(Conv(x, c2,