关键点检测(7)——yolov8-head的搭建

  前两节我学习了yolov8的backbone和head操作。这一节就到了head部分。 

  我们知道yolov8在流行的yolov5的架构上进行了扩展。在多个方面提供了改进。尤其是head部分,变化最大。yolov8模型与其前身的主要区别在于使用了无锚点检测(即从原先的耦合头变成了解耦头,并且从YOLOv5的Anchor-Based变成了Anchor-Free),这加速了非极大值抑制的后处理操作。这里废话不多说,继续先看一下其yaml配置文件.

1,yolov8的yaml配置文件

  首先,我们仍然展示一下yolov8-pose.yaml文件。看看其网络的构造:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8-pose keypoints/pose estimation model. For Usage examples see https://docs.ultralytics.com/tasks/pose
  
# Parameters
nc: 1 # number of classes
kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
scales: # model compound scaling constants, i.e. 'model=yolov8n-pose.yaml' will call yolov8-pose.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]
  s: [0.33, 0.50, 1024]
  m: [0.67, 0.75, 768]
  l: [1.00, 1.00, 512]
  x: [1.00, 1.25, 512]
  
# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  
# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 3, C2f, [512]] # 12
  
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 3, C2f, [256]] # 15 (P3/8-small)
  
  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]] # cat head P4
  - [-1, 3, C2f, [512]] # 18 (P4/16-medium)
  
  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]] # cat head P5
  - [-1, 3, C2f, [1024]] # 21 (P5/32-large)
  
  - [[15, 18, 21], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5)

   而对于head部分,实际上就是最后一行。

  - [[15, 18, 21], 1, Pose, [nc, kpt_shape]]  # Pose(P3, P4, P5) :本层是Pose层,[15, 18, 21]代表将第15、18、21层的输出(分别是80*80*256、40*40*512、20*20*1024)作为本层的输入。nc是数据集的类别数。1表示有一个检测头(Pose模块,也可以叫Detect模块)。Pose表示这是一个姿态识别层,用于预测类别,边界框和关键点。

  这个15,18,21,我们展示一下(使用mmyolo画的网络结构图):其输出确实和下面一样:15层,即P3输出的是80*80*256*w, 18层,即P4输出的是40*40*512*w,21层,即P5输出的是20*20*1024*w。

  并且P3,P4,P5检测块对应的检测目的也不一样。其中:

  • 第一个检测块P3:专门用于检测小型物体,来自于15号节点的C2f块的输出,80*80大小的特征图
  • 第二个检测块P4:专门用于检测中型物体,来自于18号节点的C2f块的输出,40*40大小的特征图
  • 第三个检测块P5:专门用于检测大型物体,来自于21号节点的C2f块的输出,20*20大小的特征图

  对应到我们的yaml文件也可以看到相同的结果:

2,yolov8的 head架构图

  在YOLOv8中,Head部分负责将Neck部分输出的特征进行进一步处理,以生成最终的目标检测结果。Head部分的主要功能是将特征图转换为目标检测,分类和关键点检测任务所需要的具体信息,包含一个类别,位置和置信度。即一个检测头和一个分类头。

  • 检测头:包含一系列卷积层和反卷积层,用于生成检测结果。这些层负责预测每个锚框的边界框回归值和目标存在的置信度。
  • 分类头:采用全局平均池化(Global Average Pooling)对每个特征图进行分类,通过减少特征图的维度,输出每个类别的概率分布。分类头的设计使得YOLOv8能够有效的处理多类别分类任务。

2.1 检测块——detect block

  检测块负责检测物体。与之前版本的YOLO不同,YOLOv8是一个无锚点模型,这意味着它直接预测物体的中心,而不是从已知的锚点框的偏移量进行预测。无锚点检测减少了框预测的数量,加快了推理后筛选候选检测结果的复杂后处理步骤。其结构图如下:

  检测块包含两个轨道。第一轨道是用于边界框预测,第二轨道是用于类别预测。这两个轨道都包含两个卷积块,随后是一个单独的Conv2d层,分别给出边界框损失和类别损失。

  而这节课的重点就是YOLOv8的Head层,对于其损失loss,我们下一节课分析。

3,yolo-pose 代码解析

  我们打开ultralytics的代码,直接找到Head部分的Pose的代码:

   然后点击进去Pose,则代码如下:

class Pose(Detect):
    """YOLOv8 Pose head for keypoints models."""

    def __init__(self, nc=80, kpt_shape=(17, 3), ch=()):
        """Initialize YOLO network with default parameters and Convolutional Layers."""
        super().__init__(nc, ch)
        self.kpt_shape = kpt_shape  # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
        self.nk = kpt_shape[0] * kpt_shape[1]  # number of keypoints total

        c4 = max(ch[0] // 4, self.nk)
        self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch)

    def forward(self, x):
        """Perform forward pass through YOLO model and return predictions."""
        bs = x[0].shape[0]  # batch size
        kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1)  # (bs, 17*3, h*w)
        x = Detect.forward(self, x)
        if self.training:
            return x, kpt
        pred_kpt = self.kpts_decode(bs, kpt)
        return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt))

    def kpts_decode(self, bs, kpts):
        """Decodes keypoints."""
        ndim = self.kpt_shape[1]
        if self.export:  # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
            y = kpts.view(bs, *self.kpt_shape, -1)
            a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
            if ndim == 3:
                a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
            return a.view(bs, self.nk, -1)
        else:
            y = kpts.clone()
            if ndim == 3:
                y[:, 2::3] = y[:, 2::3].sigmoid()  # sigmoid (WARNING: inplace .sigmoid_() Apple MPS bug)
            y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides
            y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides
            return y

   很明显上面代码是用于yolov8模型中关键点检测头,即head的关键. 我们可以看到其Pose类继承了Detect类,用于处理关键点检测任务(keypoint detection). 负责生成关键点的预测。

  他的initial函数,接收参数nc(类别数量), kpt_shape(关键点的形状,例如(17, 3)表示17个关键点,每个关键点有三个维度:x, y, visible) , ch(通道数列表)

  forward函数: 首先计算关键点预测kpt,然后调用父类detect的前向传播方法来获得基本的检测结果,然后拼接检测结果和关键点预测.

  kpts_decode函数: 解码关键点预测,根据关键点的维度(2D或3D)进行不同的解码操作。

  而Pose的forward的关键代码也就是下面这一行:

   然后就是Detect函数,我们也放上Detect的函数:

class Detect(nn.Module):
    """YOLOv8 Detect head for detection models."""

    dynamic = False  # force grid reconstruction
    export = False  # export mode
    end2end = False  # end2end
    max_det = 300  # max_det
    shape = None
    anchors = torch.empty(0)  # init
    strides = torch.empty(0)  # init

    def __init__(self, nc=80, ch=()):
        """Initializes the YOLOv8 detection layer with specified number of classes and channels."""
        super().__init__()
        self.nc = nc  # number of classes
        self.nl = len(ch)  # number of detection layers
        self.reg_max = 16  # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
        self.no = nc + self.reg_max * 4  # number of outputs per anchor
        self.stride = torch.zeros(self.nl)  # strides computed during build
        c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100))  # channels
        self.cv2 = nn.ModuleList(
            nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
        )
        self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
        self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()

        if self.end2end:
            self.one2one_cv2 = copy.deepcopy(self.cv2)
            self.one2one_cv3 = copy.deepcopy(self.cv3)

    def forward(self, x):
        """Concatenates and returns predicted bounding boxes and class probabilities."""
        if self.end2end:
            return self.forward_end2end(x)

        for i in range(self.nl):
            x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
        if self.training:  # Training path
            return x
        y = self._inference(x)
        return y if self.export else (y, x)

    def forward_end2end(self, x):
        """
        Performs forward pass of the v10Detect module.

        Args:
            x (tensor): Input tensor.

        Returns:
            (dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections.
                           If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately.
        """
        x_detach = [xi.detach() for xi in x]
        one2one = [
            torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl)
        ]
        for i in range(self.nl):
            x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
        if self.training:  # Training path
            return {"one2many": x, "one2one": one2one}

        y = self._inference(one2one)
        y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)
        return y if self.export else (y, {"one2many": x, "one2one": one2one})

    def _inference(self, x):
        """Decode predicted bounding boxes and class probabilities based on multiple-level feature maps."""
        # Inference path
        shape = x[0].shape  # BCHW
        x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
        if self.dynamic or self.shape != shape:
            self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
            self.shape = shape

        if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}:  # avoid TF FlexSplitV ops
            box = x_cat[:, : self.reg_max * 4]
            cls = x_cat[:, self.reg_max * 4 :]
        else:
            box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)

        if self.export and self.format in {"tflite", "edgetpu"}:
            # Precompute normalization factor to increase numerical stability
            # See https://github.com/ultralytics/ultralytics/issues/7371
            grid_h = shape[2]
            grid_w = shape[3]
            grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
            norm = self.strides / (self.stride[0] * grid_size)
            dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
        else:
            dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides

        return torch.cat((dbox, cls.sigmoid()), 1)

    def bias_init(self):
        """Initialize Detect() biases, WARNING: requires stride availability."""
        m = self  # self.model[-1]  # Detect() module
        # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
        # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum())  # nominal class frequency
        for a, b, s in zip(m.cv2, m.cv3, m.stride):  # from
            a[-1].bias.data[:] = 1.0  # box
            b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2)  # cls (.01 objects, 80 classes, 640 img)
        if self.end2end:
            for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride):  # from
                a[-1].bias.data[:] = 1.0  # box
                b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2)  # cls (.01 objects, 80 classes, 640 img)

    def decode_bboxes(self, bboxes, anchors):
        """Decode bounding boxes."""
        return dist2bbox(bboxes, anchors, xywh=not self.end2end, dim=1)

    @staticmethod
    def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):
        """
        Post-processes the predictions obtained from a YOLOv10 model.

        Args:
            preds (torch.Tensor): The predictions obtained from the model. It should have a shape of (batch_size, num_boxes, 4 + num_classes).
            max_det (int): The maximum number of detections to keep.
            nc (int, optional): The number of classes. Defaults to 80.

        Returns:
            (torch.Tensor): The post-processed predictions with shape (batch_size, max_det, 6),
                including bounding boxes, scores and cls.
        """
        assert 4 + nc == preds.shape[-1]
        boxes, scores = preds.split([4, nc], dim=-1)
        max_scores = scores.amax(dim=-1)
        max_scores, index = torch.topk(max_scores, min(max_det, max_scores.shape[1]), axis=-1)
        index = index.unsqueeze(-1)
        boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1]))
        scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1]))

        # NOTE: simplify but result slightly lower mAP
        # scores, labels = scores.max(dim=-1)
        # return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)

        scores, index = torch.topk(scores.flatten(1), max_det, axis=-1)
        labels = index % nc
        index = index // nc
        boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))

        return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1).to(boxes.dtype)], dim=-1)

   这段代码定义了一个名为Detect的类,它是YOLOv8模型的一个组成部分,用于处理目标检测任务,并实现了YOLO检测头的主要功能,包括预测框的解码,置信度得分的计算以及后处理。

3.1 代码分析

类属性

  • dynamic:布尔标志,指示是否强制重建网格。
  • export:布尔标志,指示是否处于导出模式。
  • end2end:布尔标志,指示是否为端到端模式。
  • max_det:整数,表示最大检测数。
  • shape:检测输入的形状,初始为空。
  • anchors:锚点,初始为空张量。
  • strides:步长,初始为空张量。

初始化方法 (__init__)

  • 接受 nc(类别数量)和 ch(输入通道数列表)作为参数。
  • 设置 nc 和 nl(检测层的数量)。
  • 定义了 reg_max,这是一个用于分布聚焦损失(DFL)的通道数。
  • 定义了 no,即每个锚点的输出数量。
  • 创建了两个 nn.ModuleListcv2 和 cv3,分别用于生成边界框的位置信息和类别概率信息
  • 定义了 dfl,即分布聚焦损失模块。
  • 如果 end2end 为真,则复制 cv2 和 cv3 以用于端到端训练。

前向传播方法 (forward)

  • 接受输入 x 并执行前向传播。
  • 如果 end2end 为真,则调用 forward_end2end 方法。
  • 否则,将 x 传递给 cv2 和 cv3 模块以生成边界框和类别概率。
  • 在训练模式下直接返回这些输出。
  • 如果不是训练模式,则调用 _inference 方法解码边界框,并返回解码后的输出。

端到端前向传播方法 (forward_end2end)

  • 接受输入 x 并执行前向传播。
  • 分离输入以避免梯度回传。
  • 生成 one2one 输出。
  • 如果在训练模式下,返回 one2many 和 one2one 的字典。
  • 否则,解码 one2one 输出并进行后处理,然后返回。

推理方法 (_inference)

  • 将多尺度特征图的输出连接在一起。
  • 如果 dynamic 或者输入形状改变,则重新计算锚点和步长。
  • 分割输出为边界框和类别概率。
  • 解码边界框,并对类别概率应用 Sigmoid 函数。
  • 返回解码后的边界框和类别概率。

偏置初始化方法 (bias_init)

  • 初始化 Detect 模块的偏置项。
  • 为 cv2 和 cv3 的最后一层设置初始偏置。

边界框解码方法 (decode_bboxes)

  • 解码边界框。
  • 使用 dist2bbox 函数进行解码。

后处理方法 (postprocess)

  • 处理模型的预测结果。
  • 分离边界框和类别概率。
  • 计算每个框的最大得分及其对应的索引。
  • 选择最高得分的边界框并返回。

4,手写yolov8的head代码

4.1 根据模型梳理head结构

  正如之前截图所展示的,head的部分其实比较简单,我再截图如下:

   但是这里涉及到了Bbox的Loss和Cls的Loss了。所以我们先将代码贴出来,然后等到loss函数再详细解释。

4.2 head的代码组合

  因为纯head的代码比较简单。就是卷积模块。我们上面也分析了,如下:

   就是三个解耦头的得到的卷积结果进行cat。所以也没啥写的,我这里直接拿到我拆解出来的代码:

import torch.nn as nn
import torch
import copy
import math

from yolov8_blocks import *

def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
    """Transform distance(ltrb) to box(xywh or xyxy)."""
    lt, rb = distance.chunk(2, dim)
    x1y1 = anchor_points - lt
    x2y2 = anchor_points + rb
    if xywh:
        c_xy = (x1y1 + x2y2) / 2
        wh = x2y2 - x1y1
        return torch.cat((c_xy, wh), dim)  # xywh bbox
    return torch.cat((x1y1, x2y2), dim)  # xyxy bbox


def make_anchors(feats, strides, grid_cell_offset=0.5):
    """Generate anchors from features."""
    anchor_points, stride_tensor = [], []
    assert feats is not None
    dtype, device = feats[0].dtype, feats[0].device
    for i, stride in enumerate(strides):
        _, _, h, w = feats[i].shape
        sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset  # shift x
        sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset  # shift y
        sy, sx = torch.meshgrid(sy, sx, indexing="ij") # if TORCH_1_10 else torch.meshgrid(sy, sx)
        anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
        stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
    return torch.cat(anchor_points), torch.cat(stride_tensor)

class DFL(nn.Module):
    """
    Integral module of Distribution Focal Loss (DFL).

    Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
    """

    def __init__(self, c1=16):
        """Initialize a convolutional layer with a given number of input channels."""
        super().__init__()
        self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
        x = torch.arange(c1, dtype=torch.float)
        self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
        self.c1 = c1

    def forward(self, x):
        """Applies a transformer layer on input tensor 'x' and returns a tensor."""
        b, _, a = x.shape  # batch, channels, anchors
        return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
        # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)

class Detect(nn.Module):
    """YOLOv8 Detect head for detection models."""

    dynamic = False  # force grid reconstruction
    export = False  # export mode
    end2end = False  # end2end
    max_det = 300  # max_det
    shape = None
    anchors = torch.empty(0)  # init
    strides = torch.empty(0)  # init

    def __init__(self, nc=80, ch=()):
        """Initializes the YOLOv8 detection layer with specified number of classes and channels."""
        super().__init__()
        self.nc = nc  # number of classes
        self.nl = len(ch)  # number of detection layers
        self.reg_max = 16  # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
        self.no = nc + self.reg_max * 4  # number of outputs per anchor
        self.stride = torch.zeros(self.nl)  # strides computed during build
        c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100))  # channels
        self.cv2 = nn.ModuleList(
            nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
        )
        self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
        self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()

        if self.end2end:
            self.one2one_cv2 = copy.deepcopy(self.cv2)
            self.one2one_cv3 = copy.deepcopy(self.cv3)

    def forward(self, x):
        """Concatenates and returns predicted bounding boxes and class probabilities."""
        if self.end2end:
            return self.forward_end2end(x)

        for i in range(self.nl):
            x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
        if self.training:  # Training path
            return x
        y = self._inference(x)
        return y if self.export else (y, x)

    def forward_end2end(self, x):
        """
        Performs forward pass of the v10Detect module.

        Args:
            x (tensor): Input tensor.

        Returns:
            (dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections.
                           If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately.
        """
        x_detach = [xi.detach() for xi in x]
        one2one = [
            torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl)
        ]
        for i in range(self.nl):
            x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
        if self.training:  # Training path
            return {"one2many": x, "one2one": one2one}

        y = self._inference(one2one)
        y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)
        return y if self.export else (y, {"one2many": x, "one2one": one2one})

    def _inference(self, x):
        """Decode predicted bounding boxes and class probabilities based on multiple-level feature maps."""
        # Inference path
        shape = x[0].shape  # BCHW
        x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
        if self.dynamic or self.shape != shape:
            self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
            self.shape = shape

        if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}:  # avoid TF FlexSplitV ops
            box = x_cat[:, : self.reg_max * 4]
            cls = x_cat[:, self.reg_max * 4 :]
        else:
            box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)

        if self.export and self.format in {"tflite", "edgetpu"}:
            # Precompute normalization factor to increase numerical stability
            # See https://github.com/ultralytics/ultralytics/issues/7371
            grid_h = shape[2]
            grid_w = shape[3]
            grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
            norm = self.strides / (self.stride[0] * grid_size)
            dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
        else:
            dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides

        return torch.cat((dbox, cls.sigmoid()), 1)

    def bias_init(self):
        """Initialize Detect() biases, WARNING: requires stride availability."""
        m = self  # self.model[-1]  # Detect() module
        # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
        # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum())  # nominal class frequency
        for a, b, s in zip(m.cv2, m.cv3, m.stride):  # from
            a[-1].bias.data[:] = 1.0  # box
            b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2)  # cls (.01 objects, 80 classes, 640 img)
        if self.end2end:
            for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride):  # from
                a[-1].bias.data[:] = 1.0  # box
                b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2)  # cls (.01 objects, 80 classes, 640 img)

    def decode_bboxes(self, bboxes, anchors):
        """Decode bounding boxes."""
        return dist2bbox(bboxes, anchors, xywh=not self.end2end, dim=1)

    @staticmethod
    def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):
        """
        Post-processes the predictions obtained from a YOLOv10 model.

        Args:
            preds (torch.Tensor): The predictions obtained from the model. It should have a shape of (batch_size, num_boxes, 4 + num_classes).
            max_det (int): The maximum number of detections to keep.
            nc (int, optional): The number of classes. Defaults to 80.

        Returns:
            (torch.Tensor): The post-processed predictions with shape (batch_size, max_det, 6),
                including bounding boxes, scores and cls.
        """
        assert 4 + nc == preds.shape[-1]
        boxes, scores = preds.split([4, nc], dim=-1)
        max_scores = scores.amax(dim=-1)
        max_scores, index = torch.topk(max_scores, min(max_det, max_scores.shape[1]), axis=-1)
        index = index.unsqueeze(-1)
        boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1]))
        scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1]))

        # NOTE: simplify but result slightly lower mAP
        # scores, labels = scores.max(dim=-1)
        # return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)

        scores, index = torch.topk(scores.flatten(1), max_det, axis=-1)
        labels = index % nc
        index = index // nc
        boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))

        return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1).to(boxes.dtype)], dim=-1)


class Pose(Detect):
    """YOLOv8 Pose head for keypoints models."""

    def __init__(self, nc=80, kpt_shape=(17, 3), ch=()):
        """Initialize YOLO network with default parameters and Convolutional Layers."""
        super().__init__(nc, ch)
        self.kpt_shape = kpt_shape  # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
        self.nk = kpt_shape[0] * kpt_shape[1]  # number of keypoints total

        c4 = max(ch[0] // 4, self.nk)
        self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch)

    def forward(self, x):
        """Perform forward pass through YOLO model and return predictions."""
        bs = x[0].shape[0]  # batch size
        kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1)  # (bs, 17*3, h*w)
        x = Detect.forward(self, x)
        if self.training:
            return x, kpt
        pred_kpt = self.kpts_decode(bs, kpt)
        return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt))

    def kpts_decode(self, bs, kpts):
        """Decodes keypoints."""
        ndim = self.kpt_shape[1]
        if self.export:  # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
            y = kpts.view(bs, *self.kpt_shape, -1)
            a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
            if ndim == 3:
                a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
            return a.view(bs, self.nk, -1)
        else:
            y = kpts.clone()
            if ndim == 3:
                y[:, 2::3] = y[:, 2::3].sigmoid()  # sigmoid (WARNING: inplace .sigmoid_() Apple MPS bug)
            y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides
            y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides
            return y

   这里代码不是我写的,我只是把ultrilytics的代码抽取出来而已。

   然后我们打印一下head部分。打印代码如下:

model = Pose(nc=80, kpt_shape=(17, 3), ch=(15, 18, 21))          
print(model)

   打印结果如下:

Pose(
  (cv2): ModuleList(
    (0): Sequential(
      (0): Conv(
        (conv): Conv2d(15, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU()
      )
      (1): Conv(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU()
      )
      (2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
    )
    (1): Sequential(
      (0): Conv(
        (conv): Conv2d(18, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU()
      )
      (1): Conv(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU()
      )
      (2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
    )
    (2): Sequential(
      (0): Conv(
        (conv): Conv2d(21, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU()
      )
      (1): Conv(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU()
      )
      (2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (cv3): ModuleList(
    (0): Sequential(
      (0): Conv(
        (conv): Conv2d(15, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU()
      )
      (1): Conv(
        (conv): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU()
      )
      (2): Conv2d(80, 80, kernel_size=(1, 1), stride=(1, 1))
    )
    (1): Sequential(
      (0): Conv(
        (conv): Conv2d(18, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU()
      )
      (1): Conv(
        (conv): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU()
      )
      (2): Conv2d(80, 80, kernel_size=(1, 1), stride=(1, 1))
    )
    (2): Sequential(
      (0): Conv(
        (conv): Conv2d(21, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU()
      )
      (1): Conv(
        (conv): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU()
      )
      (2): Conv2d(80, 80, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (dfl): DFL(
    (conv): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (cv4): ModuleList(
    (0): Sequential(
      (0): Conv(
        (conv): Conv2d(15, 51, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(51, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU()
      )
      (1): Conv(
        (conv): Conv2d(51, 51, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(51, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU()
      )
      (2): Conv2d(51, 51, kernel_size=(1, 1), stride=(1, 1))
    )
    (1): Sequential(
      (0): Conv(
        (conv): Conv2d(18, 51, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(51, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU()
      )
      (1): Conv(
        (conv): Conv2d(51, 51, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(51, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU()
      )
      (2): Conv2d(51, 51, kernel_size=(1, 1), stride=(1, 1))
    )
    (2): Sequential(
      (0): Conv(
        (conv): Conv2d(21, 51, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(51, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU()
      )
      (1): Conv(
        (conv): Conv2d(51, 51, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(51, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU()
      )
      (2): Conv2d(51, 51, kernel_size=(1, 1), stride=(1, 1))
    )
  )
)

4.3 yolov8的模型组合

  这里我们将backbone,neck,head部分组合起来。如下:

class PoseModel(nn.Module):
    """YOLOv8 pose model."""

    def __init__(self, base_channels, base_depth, deep_mul, nc=80, kpt_shape=(17, 3)):
        super().__init__()
        self.model = nn.Sequential(
            # backbone
            # stem layer 640*640*3 -> 320*320*64
            Conv(c1=3, c2=base_channels, k=3, s=2, p=1),

            # stage layer1(dark2) 320*320*64 -> 160*160*128 -> 160*160*128
            Conv(c1=base_channels, c2=base_channels * 2, k=3, s=2, p=1),
            C2f(base_channels * 2, base_channels * 2, base_depth, True),

            # stage layer2(dark3) 160*160*128 -> 80*80*256 -> 80*80*256
            Conv(c1=base_channels * 2, c2=base_channels * 4, k=3, s=2, p=1),
            C2f(base_channels * 4, base_channels * 4, base_depth * 2, True),

            # stage layer3(dark4) 80*80*256 -> 40*40*512 -> 40*40*512
            Conv(c1=base_channels * 4, c2=base_channels * 8, k=3, s=2, p=1),
            C2f(base_channels * 8, base_channels * 8, base_depth * 2, True),

            # stage layer4(dark5) 40*40*512 -> 20*20*512 -> 20*20*512
            Conv(c1=base_channels * 8, c2=int(base_channels * 16 * deep_mul), k=3, s=2, p=1),
            C2f(int(base_channels * 16 * deep_mul), int(base_channels * 16 * deep_mul), base_depth, True),
            SPPF(int(base_channels * 16 * deep_mul), int(base_channels * 16 * deep_mul), k=5),
        
            # neck  加强特征提取
            # 1024*deep_mul + 512, 40, 40 --> 512, 40, 40
            nn.Upsample(scale_factor=2, mode="nearest"),
            Concat(),  # cat backbone P4
            C2f(int(base_channels * 16*deep_mul) + base_channels * 8, base_channels * 8, base_depth, False),

            # 768, 80, 80 -> 256, 80, 80
            nn.Upsample(scale_factor=2, mode="nearest"),
            Concat(),  # cat backbone P3
            # 15 (P3/8 - small)
            C2f(base_channels * 8 + base_channels * 4, base_channels * 4, base_depth, False),

            # down_sample 256, 80, 80 -> 256, 40, 40
            Conv(c1=base_channels * 4, c2=base_channels * 4, k=3, s=2, p=1),
            Concat(),  # cat head P4
            # 18 (P4/16 - medium)
            # 512 + 256, 40, 40 ==> 512, 40, 40
            C2f(base_channels * 8 + base_channels * 4, base_channels * 8, base_depth, False),

            # down_sample 512, 40, 40 --> 512, 20, 20
            Conv(c1=base_channels * 8, c2=base_channels * 8, k=3, s=2, p=1),
            Concat(),  # cat head P5
            # 21 (P5/32-large)
            # 1024*deep_mul + 512, 20, 20 --> 1024*deep_mul, 20, 20
            C2f(base_channels * 8 + int(base_channels * 16 * deep_mul), int(base_channels * 16 * deep_mul), base_depth, False),

            Pose(nc=nc, kpt_shape=kpt_shape, ch=(64, 128, 256))            
        )   
        self.nl = 3  # number of detection layers
        self.kpt_shape = kpt_shape  # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
        self.nk = kpt_shape[0] * kpt_shape[1]  # number of keypoints total


    def forward(self, x):
        #  backbone  基础卷积提取特征
        feat1 = x
        feat2 = x
        feat3 = x
        for layer in self.model[:5]:
            feat1 = layer(feat1)

        for layer in self.model[:7]:
            feat2 = layer(feat2)

        feat3 = self.model[:10](feat3)
        print("feat1, feat2, feat3 shape is ", feat1.shape, feat2.shape, feat3.shape)

        #------------------------加强特征提取网络------------------------# 
        # 1024 * deep_mul, 20, 20 => 1024 * deep_mul, 40, 40
        P5_upsample = self.model[10:11](feat3)
        # 1024 * deep_mul, 40, 40 cat 512, 40, 40 => 1024 * deep_mul + 512, 40, 40
        P4          = self.model[11:12]([P5_upsample, feat2])
        # 1024 * deep_mul + 512, 40, 40 => 512, 40, 40
        P4          = self.model[12:13](P4)

        # 512, 40, 40 => 512, 80, 80
        P4_upsample = self.model[13:14](P4)
        # 512, 80, 80 cat 256, 80, 80 => 768, 80, 80
        P3          = self.model[14:15]([P4_upsample, feat1])
        # 768, 80, 80 => 256, 80, 80
        P3          = self.model[15:16](P3)

        # 256, 80, 80 => 256, 40, 40
        P3_downsample = self.model[16:17](P3)
        # 512, 40, 40 cat 256, 40, 40 => 768, 40, 40
        P4 = self.model[17:18]([P3_downsample, P4])
        # 768, 40, 40 => 512, 40, 40
        P4 = self.model[18:19](P4)

        # 512, 40, 40 => 512, 20, 20
        P4_downsample = self.model[19:20](P4)
        # 512, 20, 20 cat 1024 * deep_mul, 20, 20 => 1024 * deep_mul + 512, 20, 20
        P5 = self.model[20:21]([P4_downsample, feat3])
        # 1024 * deep_mul + 512, 20, 20 => 1024 * deep_mul, 20, 20
        P5 = self.model[21:22](P5)

        # P3 256, 80, 80 => num_classes + self.reg_max * 4, 80, 80
        # P4 512, 40, 40 => num_classes + self.reg_max * 4, 40, 40
        # P5 1024 * deep_mul, 20, 20 => num_classes + self.reg_max * 4, 20, 20
        x = [P3, P4, P5]
        print("P3, P4, P5 shape is ", P3.shape, P4.shape, P5.shape)

        bs = x[0].shape[0]  # batch size
        print(self.model[22:23][0].cv4)
        kpt = torch.cat([self.model[22:23][0].cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1)

        for i in range(self.nl):
            x[i] = torch.cat((self.model[22:23][0].cv2[i](x[i]), self.model[22:23][0].cv3[i](x[i])), 1)

        return x, kpt

   注意,这里只是forward模型得到的特征结果。但是使用了yolov8的Pose函数。所以forward后的结果,包含了一些后处理。这个我们后续继续学习。

  • 32
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值