从零开始实现yolox四:模型的训练(一)损失函数与标签分配


我们今天来介绍一下YOLOX模型的训练,我们前面已经完成了模型的推理和输出解码,现在就差损失函数了。一张图片输入模型后,会输出8400个预测框(解码只起到转化的作用,起不到过滤的作用),如何将这8400个预测框与真实目标框整合到一个损失函数里面去?YOLOv1~v5的策略是将8400个网格区域按照有没有目标划分成正负样本,然后与8400个预测框求损失,具体公式可以看代码。YOLOX摒弃这种策略,取而代之的是使用动态标签分配策略,也就是为每个标签分配若干个预测框,通过这些预测框来求损失。我们今天要介绍的就是这个标签分配策略和损失函数。
个人认为,这篇博客是“从零开始实现yolox”系列中最难,因为原作者将标签分配的很多函数都写到一个类里面了,而且像套娃一样层层调用,很难像前面的几篇文章一样,每写完一个模块就写一个测试函数去测试,只能把整个类写完后再统一写测试函数。

1 IOU损失

yolox中损失函数的公式为:

也就是说,yolox的输出中,边框的中心点和高宽使用回归损失(IOU损失),目标置信度和类别使用分类损失。这里我们定义一个名为IOUloss的类来求IOU损失。

在yolox_from_scratch/nets下新建一个名为 yolo_training.py的文件
在这里插入图片描述
在里面加上一个IOUloss函数

import torch
import torch.nn as nn
import torch.nn.functional as F


class IOUloss(nn.Module):
    def __init__(self, reduction="none", loss_type="iou"):
        """

        Args:
            reduction:
            loss_type:损失类型
        """
        super(IOUloss, self).__init__()
        self.reduction = reduction
        self.loss_type = loss_type

    def forward(self, pred, target):
        """

        Args:
            pred:预测框,维度为 (num_boxes, 4)
            target:真实框,维度为 (num_boxes, 4)

        Returns:

        """
        assert pred.shape[0] == target.shape[0]

        pred = pred.view(-1, 4)
        target = target.view(-1, 4)

        # 预测框与真实框左上角的交点(重叠部分左上角点),维度为(num_boxes, 2)
        tl = torch.max(
            (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
        )

        # 预测框与真实框右下角的交点(重叠部分右下角点),维度为(num_boxes, 2)
        br = torch.min(
            (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
        )

        # 求预测框与真实框的面积
        area_p = torch.prod(pred[:, 2:], 1)     # 在dim=1的方向上求乘积,pred[:, 2:]是每个预测框的w和h,这里是求面积
        area_g = torch.prod(target[:, 2:], 1)
        # area_p和area_g的维度都是(num_boxes, 2)

        # 交集有效性序列(我也不知道en该解释为什么),维度为(num_boxes, )
        en = (tl < br).type(tl.type()).prod(dim=1)
        # 只有br的横坐标和纵坐标同时大于t1的横纵坐标时,pred与target的所在行才有交集,即(tl < br)为True
        # (tl < br)的维度是(num_boxes, 2),.type(tl.type())将其转化为0-1序列,.prod(dim=1)获得乘积
        # en的维度为(num_boxes, ),都是0/1序列,如果左上角的横纵坐标同时小于右下角则是1,否则则为0,类似与布尔索引

        # 求交集面积,维度为(num_boxes, )
        area_i = torch.prod(br - tl, 1) * en
        # br - tl是重叠部分右下角点减左上角点的横纵坐标,torch.prod是求乘积,因为类似于布尔索引,所以 * en表示只考虑有效面积

        # 求交并比,1e-16是为了防止分母为0
        iou = (area_i) / (area_p + area_g - area_i + 1e-16)     # iou的维度为(num_boxes, )

        if self.loss_type == "iou":     # 返回最普通的IOU损失
            loss = 1 - iou ** 2
        elif self.loss_type == "giou":
            c_tl = torch.min(
                (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
            )
            c_br = torch.max(
                (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
            )
            area_c = torch.prod(c_br - c_tl, 1)
            giou = iou - (area_c - area_i) / area_c.clamp(1e-16)
            loss = 1 - giou.clamp(min=-1.0, max=1.0)

        # 是否对损失求均值或求和
        if self.reduction == "mean":
            loss = loss.mean()
        elif self.reduction == "sum":
            loss = loss.sum()

        return loss

代码中默认是普通的IOU损失

写一个测试脚本,yolox_from_scratch下写一个名为iou_loss_test.py的脚本
在这里插入图片描述
加入下面的代码:

import torch

from nets.yolo_training import IOUloss

pred = torch.tensor([150, 150, 200, 120]).reshape(-1, 4)
target = torch.tensor([300, 170, 300, 400]).reshape(-1, 4)

Iou_loss = IOUloss()
print(Iou_loss(pred, target))

输出:

tensor([0.9917])

我们可以自己画图言这个一下这个答案

2 YOLOX的损失函数类

我们定义一个名为YOLOLoss的类来求损失函数。这个类是最难的,耗费了巨大的时间,才读懂源代码。这篇文章的精华,都在函数的注释里。需要注意的是,上一篇文章,我们是使用官方预训练权重,相应的模型是对80个类别进行检测,而我们自己的数据集只有4个类别,这个区别会体现在求损失函数的过程中,各个中间量的维度里。

(1)初始化与forward函数

在 yolo_training.py中写一个名为YOLOLoss的类及其初始化方法

class YOLOLoss(nn.Module):
    def __init__(self, num_classes, strides=[8, 16, 32]):
        super().__init__()
        self.num_classes = num_classes      # 类数
        self.strides = strides              # 步长列表,即输出特征图中的一个点相当于原图片中多少个像素

        self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")
        # BCEWithLogitsLoss先做一次sigmoid(logits函数),然后再求BCE损失
        # nn.BCEWithLogitsLoss的使用,见这篇文章:https://zhuanlan.zhihu.com/p/170558960
        self.iou_loss = IOUloss(reduction="none")                       # IOU损失
        self.grids = [torch.zeros(1)] * len(strides)                    # 先生成若干个0(后面再填充),这是列表的扩展
        # [torch.zeros(1)]的结果是[tensor([0.])]

下面是YOLOLoss类的forward方法:

    def forward(self, inputs, labels=None):
        """

        Args:
            inputs:模型的输出,一个列表,列表中的元素,为各个检测头的输出,维度为(batch_size, anchor_attr, grid_w, grid_h)
                    各个检测头输出的张量维度,只有grid_w和grid_h不一样
            labels:每次从dataloader获得的一个batch的标签,这是一个列表,每个元素对应一张图片,元素个数为batch_size,
                    列表中的每个元素为(num_gt, 5),如果对应的图片中没有GT,那么labels中相应的元素就是一个空张量,否则是(num_objs, 5)
                    5列表示(x, y, w, h, cls_index),cls_index是类别索引

        Returns:

        """
        outputs = []
        x_shifts = []
        y_shifts = []
        expanded_strides = []

        for k, (stride, output) in enumerate(zip(self.strides, inputs)):
            output, grid = self.get_output_and_grid(output, k, stride)
            x_shifts.append(grid[:, :, 0])                              # 获得第k个检测头每个网格点在x方向上的偏移
            y_shifts.append(grid[:, :, 1])                              # 获得第k个检测头每个网格点在y方向上的偏移
            expanded_strides.append(torch.ones_like(grid[:, :, 0]) * stride)    # 表示每个网格的步长,维度和x_shifts相同
            outputs.append(output)

        loss = self.get_losses(x_shifts, y_shifts, expanded_strides, labels, torch.cat(outputs, 1))

        return loss

(2)边框调整与网格生成函数self.get_output_and_grid

def forward中出现了两个成员函数,self.get_output_and_grid和self.get_losses,我们先来介绍第一个:

    def get_output_and_grid(self, output, k, stride):
        """
        在这个函数之前,第k个检测头对应的grid为tensor([0.]),预测框的中心点坐标和高宽都是相对于网格的
        本函数的目的是生成一个张量来表示grid,使其表示每个网格左上角点在输出特征图中的位置
        并让output的中心点坐标和高宽变成letterbox图像中的数据,这一功能和解码是一样的

        下面的注释,是假设调用这个方法是(80, 80)的检测头(除了这个检测头之外,还有还有(40, 40)和(20, 20)两个检测头)
        Args:
            output:当前检测头的输出,维度为(batch_size, anchor_attr, grid_w, grid_h)
            k:当前检测头的序号,如果只有三个检测头,那么k=0~2
            stride:当前检测头输出的特征层的步长

        Returns:output:预测框数据调整后的结果,已将预测框的中心点坐标和高宽转化成letterbox中的数据,
                        维度为(batch_size, anchor_attr, grid_w, grid_h)
                grid:表示每个网格左上角点在输出特征图中的位置

        """
        grid = self.grids[k]  # 第k个检测头的grid
        hsize, wsize = output.shape[-2:]  # 特征图的高宽
        if grid.shape[2:4] != output.shape[2:4]:  # 最开始的时候,grid是tensor([0.]),自然能进入循环
            yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])  # yv和xv的维度都是(80, 80)
            # xv表示输出特征图中每个网格点的横坐标,维度为torch.Size([80, 80])
            # yv表示输出特征图中每个网格点的纵坐标,维度为torch.Size([80, 80])

            # 生成一个满足要求的grid
            grid = torch.stack((xv, yv), 2).view(1, hsize, wsize, 2).type(output.type())
            # torch.stack((xv, yv), 2)是在dim=2的维度上堆叠,堆叠之后的维度为torch.Size([80, 80, 2])
            # .view(1, hsize, wsize, 2)之后,维度为torch.Size([1, 80, 80, 2]),
            # 之所以要在最前面加上一个维度,是为了能和output相加时进行广播

            # 更新第k个检测头对应的grid
            self.grids[k] = grid

        grid = grid.view(1, -1, 2)  # 该操作之后,grid的维度为torch.Size([1, 6400, 2])

        output = output.flatten(start_dim=2).permute(0, 2, 1)
        # output.flatten返回张量的维度为(16, 9, 6400)
        # 经过permute调整后的张量维度为(16, 6400, 9)

        # 将预测框中心点坐标转化为letterbox图像中的坐标
        output[..., :2] = (output[..., :2] + grid) * stride
        # 在此之前,预测框中心点坐标仅仅是相对于网格归一化后的坐标
        # output[..., :2]+grid是将中心点坐标转化到[0, 80)的范围内,也就是特征图中的坐标
        # * stride 是将中心点坐标转化为letterbox图片中预测框的中心点坐标

        # 将预测框的宽高转化为letterbox图像中的宽高
        output[..., 2:4] = torch.exp(output[..., 2:4]) * stride

        return output, grid

(3)损失计算函数self.get_losses(上)

接下来是self.get_losses,个人认为这个是最复杂的,因为从这开始调用了很多其他函数,而且是层层调用,我花了一个星期才吃透这个函数。下面这段代码,只需要先看到torch.cuda.empty_cache()

    def get_losses(self, x_shifts, y_shifts, expanded_strides, labels, outputs):
        """
            计算当前batch的损失,以下注释都是假设 batch_size=16,input_size=(640, 640),num_cls=4 的情况下的注释
        Args:
            x_shifts:一个列表,包括了三个元素,每个元素代表了对应检测头的各个网格在x方向上的偏移,三个元素的维度分别为:(1, 6400), (1, 1600), (1, 400)
            y_shifts:一个列表,包括了三个元素,每个元素代表了对应检测头的各个网格在y方向上的偏移,每个元素的维度和x_shifts一致
            expanded_strides:一个包括了三个元素的列表,每个元素的维度和x_shifts对应的元素相同,表示每个网格(anchor)的步长
            labels:每次从dataloader获得的一个batch的标签,这是一个列表,每个元素对应一张图片,元素个数为batch_size,
                    列表中的每个元素为(num_gt, 5),如果对应的图片中没有GT,那么labels中相应的元素就是一个空张量,否则是(num_objs, 5)
                    5列表示(x, y, w, h, cls_index),cls_index是类别索引
            outputs:模型输出后经过torch.cat的结果,维度为 torch.Size([16, 8400, 9]),16是batch_size,8400是anchor的数量
                    9列表示(x, y, w, h, obj_preds, cls_preds),cls_preds表示4个类别的概率,共4列

        Returns:

        """
        bbox_preds = outputs[:, :, :4]      # 预测框中心点坐标及宽高,维度为(16, 8400, 4)
        obj_preds = outputs[:, :, 4:5]      # 目标置信度,维度为(16, 8400, 1)
        cls_preds = outputs[:, :, 5:]       # 各个类别的概率,维度为(16, 8400, 4)
        total_num_anchors = outputs.shape[1]        # 三个检测头的anchor总数

        x_shifts = torch.cat(x_shifts, 1)           # 维度为torch.Size([1, 8400])
        y_shifts = torch.cat(y_shifts, 1)           # 维度为torch.Size([1, 8400])
        expanded_strides = torch.cat(expanded_strides, 1)   # 维度为torch.Size([1, 8400])

        cls_targets = []
        reg_targets = []
        obj_targets = []
        fg_masks = []

        num_fg = 0.0        # 用来记录当前batch中,总共有多少个anchor被
        for batch_idx in range(outputs.shape[0]):       # batch_idx是当前图片在batch中的索引
            num_gt = len(labels[batch_idx])             # 当前图片的GT数目,即真实框数目

            if num_gt == 0:
                # 如果第batch_idx张图片中,GT的数目为0,那么就新建几个空张量,pytorch0.4之后,空张量也是有维度的
                cls_target = outputs.new_zeros((0, self.num_classes))       # 类别概率,维度为(0, 4)
                # .new_zeros表示新建一个与outputs类型相同的零张量

                reg_target = outputs.new_zeros((0, 4))                      # 预测框的中心点坐标和高宽,维度为(0, 4)
                obj_target = outputs.new_zeros((total_num_anchors, 1))      # 目标置信度,维度为torch.Size([8400, 1])
                fg_mask = outputs.new_zeros(total_num_anchors).bool()       # 能和GT匹配的预测框索引,维度为torch.Size([8400])
                # 当前图片中没有目标,因此全为False
            else:
                gt_bboxes_per_image = labels[batch_idx][..., :4]    # GT的中心点坐标及宽高,维度为(num_gt, 4)
                gt_classes = labels[batch_idx][..., 4]              # GT的类别索引,维度为(num_gt,)
                bboxes_preds_per_image = bbox_preds[batch_idx]      # 预测框的中心点坐标及宽高,维度为(8400, 4)
                cls_preds_per_image = cls_preds[batch_idx]          # 预测框的各个类别概率,维度为(8400, 4)
                obj_preds_per_image = obj_preds[batch_idx]          # 预测框的目标置信度,维度为(8400, 1)

                # 标签分配,即8400个anchor中,哪些作为正样本,哪些作为负样本
                gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg_img = self.get_assignments(
                    num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image,
                    cls_preds_per_image, obj_preds_per_image, expanded_strides, x_shifts, y_shifts
                )
                # gt_matched_classes:       第二轮筛选后得到的anchor对应GT的索引,维度为(len_sg, ),len_sg是经过第二轮筛选后得到的anchor数量
                # fg_mask:                  第二轮筛选后得到的anchor在8400个anchor中的布尔索引,维度为(8400, )
                # pred_ious_this_matching:  第二轮筛选得到的anchor,与其对应的GT的iou,维度为(len_sg, )
                # matched_gt_inds:          第二轮筛选得到的anchor能和哪些GT匹配,维度为(len_sg, )
                # num_fg_img:               当前图片中,经过两轮筛选后,所有GT的正样本总数,即能与任意一个GT匹配的anchor总数,一个纯数字

                torch.cuda.empty_cache()
                # 其实上面这一句可以省略的,当显存中的数据没有任何变量引用时,会自动释放显存,
                # 但释放的显存在Nvidia中看不到,只有加上这一句,才会在Nvidia-smi中释放

                num_fg += num_fg_img

                # 分类目标
                cls_target = F.one_hot(gt_matched_classes.to(torch.int64), self.num_classes).float() * pred_ious_this_matching.unsqueeze(-1)
                # F.one_hot(gt_matched_classes.to(torch.int64), self.num_classes)返回的张量维度为(len_sg, 4)
                # pred_ious_this_matching.unsqueeze(-1)返回的张量维度为(len_sg, 1)
                # 上述两个张量相乘,得到的张量维度为(len_sg, 4)
                # TODO 上述两个张量相乘的目的是什么?为何类型要乘以IOU?

                # 置信度目标
                obj_target = fg_mask.unsqueeze(-1)                  # 维度为(8400, 1)
                # GT的置信度是1
                # obj_target非0则1,这里之所以还有0,是因为负样本(即没有和GT匹配的anchor)也是参与置信度损失计算的

                # 回归目标
                reg_target = gt_bboxes_per_image[matched_gt_inds]   # 维度为(len_sg, 4)
                # 通常,len_sg是大于num_gt的,可以认为,在计算损失函数的时候,有些GT被使用了多次

            cls_targets.append(cls_target)
            reg_targets.append(reg_target)
            obj_targets.append(obj_target.type(cls_target.type()))
            fg_masks.append(fg_mask)

        cls_targets = torch.cat(cls_targets, 0)     # 分类标签
        reg_targets = torch.cat(reg_targets, 0)     # 回归标签
        obj_targets = torch.cat(obj_targets, 0)     # 置信度标签
        fg_masks = torch.cat(fg_masks, 0)

        num_fg = max(num_fg, 1)
        loss_iou = (self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)).sum()     # iou损失
        # fg_mask是布尔索引,bbox_preds.view(-1, 4)[fg_masks]是将能与GT匹配的anchor取出,然后和reg_targets求IOU损失

        loss_obj = (self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)).sum()         # 置信度损失
        loss_cls = (self.bcewithlog_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets)).sum()    # 分类损失

        reg_weight = 5.0                            # 回归权重
        loss = reg_weight * loss_iou + loss_obj + loss_cls      # 损失函数计算,这个才是真正的损失函数

        return loss / num_fg

(4)标签分配函数self.get_assignments(上)

def get_losses调用了self.get_assignments,它是将8400个anchor划分成正负样本,正样本就是能和GT进行匹配的anchor,负样本就是不能和GT进行匹配的anchor,正样本可以和GT计算分类、回归、置信度损失,负样本只能计算置信度损失。这个函数的代码如下(先展示一部分,讲完第二轮筛选后会讲第二部分):

    @torch.no_grad()
    def get_assignments(self, num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes,bboxes_preds_per_image,
                       cls_preds_per_image, obj_preds_per_image, expanded_strides, x_shifts, y_shifts):
        """
            一张图片输入到模型后,三个检测头会得到8400个anchor,这些anchor只有一部分会当成正样本,与标签进行损失函数计算
            这个函数就是把这些anchor给找出来
        Args:
            num_gt:当前图片中GT的数量,纯数字
            total_num_anchors:三个检测头的anchor总数,纯数字,这里是8400
            gt_bboxes_per_image:当前图片中GT的中心点坐标及宽高,维度为(num_gt, 4)
            gt_classes:当前图片中,所有GT的类别索引,维度为(num_gt,)
            bboxes_preds_per_image:当前图片预测框的中心点坐标及宽高,维度为(8400, 4)
            cls_preds_per_image:当前图片预测目标的类别,维度为(8400, 4)
            obj_preds_per_image:当前图片预测目标的置信度(目标置信度),维度为(8400, 1)
            expanded_strides:各个anchor与输入图片中网格的尺寸比例,即步长,维度为(1, 8400)
            x_shifts:各个anchor在特征图中的横坐标,维度为(1, 8400)
            y_shifts:各个anchor在特征图中的纵坐标,维度为(1, 8400)

        Returns:gt_matched_classes:第二轮筛选后得到的anchor对应GT的索引,维度为(len_sg, ),len_sg是经过第二轮筛选后得到的anchor数量
                fg_mask:第二轮筛选后得到的anchor在8400个anchor中的布尔索引,维度为(8400, )
                pred_ious_this_matching:第二轮筛选得到的anchor,与其对应的GT的iou,维度为(len_sg, )
                matched_gt_inds:第二轮筛选得到的anchor能和哪些GT匹配,维度为(len_sg, )
                num_fg:当前图片中,经过两轮筛选后,所有GT的正样本总数,即能与任意一个GT匹配的anchor总数,一个纯数字

        """
        """第一轮筛选"""
        fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(gt_bboxes_per_image, expanded_strides, x_shifts,
                                                                 y_shifts, total_num_anchors, num_gt)
        # self.get_in_boxes_info是使用两种方法,对8400个anchor进行筛选,筛选出能和GT进行匹配的anchor
        # fg_mask: 能通过两种方法之一的anchor的布尔索引,维度为(8400, )
        # is_in_boxes_and_center: 这也是一个布尔索引,表示第一种筛选方法得到的anchor中,能通过第二种筛选方法的anchor,
        #                         维度为(num_gt, len_first),len_first是is_in_boxes_anchor中True的数量
        # is_in_boxes_and_center[i, j]表示什么呢?假设8400个anchor中,有50个通过了第一种方法筛选,
        # i表示第i个GT,但j并不表示8400个anchor中的第j个anchor,而是50个anchor中的第j个

(5)第一轮筛选self.get_in_boxes_info

def get_assignments出现了self.get_in_boxes_info函数,它对8400个anchor做第一轮筛选。第一轮筛选使用了两种方法,任意一个anchor只要通过其中一种筛选方法,就可以认为其通过了第一轮筛选。

第一种方法是先把网格的各个中心点坐标求出来,判断其是否在GT的内部,如果在GT的内部,那么就认为该网格对应的anchor与GT匹配。如下图所示:
在这里插入图片描述
第二种方法是以每个GT的中心点为中心,生成一个边长为5的正方形,判断各个网格的中心点是否在这个网格的内部
在这里插入图片描述
看懂了上面两幅图,就很好理解self.get_in_boxes_info函数了,现在我们来看看self.get_in_boxes_info的代码:

    def get_in_boxes_info(self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt, center_radius = 2.5):
        """
            三个检测头,共有8400个anchor,但这些anchor只有部分能和标签进行匹配,本函数就是筛选出能和标签进行匹配的anchor
            本函数中使用两种方法对anchor进行筛选
        Args:
            gt_bboxes_per_image:当前图片中,各个真实框的中心点坐标及宽高,维度为(num_gt, 4)
            expanded_strides:每个网格的步长,维度为torch.Size([1, 8400])
            x_shifts:维度为torch.Size([1, 8400])
            y_shifts:维度为torch.Size([1, 8400])
            total_num_anchors:网格点总数,纯数字,例如8400
            num_gt:真实框总数,纯数字
            center_radius:半径,纯数字

        Returns:is_in_boxes_anchor 能通过两种方法之一的anchor的布尔索引,维度为(8400, )
                is_in_boxes_and_center 这也是一个布尔索引,表示第一种筛选方法得到的anchor中,能通过第二种筛选方法的anchor,
                维度为(num_gt, len_first),len_first是is_in_boxes_anchor中True的数量

        """

        # 获得每个网格的步长
        expanded_strides_per_image = expanded_strides[0]            # 维度为(8400, )

        # 获得各个网格的中心点横坐标
        x_centers_per_image = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1)
        # (x_shifts[0] + 0.5) * expanded_strides_per_image的维度为(8400, ),unsqueeze(0)之后为(1, 8400)
        # repeat(num_gt, 1)之后为(num_gt, 8400)

        # 获得各个网格的中心点纵坐标
        y_centers_per_image = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1)

        """第一种筛选方式:筛选出中心点在GT内部的网格,所对应的anchor"""
        """各个GT的上下左右边缘"""
        # gt_bboxes_per_image_l当前图片,每个真实框的左边缘横坐标,l表示left,同样的,r、t、b分别 表示右、上、下
        gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors)
        # gt_bboxes_per_image[:, 0]的维度是(num_gt, ),unsqueeze(1)后是(num_gt, 1),repeat(1, total_num_anchors)后是(num_gt, 8400)

        gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors)
        gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors)
        gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors)

        """计算各个网格中心点与GT各个边缘的距离"""
        b_l = x_centers_per_image - gt_bboxes_per_image_l           # 如果b_l>0,表示对应的网格中心在左边缘的右边,维度为(num_gt, 8400)
        b_r = gt_bboxes_per_image_r - x_centers_per_image           # 如果b_r>0,表示对应的网格中心在右边缘的左边
        b_t = y_centers_per_image - gt_bboxes_per_image_t           #
        b_b = gt_bboxes_per_image_b - y_centers_per_image
        bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)          # 新增加一个维度,stack之后,返回的张量维度为(num_gt, 8400, 4)

        """获得各个anchor的匹配情况"""
        # 获得GT和anchor的匹配矩阵
        is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0          # dim=-1表示对最后一个维度求最小值
        # 只有当最后一个维度的4个数都大于0,才说明对应网格的中心点在GT的内部
        # 获得一个布尔索引,维度为(num_gt, 8400),如果is_in_boxes[i, j]为True,表示第i个GT和第j个网格对应的anchor能匹配上

        # 获得正样本的索引
        is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
        # is_in_boxes.sum(dim=0)是计算每个网格能与多少个GT进行匹配,>0表示对应的对应的anchor至少存在一个GT与之匹配
        # 返回值的维度为(8400, )

        """第二种筛选方式:以GT的中心为中心,生成一个边长为5个stride的正方形(这里简称GT方框),
        将中心点落在这个正方形内的网格所对应的anchor,作为与GT匹配的正样本"""
        # 获得GT方框的左右上下边缘
        gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0)
        # gt_bboxes_per_image[:, 0]的维度为(num_gt, ),.unsqueeze(1)的维度为(num_gt, 1),
        # .repeat(1, total_num_anchors)的维度为(num_gt, 8400)
        gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0)
        gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0)
        gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0)

        """计算各个网格中心点与GT方框各个边缘的距离"""
        c_l = x_centers_per_image - gt_bboxes_per_image_l       # 维度为(num_gt, 8400)
        c_r = gt_bboxes_per_image_r - x_centers_per_image
        c_t = y_centers_per_image - gt_bboxes_per_image_t
        c_b = gt_bboxes_per_image_b - y_centers_per_image
        center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)

        """获得各个anchor的匹配情况"""
        # 获得GT和anchor的匹配矩阵
        is_in_centers = center_deltas.min(dim=-1).values > 0.0      # 维度为(num_gt, 8400)
        # 获得正样本的布尔索引
        is_in_centers_all = is_in_centers.sum(dim=0) > 0            # 维度为(8400, )

        """将上述两种方法综合起来"""
        # anchor按照上述两种方法,只要有一种能和标签匹配上,就认为其是正样本
        is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all        # 维度为(8400, )
        # 从第一种筛选方法得到的anchor中,筛选出能够通过第二种方法的anchor
        is_in_boxes_and_center = is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
        # 注意,参与&的两个矩阵,他们的列索引都是 is_in_boxes_anchor,也就说,是从第一种的方法的结果中进行筛选
        # 维度为(num_gt, len_first),len_first是is_in_boxes_anchor中True的数量

        return is_in_boxes_anchor, is_in_boxes_and_center

(6)第二轮筛选与self.bboxes_iou、self.dynamic_k_matching(标签分配函数(下))

第一轮筛选后,就要做第二轮筛选了。第二轮筛选使用简化的OTA算法,即simOTA算法,它的过程如下:
(1)计算每个anchor(经过第一轮筛选后得到的anchor)与每个GT的分类损失和iou损失,然后求和得到cost矩阵(成本函数);
(2)在经过第一轮筛选后得到的anchor中,为每个GT找到与其有最大IOU的10个anchor,将这10个anchor对应的IOU值求和取整,即为当前GT所匹配到的anchor数量,即dynamic_k,IOU排名前dynamic_k的anchor即为和当前GT匹配的anchor。可以用如下例子理解这一过程:
在这里插入图片描述
(3)部分anchor可能和多个GT匹配,这是不允许的,要对这些anchor进行处理。假设经过上述操作后,某些anchor各自能和2个及2个以上的GT匹配,对每一个这样的anchor,寻找cost最小的GT作为与其匹配的GT,这里的cost是第(1)步中求得的cost。个人认为这里还是有些漏洞的,因为可能出现这种情况:经过第(2)步后,第j个anchor同时与第m和n个GT匹配,但第cost中,第j列最小的元素在第i行,i≠m且i≠n,也就是说,第i个GT的前dynamic_k个anchor中,并不包含第j个anchor。

理解了以上过程,就能更好地看懂第二轮筛选的代码了。让我们回到def get_assignments函数中,添加以下代码,做第二轮筛选:

        """下面是第二轮筛选"""
        """获得筛选后的anchor的边框、类别概率和置信度"""
        bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]        # 维度为(len_fg, 4),len_fg是fg_mask中True的个数
        cls_preds_ = cls_preds_per_image[fg_mask]                       # 维度为(len_fg, 4),因为是4个类别,所以第二个维度是4
        obj_preds_ = obj_preds_per_image[fg_mask]                       # 维度为(len_fg, 1)
        num_in_boxes_anchor = bboxes_preds_per_image.shape[0]           # len_fg,纯数字,即通过第一轮筛选后的anchor数量

        """计算IOU损失"""
        # 计算GT和第一轮得到的anchor的交并比
        pair_wise_ious = self.bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False)    # 维度为(num_gt, len_fg)
        pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)                                 # iou损失

        """计算分类损失"""
        # 正样本anchor的预测分类
        cls_preds_ = cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
        # cls_preds_.float().unsqueeze(0)返回的维度为(1, len_fg, 4),.repeat(num_gt, 1, 1)返回的维度为(num_gt, len_fg, 4)
        # .sigmoid_()对每个类别的概率做二分类
        # obj_preds_.unsqueeze(0)返回的维度为(1, len_fg, 1),.repeat(num_gt, 1, 1)返回的维度为(num_gt, len_fg, 1)
        #
        # 之所以这么操作,是为了方便做广播。上述命令执行后,cls_preds_为每个类别的置信度,维度为(num_gt, len_fg, 4)
        # 个人认为可以改成 (cls_preds_.float().sigmoid_() * obj_preds_.sigmoid_())..unsqueeze(0).repeat(num_gt, 1, 1)

        # 将正样本anchor的预测分类做成one-hot编码
        gt_cls_per_image = F.one_hot(gt_classes.to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, num_in_boxes_anchor, 1)
        # F.one_hot(gt_classes.to(torch.int64), self.num_classes)返回的张量的维度为(num_gt, 4)
        # .unsqueeze(1)返回的维度为(num_gt, 1, 4),.repeat(1, num_in_boxes_anchor, 1)返回的维度为(num_gt, len_fg, 4)
        # 上述命令执行后,gt_cls_per_image表示每个GT的类别one-hot编码

        # 计算分类损失
        pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt_(), gt_cls_per_image, reduction="none").sum(-1)
        # F.binary_cross_entropy(A, B, reduction="none")的维度为(num_gt, len_fg, num_classes)
        # .sum(-1)的维度为(num_gt, len_fg)
        # pair_wise_cls_loss表示第i个GT和第j个(经过第一轮筛选后的anchor的第j个)anchor的分类损失

        del cls_preds_  # 释放内存

        # 计算成本函数
        cost = pair_wise_cls_loss + 3.0 * pair_wise_ious_loss + 100000.0 * (~is_in_boxes_and_center).float()
        # 维度为(num_gt, len_fg)
        # TODO 为何要加上 100000.0 * (~is_in_boxes_and_center).float()

        """SimOTA求解"""
        num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds = self.dynamic_k_matching(cost,
                                                                                                       pair_wise_ious,
                                                                                                       gt_classes,
                                                                                                       num_gt, fg_mask)
        # num_fg:                   当前图片中,所有GT的正样本总数,即能与任意一个GT匹配的anchor总数,一个纯数字
        # gt_matched_classes:       第二轮筛选后得到的anchor对应GT的索引,维度为(len_sg, ),len_sg是经过第二轮筛选后得到的anchor数量
        # pred_ious_this_matching:  第二轮筛选得到的anchor,与其对应的GT的iou,维度为(len_sg, )
        # matched_gt_inds:          第二轮筛选得到的anchor能和哪些GT匹配,维度为(len_sg, )
        # 这个函数还对fg_mask进行了更新,但没有将其作为返回值,而是以传引用的方式进行了更新,
        # 经过这个函数操作之后,fg_mask变成了第二轮筛选后得到的anchor在8400个anchor中的索引

        del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss       # 释放内存

        return gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg

这里出现了self.bboxes_iou和self.dynamic_k_matching函数,前者用来计算GT和预测框的IOU,后者用来求每个GT的dynamic_k。我们先讲一下self.bboxes_iou,其代码如下:

    def bboxes_iou(self, bboxes_a, bboxes_b, xyxy=True):        # 求两个边框的交并比
        """
            求GT与预测框(anchor)的交并比
        Args:
            bboxes_a: GT,维度为(num_gt, 4)
            bboxes_b: 预测框,维度为(len_fg, 4),len_fg是经过第一轮筛选后得到的anchor数量
            xyxy:GT和预测框,是否为边框上下角点的坐标

        Returns: iou  GT和预测框的交并比,维度为(num_gt, len_fg),例如iou[i, j]表示第i个GT和第j个预测框的交并比

        """
        if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
            raise IndexError

        if xyxy:
            tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])      # 重合区域的左边缘和上边缘
            # bboxes_a[:, None, :2]等价于先获得切片bboxes_a[:, :2],再使用 unsqueeze(1),得到的张量维度为(num_gt, 1, 2)
            # bboxes_b[:, :2]的维度为(len_fg, 2),而bboxes_a[:, None, :2]的维度是(num_gt, 1, 2),两者可以进行广播
            # torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])的结果为(num_gt, len_fg, 2)

            br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])      # 重合区域的下边缘和右边缘

            area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)       # 方框a的面积,维度为(num_gt, )
            area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)       # 方框b的面积,维度为(len_fg, )
        else:
            tl = torch.max(                                     # 重合区域的左边缘和上边缘,维度为(num_gt, len_fg, 2)
                (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
                (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),
            )
            br = torch.min(                                     # 重合区域的下边缘和右边缘,维度为(num_gt, len_fg, 2)
                (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
                (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),
            )

            area_a = torch.prod(bboxes_a[:, 2:], 1)             # 方框a的面积,维度为(num_gt, )
            area_b = torch.prod(bboxes_b[:, 2:], 1)             # 方框b的面积,维度为(len_fg, )

        # 获得左上小于右下的索引
        en = (tl < br).type(tl.type()).prod(dim=2)
        # (tl < br)得到布尔索引,.type(tl.type())将其转化为数值,.prod(dim=2)表示将第二个维度的元素相乘
        # 如果相乘之后还是1,那么说明“左<右”和“上<下”同时满足,即GT和预测框存在交集
        # 维度为(num_gt, len_fg)

        # 计算交集面积
        area_i = torch.prod(br - tl, 2) * en
        # 先求交集面积,然后通过 * en,对存在交集的面积进行筛选,area_i的维度为(num_gt, len_fg)

        # 返回交并比
        return area_i / (area_a[:, None] + area_b - area_i)

接下来是self.dynamic_k_matching函数的代码:

    def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
        """
            一个GT能和多个anchor进行匹配,但一个anchor只能和一个GT进行匹配,也就是说GT和anchor是一对多的关系
            这个函数先进行第二轮筛选,获得若干个anchor,然后求这些anchor与对应GT、GT的目标类别、与所匹配GT的IOU
            本函数还以传引用的方式对fg_mask进行了更新,更新后的fg_mask变成了第二轮筛选后得到的anchor在8400个anchor中的布尔索引
        Args:
            cost:第一轮筛选得到的anchor与GT的成本函数,维度为(num_gt, len_fg),len_fg是8400个anchor经过第一轮筛选后得到的数量
            pair_wise_ious:GT和第一轮得到的anchor的交并比,维度为(num_gt, len_fg)
            gt_classes:当前图片中,所有GT的类别索引,维度为(num_gt,)
            num_gt:当前图片中GT的数量,纯数字
            fg_mask:第一轮筛选得到的anchor在8400个anchor中的布尔索引,维度为(8400, )

        Returns: num_fg:当前图片中,所有GT的正样本总数,即能与任意一个GT匹配的anchor总数,一个纯数字
                gt_matched_classes:第二轮筛选后得到的anchor对应GT的索引,维度为(len_sg, ),len_sg是经过第二轮筛选后得到的anchor数量
                pred_ious_this_matching:第二轮筛选得到的anchor,与其对应的GT的iou,维度为(len_sg, )
                matched_gt_inds:第二轮筛选得到的anchor能和哪些GT匹配,维度为(len_sg, )

        假如len_fg=50,len_sg=20,
        若 matched_gt_inds[5]=3 则表示第5个anchor(20中的第5个)匹配的GT的索引是3
        gt_matched_classes[5]=2 则表示与第5个anchor(20中的第5个)匹配的GT(即索引为3的GT),其类别索引是2
        pred_ious_this_matching[5]=0.53,则表示第5个anchor(20中的第5个),与其匹配的GT(即索引为3的GT)的iou为0.53

        """
        """初始化匹配矩阵"""
        matching_matrix = torch.zeros_like(cost)        # 维度为(num_gt, len_fg)

        """确定每个GT能匹配的anchor数量"""
        n_candidate_k = min(10, pair_wise_ious.size(1))                  # 看10和len_fg哪个小
        topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1)  # 对每个GT,寻找最大的10个(或len_fg个)IOU
        # topk_ious的维度为(num_gt, n_candidate_k)

        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)  # 获得每个GT的k,即每个GT,该与多少个网格进行匹配
        # 上一步中,对各个GT求最大的k个IOU,这里topk_ious.sum(1)是将这k个最大的IOU进行求和操作,返回张量的维度为(num_gt, )
        # 可能某个GT,其最大的k个IOU相加都不不超过1,那么对这个GT,它的k就为1

        """给每个真实框选取k个标签进行匹配"""
        for gt_idx in range(num_gt):
            _, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)  # largest=False表示取最小
            # pos_idx是损失函数最小的k个预测框(anchor)对应的索引

            matching_matrix[gt_idx][pos_idx] = 1.0
            # 给匹配矩阵的对应的索引赋1

        del topk_ious, dynamic_ks, pos_idx          # 释放内存

        """有些anchor可能同时和多个GT匹配,需要在matching_matrix中,对这些anchor进行处理"""
        anchor_matching_gt = matching_matrix.sum(0)         # 维度为(len_fg, ),表示每个anchor能和多少个GT进行匹配
        if (anchor_matching_gt > 1).sum() > 0:
            # anchor_matching_gt>1 的返回值是一个维度为(len_fg, )的布尔索引
            # 在len_fg个anchor中,如果存在某个anchor能和多个GT匹配,那么这个anchor对应的索引就是True
            # .sum()用来求有多少个这样的特征点

            #   当某一个anchor指向多个GT的时候,选取cost最小的GT作为与其匹配的GT
            _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
            # cost[:, anchor_matching_gt > 1] 是将能与多个GT匹配的anchor取出,维度为(num_gt, match_mul),
            # 某些anchor能与多个GT匹配,这样的anchor数量为match_mul,即match_mul是能与多个GT匹配的anchor的数量
            # torch.min dim=0表示对每列求最小值
            # cost_argmin每列最小值所对应的索引(GT的索引),维度为(match_mul, )
            # 若cost_argmin[2]为4,则表示在cost矩阵中,第2个anchor(match_mul中的第2个anchor)所在列中,与第4个GT的损失函数最小

            # 在matching_matrix中,先把这样的anchor所在列全部设为0,再把每个这样的anchor列的最小值所对应的GT设为1
            matching_matrix[:, anchor_matching_gt > 1] *= 0.0
            matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0

        """第二轮筛选整理"""
        fg_mask_inboxes = matching_matrix.sum(0) > 0.0
        # 返回一个布尔索引,代表第一轮筛选得到的anchor是否能通过第二轮筛选,即是否为正样本,维度为(len_fg, )
        num_fg = fg_mask_inboxes.sum().item()           # 当前图片中,所有GT的正样本总数,即能与任意一个GT匹配的anchor总数,即len_sg

        """对fg_mask进行更新"""
        fg_mask[fg_mask.clone()] = fg_mask_inboxes
        # fg_mask本身代表8400个anchor中,通过第一轮筛选的anchor所对应的布尔索引
        # fg_mask[fg_mask.clone()],布尔索引的布尔索引,即把所有为True的元素筛选出来,对这些元素进行重新赋值,
        # 赋值之后,fg_mask代表第二轮筛选后得到的anchor在8400个anchor中的索引
        # 至此,第二轮筛选结束,fg_mask的维度为(8400, )

        """获得第二轮筛选后得到的anchor,其所对应GT、GT的目标类别、与所匹配GT的IOU"""
        matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
        # 获得anchor对应GT的索引,维度为(len_sg, ),sg表示 second GT
        # matching_matrix[:, fg_mask_inboxes]返回的是GT与第二轮筛选得到的anchor的匹配矩阵,维度为(num_fg, len_sg)
        # .argmax(0)是求各列的最大值,因为各列只有一个值为1,其余都为0,由于每个anchor最多只能和一个GT匹配,
        # 所以这里是求各个anchor能和哪些GT匹配,维度为(len_sg, )
        # 假设len_sg=20,即通过第二轮筛选后还剩20个anchor,若matched_gt_inds[5]的值为3,
        # 那么意思是第5个anchor(20中的第5个)匹配的GT的索引是3

        gt_matched_classes = gt_classes[matched_gt_inds]  # 根据GT的索引,获得特征点对应的GT的类别
        # gt_matched_classes维度为(len_sg, )
        # 若gt_matched_classes[5]的值为2,那么意思是第5个anchor(20中的第5个)匹配的GT,其类别索引为2

        pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[fg_mask_inboxes]
        # matching_matrix * pair_wise_ious的维度是(num_gt, len_fg),表示经第二轮筛选后得到的anchor与GT的iou,
        # 每列最多只有一个元素有值,有可能一个都没有,所以.sum(0)是将这些iou给取出来,变成一个维度为(fg_mask, )的张量,
        # [fg_mask_inboxes]是从中取出经过第二轮筛选后得到的anchor与对应的GT的iou
        # 最后得到的pred_ious_this_matching,其维度为(len_sg, ),表示第二轮筛选得到的anchor,与其对应的GT的iou
        # 若pred_ious_this_matching[5]的值为0.53,则表示第5个anchor(20中的第5个),与其匹配的GT的iou为0.53
        # 第5个anchor与哪一个GT匹配呢,这个要看 matched_gt_inds 才知道

        return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds

(7)损失计算函数self.get_losses(下)

最后,我们回到def get_losses函数中,我们之前只看到torch.cuda.empty_cache(),接下来要构建目标和损失函数了,代码如下:

                num_fg += num_fg_img

                # 分类目标
                cls_target = F.one_hot(gt_matched_classes.to(torch.int64), self.num_classes).float() * pred_ious_this_matching.unsqueeze(-1)
                # F.one_hot(gt_matched_classes.to(torch.int64), self.num_classes)返回的张量维度为(len_sg, 4)
                # pred_ious_this_matching.unsqueeze(-1)返回的张量维度为(len_sg, 1)
                # 上述两个张量相乘,得到的张量维度为(len_sg, 4)
                # TODO 上述两个张量相乘的目的是什么?为何类型要乘以IOU?

                # 置信度目标
                obj_target = fg_mask.unsqueeze(-1)                  # 维度为(8400, 1)
                # GT的置信度是1
                # obj_target非0则1,这里之所以还有0,是因为负样本(即没有和GT匹配的anchor)也是参与置信度损失计算的

                # 回归目标
                reg_target = gt_bboxes_per_image[matched_gt_inds]   # 维度为(len_sg, 4)
                # 通常,len_sg是大于num_gt的,可以认为,在计算损失函数的时候,有些GT被使用了多次

            cls_targets.append(cls_target)
            reg_targets.append(reg_target)
            obj_targets.append(obj_target.type(cls_target.type()))
            fg_masks.append(fg_mask)

        cls_targets = torch.cat(cls_targets, 0)     # 分类标签
        reg_targets = torch.cat(reg_targets, 0)     # 回归标签
        obj_targets = torch.cat(obj_targets, 0)     # 置信度标签
        fg_masks = torch.cat(fg_masks, 0)

        num_fg = max(num_fg, 1)
        loss_iou = (self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)).sum()     # iou损失
        # fg_mask是布尔索引,bbox_preds.view(-1, 4)[fg_masks]是将能与GT匹配的anchor取出,然后和reg_targets求IOU损失

        loss_obj = (self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)).sum()         # 置信度损失
        loss_cls = (self.bcewithlog_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets)).sum()    # 分类损失

        reg_weight = 5.0                            # 回归权重
        loss = reg_weight * loss_iou + loss_obj + loss_cls      # 损失函数计算,这个才是真正的损失函数

        return loss / num_fg

(8)YOLOLoss类的测试

最后,我们写一个测试脚本。在yolox_from_scratch下新建一个名为yolo_loss_test.py的脚本,建立后项目结构如下:
在这里插入图片描述

import random

import torch
import numpy as np
from torch.utils.data import DataLoader

from nets.yolo import YoloBody
from nets.yolo_training import YOLOLoss
from utils.dataloader import YoloDataset, yolo_dataset_collate

if __name__ == '__main__':
    # 设置种子
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)

    """获得数据集类的相关初始化参数"""
    train_annotation_path = '2007_train.txt'
    with open(train_annotation_path) as f:
        train_lines = f.readlines()  # train_lines将是一个列表

    input_shape = [640, 640]
    num_classes = 4
    mosaic = False
    mixup = False

    """建立数据集类对象"""
    train_dataset = YoloDataset(train_lines, input_shape, num_classes, is_train=True, mosaic=mosaic, mixup=mixup)

    """建立导入器对象"""
    gen = DataLoader(train_dataset, batch_size=16, pin_memory=True, collate_fn=yolo_dataset_collate)

    """建立模型对象"""
    model = YoloBody(4, 's')  # 新建模型,'s'表示新建的为yolox_s模型

    """建立损失函数计算器"""
    yolo_loss = YOLOLoss(4)

    for iteration, batch in enumerate(gen):
        images, targets = batch[0], batch[1]
        images = torch.from_numpy(images).type(torch.FloatTensor)
        targets = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in targets]
        outputs = model(images)
        loss = yolo_loss(outputs, targets)
        break

    print(loss)

输出为

tensor(3596.6357, grad_fn=<DivBackward0>)

能正常输出,说明我们写的损失函数类,没有bug。

  • 17
    点赞
  • 60
    收藏
    觉得还不错? 一键收藏
  • 13
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值