YOLOV8的Detect head 逐行解读

YOLOV8从不同的特征层,得到不同大小的特征图,然后预测每个特征图的每个格子anchor的类别概率,以及每个格子中物体的边框,即相对于中心点上下左右的偏移量box。

shape为[(1, 144, 80, 80),(1, 144, 40, 40),(1,144,20,20)]。

 输入x为从不同的上采样层得到的结果

x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
 #(1,64,8400),(1,80,8400)

整合这些结果,得到的shape为 (1,144,8400)。其中:
       8400 = 80 * 80+40 * 40+20 * 20,总的预测数
       144 为80个class和4*16个box
       4 为预测的四个边框距离中心点的距离,是Anchor-Free的预测目标,格式为[left,top,right,bottom]。
        self.reg_max = 16,是中心点的最大预测范围,即边框距离中心点的最远距离为16,但并不是16个像素,因为预测值都进行了不同stride的缩放。这个参数也决定了检测物体最大边框为 reg_max * stride*2。

self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
#(2,8400),(1,8400)
self.shape = shape #(1, 144, 80, 80)

def make_anchors(feats, strides, grid_cell_offset=0.5):
    """Generate anchors from features."""
    anchor_points, stride_tensor = [], []
    assert feats is not None
    dtype, device = feats[0].dtype, feats[0].device
    for i, stride in enumerate(strides):
        _, _, h, w = feats[i].shape
        sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset  # shift x
        sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset  # shift y
        sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
        anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
        stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
    return torch.cat(anchor_points), torch.cat(stride_tensor)


self.anchors[:,:10]
tensor([[0.5000, 1.5000, 2.5000, 3.5000, 4.5000, 5.5000, 6.5000, 7.5000, 8.5000, 9.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000]], device='cuda:0')

self.strides[:,:10]
tensor([[8., 8., 8., 8., 8., 8., 8., 8., 8., 8.]], device='cuda:0')

 make_anchors,主要生成预测的网格点,

其中x 的shape [(1, 144, 80, 80),(1, 144, 40, 40),(1,144,20,20)]

self.stride 的值为:tensor([8., 16., 32.])

对应 80 * 80的特征图,生成   80 * 80的anchor和 80 * 80 的stride,anchor就是每个 1*1 网格的中心点,stride是缩放系数,大的特征图缩放系数小,用来预测小物体。

dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides  
#(1,4,8400),(1,2,8400) => (1,4,8400)

class DFL(nn.Module):
    def __init__(self, c1=16):
        """Initialize a convolutional layer with a given number of input channels."""
        super().__init__()
        self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
        x = torch.arange(c1, dtype=torch.float)
        self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
        self.c1 = c1

    def forward(self, x):
    #(1,64,8400) => (1,4,16,8400) => (1,16,4,8400) => (1,1,4,8400) => (1,4,8400)
        """Applies a transformer layer on input tensor 'x' and returns a tensor."""
        b, _, a = x.shape  # batch, channels, anchors
        return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)

def decode_bboxes(self, bboxes, anchors):
    """Decode bounding boxes."""
    if self.export:
        return dist2bbox(bboxes, anchors, xywh=False, dim=1)
    return dist2bbox(bboxes, anchors, xywh=True, dim=1)

def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
    """Transform distance(ltrb) to box(xywh or xyxy)."""
    lt, rb = distance.chunk(2, dim)
    x1y1 = anchor_points - lt
    x2y2 = anchor_points + rb
    if xywh:
        c_xy = (x1y1 + x2y2) / 2
        wh = x2y2 - x1y1
        return torch.cat((c_xy, wh), dim)  # xywh bbox
    return torch.cat((x1y1, x2y2), dim)  # xyxy bbox

self.dfl(box):计算box偏移量

x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1):先把box从(1,64,8400) => (1,4,16,8400) => (1,16,4,8400),然后对dim=1进行softmax计算,给16个距离对应的权重。

 self.conv的参数requires_grad_(False),等于x = torch.arange(c1, dtype=torch.float),固定为 tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.]),再进行 nn.Conv2d(c1, 1, 1, bias=False)计算,就相当于用softmax后的权重乘以对应的数值,得到最终的偏移量。

decode_bboxes 使用 dist2bbox函数,box的格式为[left,top,right,bottom],将box分为两部分,用中心点减去left,top,得到左上角x1y1,用中心点加上 right,bottom,得到右下角的点x2y2,这样就得到了xyxy格式(也可以转换为xywh格式)的坐标点,再乘以对应的stride,得到最终的坐标点。(1,4,8400)

y = torch.cat((dbox, cls.sigmoid()), 1)  #(1,84,8400)

将预测的坐标点和类别合并,得到最终返回结果。


 完整代码:

class Detect(nn.Module):
    """YOLOv8 Detect head for detection models."""

    dynamic = False  # force grid reconstruction
    export = False  # export mode
    shape = None
    anchors = torch.empty(0)  # init
    strides = torch.empty(0)  # init

    def __init__(self, nc=80, ch=()):
        """Initializes the YOLOv8 detection layer with specified number of classes and channels."""
        super().__init__()
        self.nc = nc  # number of classes
        self.nl = len(ch)  # number of detection layers
        self.reg_max = 16  # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
        self.no = nc + self.reg_max * 4  # number of outputs per anchor
        self.stride = torch.zeros(self.nl)  # strides computed during build
        c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100))  # channels
        self.cv2 = nn.ModuleList(
            nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
        )
        self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
        self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()

    def inference(self, x):#[(1, 144, 80, 80),(1, 144, 40, 40),(1,144,20,20)]
        # Inference path
        shape = x[0].shape  # BCHW  (1, 144, 80, 80)
        x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)  #(1,144,8400)
        if self.dynamic or self.shape != shape:
            self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) #(2,8400),(1,8400)
            self.shape = shape #(1, 144, 80, 80)

        if self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"):  # avoid TF FlexSplitV ops
            box = x_cat[:, : self.reg_max * 4]
            cls = x_cat[:, self.reg_max * 4 :]
        else:
            box, cls = x_cat.split((self.reg_max * 4, self.nc), 1) #(1,64,8400),(1,80,8400)

        if self.export and self.format in ("tflite", "edgetpu"):
            # Precompute normalization factor to increase numerical stability
            # See https://github.com/ultralytics/ultralytics/issues/7371
            grid_h = shape[2]
            grid_w = shape[3]
            grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
            norm = self.strides / (self.stride[0] * grid_size)
            dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
        else:
            dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides  #(1,4,8400),(1,2,8400) => (1,4,8400)

        y = torch.cat((dbox, cls.sigmoid()), 1)  #(1,84,8400)
        return y if self.export else (y, x)

    def forward_feat(self, x, cv2, cv3):
        y = []
        for i in range(self.nl):
            y.append(torch.cat((cv2[i](x[i]), cv3[i](x[i])), 1))
        return y

    def forward(self, x):
        """Concatenates and returns predicted bounding boxes and class probabilities."""
        y = self.forward_feat(x, self.cv2, self.cv3)
        
        if self.training:
            return y

        return self.inference(y)

    def bias_init(self):
        """Initialize Detect() biases, WARNING: requires stride availability."""
        m = self  # self.model[-1]  # Detect() module
        # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
        # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum())  # nominal class frequency
        for a, b, s in zip(m.cv2, m.cv3, m.stride):  # from
            a[-1].bias.data[:] = 1.0  # box
            b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2)  # cls (.01 objects, 80 classes, 640 img)

    def decode_bboxes(self, bboxes, anchors):
        """Decode bounding boxes."""
        if self.export:
            return dist2bbox(bboxes, anchors, xywh=False, dim=1)
        return dist2bbox(bboxes, anchors, xywh=True, dim=1)

参考:

YOLOv8详解:损失函数、Anchor-Free、样本分配策略;以及与v5的对比_yolov8的损失函数为什么大于1-CSDN博客

  • 51
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值