yolox核心部分复现

yolox核心部分复现

yolox讲解
重要策略OTA讲解,虽然yolox用的是简单的k动态分配
OTA另一个讲解

样本标记原则:

预处理:对于每个苗匡,和真实框,先进行预处理筛选出可能作为有物体的苗匡。依据:如果中心在真实框内,或者中心在,以真实框中心点为中心,向外扩展center_radius的正方形内

然后动态分配,见底下的代码

然后再删除重复的点

import torch
import numpy as np
import math
import torch.nn as nn
import torch.nn.functional as F
class YoloHead(nn.Module):
    def bboxes_iou(bboxes_a, bboxes_b, xyxy=True):
        # 这个函数,输入为a:[苗匡的数量,每个苗匡的四个信息,有多少张图片]
        # b:[真实框数量,每个苗匡的四个信息,有多少张图片]
        # 输出得到的是[苗匡的数量,真实框的数量,iou值]
        if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
            raise IndexError
        if xyxy:
            tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
            # 写个None,相当于加上了一维
            br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
            area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
            # 将这一维全部都乘起来的意思
            area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)

        else:
            tl = torch.max(
                (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
                (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),
            )
            br = torch.min(
                (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
                (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),
            )

            area_a = torch.prod(bboxes_a[:, 2:], 1)
            area_b = torch.prod(bboxes_b[:, 2:], 1)
        en = (tl < br).type(tl.type()).prod(dim=2)
        area_i = torch.prod(br - tl, 2) * en  # * ((tl < br).all())
        return area_i / (area_a[:, None] + area_b - area_i)
    def get_assignments(
            self,
            batch_idx,
            num_gt,
            total_num_anchors,
            gt_bboxes_per_image,
            gt_classes,
            bboxes_preds_per_image,
            expanded_strides,
            x_shifts,
            y_shifts,
            cls_preds,
            bbox_preds,
            obj_preds,
            labels,
            imgs,
            mode="gpu",
    ):
        # 这个函数用于给已经得到的苗匡划分标签


        fg_mask,is_in_boxes_and_center=self.get_in_boxes_info(
            gt_bboxes_per_image,
            expanded_strides,
            x_shifts,
            y_shifts,
            total_num_anchors,
            num_gt,
        )
        bboxes_preds_per_image=bboxes_preds_per_image[fg_mask]
        cls_preds=cls_preds[batch_idx][fg_mask]
        obj_preds=cls_preds[batch_idx][fg_mask]
        num_in_boxes_anchor=bboxes_preds_per_image.shape[0]

        pair_wise_ious=self.bboxes_iou(gt_bboxes_per_image,bboxes_preds_per_image,False)
        # 获得每个苗匡,针对于每个真实框的iou

        gt_cls_per_image=(
            F.one_hot(gt_classes.to(torch.int64),self.num_classes)
            .float()
            .unsqueeze(1)
            .repeat(1,num_in_boxes_anchor,1)
        )
        # 针对真实框,将每个真实框对应的类别,分类成为onehot编码,便于后面计算置信度
        # 同时再加一维然后复制(苗匡的数量)次,让他变成,每个真实框对应每个苗匡的形式

        pair_wise_ious_loss=-torch.log(pair_wise_ious+1e-8)
        with torch.cuda.amp.autocast(enabled=False):
            cls_preds_ = (
                    cls_preds.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
                    * obj_preds.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
            )
            # 对于苗匡,这个苗匡是该类物体的概率,等于苗匡物理的概率乘上他是物体的概率
            pair_wise_cls_loss=F.binary_cross_entropy(
                cls_preds_.sqrt_(),gt_cls_per_image,reduction='none'
            ).sum(-1)
            # sum是把所有类别的交叉熵都统计上来,因为前面已经转成onehot类型了
        del cls_preds_
        cost=(
            pair_wise_cls_loss
            +3.0*pair_wise_ious_loss
            +100000.0*(~is_in_boxes_and_center)
        )
        # 这里乘上一个100000是因为,如果某个苗匡根本就没有被选入,就直接取消掉

        (
            num_fg,
            gt_matched_classes,
            pred_ious_this_matching,
            matched_gt_inds,

        )=self.dynamic_k_matching(cost,pair_wise_ious,gt_classes,num_gt,fg_mask)
        del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss




    def get_in_boxes_info(
        self,
        gt_bboxes_per_image,
        expended_strides,
        x_shifts,
        y_shifts,
        total_num_anchors,
        num_gt
    ):
        expanded_strides_per_image=expended_strides[0]
        x_shifts_per_image=x_shifts[0]*expanded_strides_per_image
        y_shifts_per_image = y_shifts[0] * expanded_strides_per_image
        # 将图片先变成该有的原样子
        x_centers_per_image=(
            (x_shifts_per_image+0.5*expanded_strides_per_image)
            .unsqueeze(0)
            .repeat(num_gt,1)
        )
        # unsqueeze表示在0的位置添加上1维
        # repeat沿着第一维复制数据numgt次
        y_centers_per_image = (
            (y_shifts_per_image + 0.5 * expanded_strides_per_image)
                .unsqueeze(0)
                .repeat(num_gt, 1)
        )

        # 以下代码为计算真实框的左上角和右下角的位置数据
        gt_bboxes_per_image_l = (
            (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2])
                .unsqueeze(1)
                .repeat(1, total_num_anchors)
        )
        gt_bboxes_per_image_r = (
            (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2])
                .unsqueeze(1)
                .repeat(1, total_num_anchors)
        )
        gt_bboxes_per_image_t = (
            (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3])
                .unsqueeze(1)
                .repeat(1, total_num_anchors)
        )
        gt_bboxes_per_image_b = (
            (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3])
                .unsqueeze(1)
                .repeat(1, total_num_anchors)
        )
        """
        上面的数据处理,是让候选框和真实框,拥有相同维度的表示方式,便于后面计算差值
        第一维为真实框的数量,第二维是候选框的数量,
        如此一来,每一个候选框对应于所有的候选框
        """
        # 相减计算差值,判断一下候选框是否在这个
        b_l = x_centers_per_image - gt_bboxes_per_image_l
        b_r = gt_bboxes_per_image_r - x_centers_per_image
        b_t = y_centers_per_image - gt_bboxes_per_image_t
        b_b = gt_bboxes_per_image_b - y_centers_per_image
        bbox_deltas=torch.stack([b_l,b_r,b_t,b_b],2)

        # 只有同时满足这四个值都大于零,才能说明,候选框的中心在真实框里面

        is_in_boxes=bbox_deltas.min(dim=-1).values>0.0
        # 查看对于某个候选框,他是否被真实框选中了,要给予正标记
        is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
        center_radius = 2.5
        """
        # 在以真实框中心点为中心,向外扩展center_radius的正方形
        # 里面的苗匡(真实框里面的苗匡)
        """

        gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(
            1, total_num_anchors
        ) - center_radius * expanded_strides_per_image.unsqueeze(0)

        gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(
            1, total_num_anchors
        ) + center_radius * expanded_strides_per_image.unsqueeze(0)

        gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
            1, total_num_anchors
        ) - center_radius * expanded_strides_per_image.unsqueeze(0)

        gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
            1, total_num_anchors
        ) + center_radius * expanded_strides_per_image.unsqueeze(0)

        c_l = x_centers_per_image - gt_bboxes_per_image_l
        c_r = gt_bboxes_per_image_r - x_centers_per_image
        c_t = y_centers_per_image - gt_bboxes_per_image_t
        c_b = gt_bboxes_per_image_b - y_centers_per_image
        center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
        is_in_centers = center_deltas.min(dim=-1).values > 0.0
        is_in_centers_all = is_in_centers.sum(dim=0) > 0

        is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all

        is_in_boxes_and_center = (
                is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
        )
        # 得到的结果是被选中的苗匡,和同时被两种选法选中的苗匡
        return is_in_boxes_anchor, is_in_boxes_and_center


    def dynamic_k_matching(
        self,
        cost,
        pair_wise_ious,
        gt_classes,
        num_gt,
        fg_mask):
        """
        注意,这里面所有的数据,都是一个苗匡对应着一个真实框的
        :param cost:
        :param pair_wise_ious:
        :param gt_classes:
        :param num_gt:
        :param fg_mask:
        :return:
        """
        matching_matrix=torch.zeros_like(cost)
        # 给每个真实框,对应十个候选框。相当于十个候选框对应于这个真实框
        # 第一维是真实框,第二维是候选框
        ious_in_boxes_matrix=pair_wise_ious
        n_candidate_k=min(10,ious_in_boxes_matrix.size(1))

        topk_ious,_=torch.topk(ious_in_boxes_matrix,n_candidate_k,dim=1)
        # 返回沿指定维度dim的,最大的k个值.返回结果为元组,第一个为value,第二个为indics

        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
        # 这里相当于将其进行数据压缩,即,让数据最后范围是在min和max之间
        # 取出有多少个,这里的k根据前面的取出来预选的iou最大的10个,所有iou加起来的值
        # 再对iou值进行转化

        for gt_idx in range(num_gt):
            _,pos_idx=torch.topk(
                cost[gt_idx],k=dynamic_ks[gt_idx].item(),largest=False
            )
            matching_matrix[gt_idx][pos_idx]=1.0
            # 选取cost最小的k个,进行正标记
        del topk_ious,dynamic_ks,pos_idx
        anchor_matching_gt=matching_matrix.sum(0)
        # 处理共用的情况,即,获取每个苗匡,被多少个真实框给标记了
        if(anchor_matching_gt>1).sum()>0:
            # 有共存的情况,即,每个苗匡,被超过1个真实框给标记了
            _,cost_argmin=torch.min(cost[:,anchor_matching_gt>1],dim=0)
            # 获取前面已经被超过一个真实框标记的每个苗匡,
            # 所对应的所有真实框里面,cost值最小的那个真实框的索引
            matching_matrix[:,anchor_matching_gt>1]*=0.0
            matching_matrix[cost_argmin,anchor_matching_gt>1]=1.0
        fg_mask_inboxes=matching_matrix.sum(0)>0.0
        # 记录被分配为正样本的苗匡
        num_fg=fg_mask_inboxes.sum().item()
        # 求出有多少个苗匡被分配为了正样本
        fg_mask[fg_mask.clone()]=fg_mask_inboxes
        # 对于一开始就被分配为正样本的苗匡,
        # 再根据后来的筛选判断现在还是不是正样本
        matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
        # 找到,所有被标记为正样本的苗匡的,所对应的真实框的索引

        gt_matched_classes = gt_classes[matched_gt_inds]
        # 由于苗匡的数量大于真实框的数量,顾产生的新数组
        # 相当于将原来的真实框的类别,按照苗匡的索引顺序,分发给苗匡。
        # 这些数据对应的基础部分都是一维
        pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
            fg_mask_inboxes
        ]
        return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds



  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值