2023-简单点-yolox-pytorch代码解析(三)-nets/yolo_training.py

仓库

https://github.com/bubbliiiing/yolox-pytorch
仓库

yolox网络结构

这里是引用

yolox-pytorch目录

在这里插入图片描述

在这里插入图片描述

今天详细注释yolo_training.py

import math
from copy import deepcopy
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F

def is_parallel(model):  
    # 如果模型是DP或DDP类型,则返回True  
    return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)  
  
def de_parallel(model):  
    # 去并行化一个模型:如果模型是DP或DDP类型,则返回单GPU模型  
    return model.module if is_parallel(model) else model  
  
def copy_attr(a, b, include=(), exclude=()):  
    # 从b复制属性到a,可以选择只包含[...]和排除[...]  
    for k, v in b.__dict__.items():  
        # 如果属性名在include中并且不在exclude中,并且不是私有属性(不以'_'开头),则进行复制  
        if (len(include) and k not in include) or k.startswith('_') or k in exclude:  
            continue  
        else:  
            setattr(a, k, v)

#############################################################################
class IOUloss(nn.Module):  # 定义一个名为IOUloss的类,它继承自nn.Module。  
    def __init__(self, reduction="none", loss_type="iou"):  # 初始化函数,设置默认参数。  
        super(IOUloss, self).__init__()  # 调用父类的初始化方法。  
        self.reduction = reduction  # 设置要应用的损失减少方法。  
        self.loss_type = loss_type  # 设置要使用的损失类型。  
  
    def forward(self, pred, target):  # 定义前向传播函数,接收预测和目标作为输入。  
        assert pred.shape[0] == target.shape[0]  # 确保预测和目标的批次大小相同。  
  
        pred = pred.view(-1, 4)  # 将预测张量重塑为两个维度,其中最后一个维度代表着边界框的坐标。  
        target = target.view(-1, 4)  # 将目标张量重塑为相同的形状。  
  
        # 计算预测边界框和目标边界框的交集区域的左上角和右下角坐标。  
        tl = torch.max(  
            (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)  
        )  
        br = torch.min(  
            (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)  
        )  
  
        area_p = torch.prod(pred[:, 2:], 1)  # 计算预测边界框的面积。  
        area_g = torch.prod(target[:, 2:], 1)  # 计算目标边界框的面积。  
  
        en = (tl < br).type(tl.type()).prod(dim=1)  # 计算交集区域是否存在。  
        area_i = torch.prod(br - tl, 1) * en  # 计算交集区域的面积。  
        area_u = area_p + area_g - area_i  # 计算预测边界框和目标边界框的并集区域的面积。  
        iou = (area_i) / (area_u + 1e-16)  # 计算IoU值。  
  
        if self.loss_type == "iou":  # 如果选择的损失类型是IoU。  
            loss = 1 - iou ** 2  # 计算IoU损失。  
        elif self.loss_type == "giou":  # 如果选择的损失类型是GIoU。  
            c_tl = torch.min(  
                (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)  
            )  
            c_br = torch.max(  
                (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)  
            )  
            area_c = torch.prod(c_br - c_tl, 1)  # 计算包含预测和目标边界框的最小边界框的面积。  
            giou = iou - (area_c - area_u) / area_c.clamp(1e-16)  # 计算GIoU值。  
            loss = 1 - giou.clamp(min=-1.0, max=1.0)  # 计算GIoU损失。  
  
        # 根据所选的减少方法应用损失减少。  
        if self.reduction == "mean":  
            loss = loss.mean()  # 取平均值。  
        elif self.reduction == "sum":  
            loss = loss.sum()  # 求和。  
  
        return loss  # 返回计算得到的损失值。
####################################################################################

class YOLOLoss(nn.Module):  
    """  
    YOLO损失函数类。  
    该类继承了PyTorch的nn.Module,用于计算YOLO模型的损失。  
    """  
      
    def __init__(self, num_classes, fp16, strides=[8, 16, 32]):  
        """  
        初始化函数。  
          
        参数:  
        - num_classes: 类别数量。  
        - fp16: 是否使用16位浮点数。  
        - strides: YOLO模型中的步长列表,默认为[8, 16, 32]。  
        """  
        super().__init__()  # 调用父类的初始化方法。  
        self.num_classes = num_classes  # 存储类别数量。  
        self.strides = strides  # 存储步长列表。  
  
        # 定义损失函数。  
        self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")  # 二分类交叉熵损失,不使用任何减少。  
        self.iou_loss = IOUloss(reduction="none")  # 交并比损失,不使用任何减少。  
        self.grids = [torch.zeros(1)] * len(strides)  # 根据步长数量创建网格列表,初始值均为0。  
        self.fp16 = fp16  # 是否使用16位浮点数。  
  
    def forward(self, inputs, labels=None):  
        """  
        前向传播函数。  
          
        参数:  
        - inputs: YOLO模型的输出,形状如注释所示。  
        - labels: 真实标签,可选参数。  
          
        返回:  
        - losses: 损失值。  
        """  
        outputs = []  # 存储处理后的模型输出。  
        x_shifts = []  # 存储x方向的偏移量。  
        y_shifts = []  # 存储y方向的偏移量。  
        expanded_strides = []  # 存储扩展的步长。  
  
        # 遍历每个尺度的输出和步长。  
        for k, (stride, output) in enumerate(zip(self.strides, inputs)):  
            # 获取当前尺度的输出和网格。  
            output, grid = self.get_output_and_grid(output, k, stride)  
            x_shifts.append(grid[:, :, 0])  # 添加x方向偏移量到列表。  
            y_shifts.append(grid[:, :, 1])  # 添加y方向偏移量到列表。  
            expanded_strides.append(torch.ones_like(grid[:, :, 0]) * stride)  # 创建与网格同形状的步长张量并添加到列表。  
            outputs.append(output)  # 添加处理后的输出到列表。  
  
        # 获取损失值。  
        return self.get_losses(x_shifts, y_shifts, expanded_strides, labels, torch.cat(outputs, 1))

	def get_output_and_grid(self, output, k, stride):  
	    """  
	    从YOLO模型的输出中获取处理后的输出和网格。  
	      
	    参数:  
	    - output: YOLO模型在某个尺度的输出。  
	    - k: 当前尺度的索引。  
	    - stride: 当前尺度的步长。  
	      
	    返回:  
	    - output: 处理后的输出。  
	    - grid: 网格。  
	    """  
	    grid = self.grids[k]  # 获取当前尺度的网格。  
	    hsize, wsize = output.shape[-2:]  # 获取输出的高和宽。  
	  
	    # 如果网格的尺寸与输出的尺寸不匹配,则重新计算网格。  
	    if grid.shape[2:4] != output.shape[2:4]:  
	        yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])  # 创建坐标网格。  
	        grid = torch.stack((xv, yv), 2).view(1, hsize, wsize, 2).type(output.type())  # 将坐标网格转换为与输出同类型的张量。  
	        self.grids[k] = grid  # 更新网格。  
	    grid = grid.view(1, -1, 2)  # 将网格展平。  
	  
	    # 处理输出。  
	    output = output.flatten(start_dim=2).permute(0, 2, 1)  # 将输出的形状变换为[batch_size, num_anchors, num_classes + 5]。  
	    output[..., :2] = (output[..., :2] + grid.type_as(output)) * stride  # 更新输出的x和y坐标。  
	    output[..., 2:4] = torch.exp(output[..., 2:4]) * stride  # 对输出的宽和高进行指数变换并更新。  
	    
	    return output, grid  # 返回处理后的输出和网格。
	

	def get_losses(self, x_shifts, y_shifts, expanded_strides, labels, outputs):  
	    # 从输出中获取边界框预测  
	    bbox_preds = outputs[:, :, :4]  # [batch, n_anchors_all, 4]  
	  
	    # 从输出中获取目标预测  
	    obj_preds = outputs[:, :, 4:5]  # [batch, n_anchors_all, 1]  
	  
	    # 从输出中获取类别预测  
	    cls_preds = outputs[:, :, 5:]  # [batch, n_anchors_all, n_cls]  
	  
	    # 获取总的锚点数量  
	    total_num_anchors = outputs.shape[1]  
	  
	    # 对x_shifts、y_shifts和expanded_strides进行拼接和类型转换  
	    x_shifts = torch.cat(x_shifts, 1).type_as(outputs)  # [1, n_anchors_all]  
	    y_shifts = torch.cat(y_shifts, 1).type_as(outputs)  # [1, n_anchors_all]  
	    expanded_strides = torch.cat(expanded_strides, 1).type_as(outputs)  # [1, n_anchors_all]  
	  
	    # 初始化列表,用于存储每个批次的目标、回归目标、目标掩码和前景掩码  
	    cls_targets = []  
	    reg_targets = []  
	    obj_targets = []  
	    fg_masks = []  
	  
	    # 初始化前景锚点的数量  
	    num_fg = 0.0  
	  
	    # 对每个批次进行处理  
	    for batch_idx in range(outputs.shape[0]):  
	        # 获取当前批次的真实标签数量  
	        num_gt = len(labels[batch_idx])  
	          
	        # 如果当前批次没有真实标签,则为其分配零张量  
	        if num_gt == 0:  
	            # 初始化类别目标为零张量,形状为[0, self.num_classes]  
	            cls_target = outputs.new_zeros((0, self.num_classes))  
	              
	            # 初始化回归目标为零张量,形状为[0, 4]  
	            reg_target = outputs.new_zeros((0, 4))  
	              
	            # 初始化目标预测为零张量,形状为[total_num_anchors, 1]  
	            obj_target = outputs.new_zeros((total_num_anchors, 1))  
	              
	            # 初始化前景掩码为零张量,形状为total_num_anchors,类型为bool  
	            fg_mask = outputs.new_zeros(total_num_anchors).bool()
	        else:  
			    # 如果当前批次存在真实标签  
			    # 获取当前批次的真实边界框和类别  
			    gt_bboxes_per_image = labels[batch_idx][..., :4].type_as(outputs)  # [num_gt, 4]  
			    gt_classes = labels[batch_idx][..., 4].type_as(outputs)  # [num_gt]  
			  
			    # 获取当前批次的预测结果  
			    bboxes_preds_per_image = bbox_preds[batch_idx]  # [n_anchors_all, 4]  
			    cls_preds_per_image = cls_preds[batch_idx]  # [n_anchors_all, num_classes]  
			    obj_preds_per_image = obj_preds[batch_idx]  # [n_anchors_all, 1]  
			  
			    # 通过调用get_assignments函数,获取匹配的类别、前景掩码、预测的IoU、匹配的GT索引和前景锚点数量  
			    gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg_img = self.get_assignments(  
			        num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, cls_preds_per_image, obj_preds_per_image,  
			        expanded_strides, x_shifts, y_shifts,   
			    )  
			      
			    # 清空CUDA缓存,释放不再使用的变量占用的显存  
			    torch.cuda.empty_cache()  
			  
			    # 更新前景锚点的数量  
			    num_fg += num_fg_img  
			  
			    # 根据匹配的类别生成类别目标,使用one-hot编码,并乘以预测的IoU  
			    cls_target = F.one_hot(gt_matched_classes.to(torch.int64), self.num_classes).float() * pred_ious_this_matching.unsqueeze(-1)  
			      
			    # 生成目标预测的目标,直接使用前景掩码  
			    obj_target = fg_mask.unsqueeze(-1)  
			      
			    # 根据匹配的GT索引生成回归目标  
			    reg_target = gt_bboxes_per_image[matched_gt_inds]  
			  
			    # 将当前批次的目标添加到列表中  
			    cls_targets.append(cls_target)  
			    reg_targets.append(reg_target)  
			    obj_targets.append(obj_target.type(cls_target.type()))  
			    fg_masks.append(fg_mask)  
			  
			# 将所有批次的目标拼接起来  
			cls_targets = torch.cat(cls_targets, 0)  
			reg_targets = torch.cat(reg_targets, 0)  
			obj_targets = torch.cat(obj_targets, 0)  
			fg_masks = torch.cat(fg_masks, 0)  
			  
			# 确保前景锚点的数量至少为1  
			num_fg = max(num_fg, 1)
			# 计算IoU损失,即预测的边界框与真实边界框之间的差距  
			loss_iou = (self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)).sum()  
			  
			# 计算目标预测损失,即模型预测的前景/背景与真实前景/背景之间的差距  
			loss_obj = (self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)).sum()  
			  
			# 计算类别损失,即模型预测的类别与真实类别之间的差距,只在前景锚点上计算  
			loss_cls = (self.bcewithlog_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets)).sum()  
			  
			# 设置回归损失的权重,这是为了平衡不同损失之间的相对重要性  
			reg_weight = 5.0  
			  
			# 计算总损失,它是回归损失、目标预测损失和类别损失的加权和  
			loss = reg_weight * loss_iou + loss_obj + loss_cls  
			  
			# 返回平均损失,这是通过总损失除以前景锚点的数量得到的  
			return loss / num_fg
	
			
	@torch.no_grad()  # 指示下面的函数在执行时不计算梯度,通常用于推理阶段  
	def get_assignments(self, num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, cls_preds_per_image, obj_preds_per_image, expanded_strides, x_shifts, y_shifts):  
	    # 获取前景掩码和每个真实边界框是否在预测的边界框内以及中心是否在预测边界框的中心区域的信息  
	    fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt)  
	  
	    # 通过前景掩码筛选预测的边界框、类别和对象性得分,只保留前景区域的预测结果  
	    bboxes_preds_per_image  = bboxes_preds_per_image[fg_mask]  
	    cls_preds_              = cls_preds_per_image[fg_mask]  
	    obj_preds_              = obj_preds_per_image[fg_mask]  
	    num_in_boxes_anchor     = bboxes_preds_per_image.shape[0]  
	  
	    # 计算每个真实边界框与所有预测边界框之间的交并比(IoU)  
	    pair_wise_ious      = self.bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False)  
	    # 计算交并比的对数损失  
	    pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)  
	      
	    # 扩展预测的类别和真实的类别的维度,方便后续计算类别损失  
	    # [num_gt, fg_mask, num_classes]
		# 检查是否使用了混合精度训练(fp16)  
		if self.fp16:  
		    # 如果使用了混合精度训练,则禁用自动类型转换  
		    with torch.cuda.amp.autocast(enabled=False):  
		        # 将类别预测和目标性预测转换为float类型,并调整它们的形状以匹配真实标签的形状  
		        # 使用sigmoid激活函数,并与目标性预测相乘  
		        cls_preds_ = cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()  
		          
		        # 将真实类别标签转换为one-hot编码形式,并调整其形状以匹配预测标签的形状  
		        gt_cls_per_image = F.one_hot(gt_classes.to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, num_in_boxes_anchor, 1)  
		          
		        # 计算类别预测的二元交叉熵损失,不使用任何形式的损失减少(reduction="none")  
		        pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt_(), gt_cls_per_image, reduction="none").sum(-1)  
		# 如果没有使用混合精度训练  
		else:  
		    # 执行与上面相同的操作,但不需要禁用自动类型转换  
		    cls_preds_ = cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()  
		    gt_cls_per_image = F.one_hot(gt_classes.to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, num_in_boxes_anchor, 1)  
		    pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt_(), gt_cls_per_image, reduction="none").sum(-1)  
		    del cls_preds_  # 删除不再需要的变量以释放内存  
		  
		# 计算总体损失,包括类别损失、交并比损失和中心区域损失  
		cost = pair_wise_cls_loss + 3.0 * pair_wise_ious_loss + 100000.0 * (~is_in_boxes_and_center).float()  
		  
		# 使用动态K匹配算法为预测的边界框分配真实标签,并返回匹配结果和相关信息  
		num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)  
		  
		# 删除不再需要的变量以释放内存  
		del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss  
		  
		# 返回匹配结果和相关信息  
		return gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg
	
	def bboxes_iou(self, bboxes_a, bboxes_b, xyxy=True):  
	    # 检查边界框是否具有正确的形状  
	    if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:  
	        raise IndexError  
	  
	    # 如果边界框是以xyxy格式给出的  
	    if xyxy:  
	        # 计算两组边界框的左上角和右下角坐标  
	        tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])  # 左上角坐标  
	        br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])  # 右下角坐标  
	          
	        # 计算第一组和第二组边界框的面积  
	        area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)  # A组面积  
	        area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)  # B组面积  
	    else:  
	        # 如果边界框是以中心点和宽高格式给出的  
	        tl = torch.max(  
	            (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),  # A组左上角坐标  
	            (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),              # B组左上角坐标  
	        )  
	        br = torch.min(  
	            (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),  # A组右下角坐标  
	            (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),              # B组右下角坐标  
	        )  
	  
	        # 计算第一组和第二组边界框的面积  
	        area_a = torch.prod(bboxes_a[:, 2:], 1)  # A组面积  
	        area_b = torch.prod(bboxes_b[:, 2:], 1)  # B组面积  
	  
	    # 判断边界框是否有交集,并计算交集区域的面积  
	    en = (tl < br).type(tl.type()).prod(dim=2)  # 判断是否有交集的掩码  
	    area_i = torch.prod(br - tl, 2) * en  # 交集区域的面积  
	      
	    # 计算并返回交并比(IoU)  
	    return area_i / (area_a[:, None] + area_b - area_i)  # IoU值




	def get_in_boxes_info(self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt, center_radius = 2.5):  
	    # 获取每张图像上扩展的步长  
	    expanded_strides_per_image  = expanded_strides[0]  
	      
	    # 根据x_shifts和y_shifts计算每个预测框的中心点坐标  
	    x_centers_per_image         = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1)  
	    y_centers_per_image         = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1)  
	  
	    # 根据ground truth的边界框计算其左上角和右下角的坐标  
	    gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors)  
	    gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors)  
	    gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors)  
	    gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors)  
	  
	    # 计算预测框中心点与ground truth框的左上角和右下角之间的偏移量  
	    b_l = x_centers_per_image - gt_bboxes_per_image_l  
	    b_r = gt_bboxes_per_image_r - x_centers_per_image  
	    b_t = y_centers_per_image - gt_bboxes_per_image_t  
	    b_b = gt_bboxes_per_image_b - y_centers_per_image  
	      
	    # 将这四个偏移量堆叠起来,形成一个[num_gt, n_anchors_all, 4]的张量,方便后续处理  
	    bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
	    # 判断预测框是否在真实框内,得到一个[num_gt, n_anchors_all]的布尔张量  
		is_in_boxes     = bbox_deltas.min(dim=-1).values > 0.0  
		  
		# 判断所有预测框中是否有任何一个在真实框内,得到一个[n_anchors_all]的布尔张量  
		is_in_boxes_all = is_in_boxes.sum(dim=0) > 0  
		  
		# 根据中心半径和扩展的步长,调整真实框的左边、右边、上边和下边界  
		gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0)  
		gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0)  
		gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0)  
		gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0)  
		  
		# 计算预测框中心点与调整后的真实框边界之间的差异,得到一个[num_gt, n_anchors_all, 4]的张量  
		c_l = x_centers_per_image - gt_bboxes_per_image_l  
		c_r = gt_bboxes_per_image_r - x_centers_per_image  
		c_t = y_centers_per_image - gt_bboxes_per_image_t  
		c_b = gt_bboxes_per_image_b - y_centers_per_image  
		center_deltas       = torch.stack([c_l, c_t, c_r, c_b], 2)  
		  
		# 判断预测框的中心点是否在调整后的真实框内,得到一个[num_gt, n_anchors_all]的布尔张量  
		is_in_centers       = center_deltas.min(dim=-1).values > 0.0  
		  
		# 判断所有预测框的中心点中是否有任何一个在调整后的真实框内,得到一个[n_anchors_all]的布尔张量  
		is_in_centers_all   = is_in_centers.sum(dim=0) > 0
		def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):  
		    # cost: 损失值矩阵,大小为 [num_gt, fg_mask]  
		    # pair_wise_ious: 真实框与预测框之间的IoU矩阵,大小为 [num_gt, fg_mask]  
		    # gt_classes: 真实框的类别标签,大小为 [num_gt]  
		    # num_gt: 真实框的数量  
		    # fg_mask: 前景掩码,大小为 [n_anchors_all]  
		    # matching_matrix: 匹配矩阵,大小为 [num_gt, fg_mask]  
		    matching_matrix = torch.zeros_like(cost)  
		  
		    # 选取IoU最大的n_candidate_k个点  
		    # 通过求和来判断应该有多少点用于该框预测  
		    # topk_ious: IoU最大的前n_candidate_k个点,大小为 [num_gt, n_candidate_k]  
		    # dynamic_ks: 动态K值,即每个真实框选取的点的数量,大小为 [num_gt]  
		    n_candidate_k = min(10, pair_wise_ious.size(1))  
		    topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1)  
		    dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)  
		  
		    for gt_idx in range(num_gt):  
		        # 给每个真实框选取最小的动态k个点  
		        # pos_idx: 选取的点的索引  
		        _, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)  
		        matching_matrix[gt_idx][pos_idx] = 1.0  
		    del topk_ious, dynamic_ks, pos_idx  # 释放变量,节省内存  
		  
		    # anchor_matching_gt: 每个特征点匹配的真实框数量,大小为 [fg_mask]  
		    anchor_matching_gt = matching_matrix.sum(0)  
		    if (anchor_matching_gt > 1).sum() > 0:  
		        # 当某一个特征点指向多个真实框的时候  
		        # 选取cost最小的真实框。  
		        # 当某个特征点与多个真实框匹配时,选取cost最小的真实框  
				_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)  
				# 将匹配矩阵中对应位置的值设为0  
				matching_matrix[:, anchor_matching_gt > 1] *= 0.0  
				# 将匹配矩阵中对应位置设为1,表示该特征点与cost最小的真实框匹配  
				matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0  
		  
			# fg_mask_inboxes: 在真实框内的特征点掩码,大小为 [fg_mask]  
			# num_fg: 正样本的特征点个数  
			fg_mask_inboxes = matching_matrix.sum(0) > 0.0  
			num_fg = fg_mask_inboxes.sum().item()  
			  
			# 更新fg_mask的值,只保留在真实框内的特征点  
			fg_mask[fg_mask.clone()] = fg_mask_inboxes  
			  
			# 获得与特征点匹配的真实框的索引  
			matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)  
			# 获得与特征点匹配的真实框的物种类别  
			gt_matched_classes = gt_classes[matched_gt_inds]  
			  
			# 计算与特征点匹配的真实框的预测IoU  
			pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[fg_mask_inboxes]  
			  
			return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
###################################
class ModelEMA:  
    """  
    更新的指数移动平均(EMA)类,从 https://github.com/rwightman/pytorch-image-models 获取。  
    保存模型中state_dict(参数和缓冲区)的移动平均值。  
    EMA的详细信息参见 https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage  
    """  
  
    def __init__(self, model, decay=0.9999, tau=2000, updates=0):  
        """  
        初始化函数。  
        :param model: 要应用EMA的模型。  
        :param decay: EMA的衰减率,默认为0.9999。  
        :param tau: 用于计算衰减率的参数,默认为2000。  
        :param updates: EMA的更新次数,默认为0。  
        """  
        # 创建一个EMA模型的深度拷贝,并设置为评估模式。此处的模型是FP32精度。  
        self.ema = deepcopy(de_parallel(model)).eval()  # FP32 EMA  
          
        # 如果模型的下一个参数不在CPU上,则将EMA模型设置为FP16精度。  
        # if next(model.parameters()).device.type != 'cpu':  
        #     self.ema.half()  # FP16 EMA  
          
        # 记录EMA的更新次数。  
        self.updates = updates  # number of EMA updates  
          
        # 定义衰减函数,该函数根据当前的更新次数计算衰减率。  
        self.decay = lambda x: decay * (1 - math.exp(-x / tau))  # decay exponential ramp (to help early epochs)  
          
        # 将EMA模型的所有参数设置为不需要梯度。  
        for p in self.ema.parameters():  
            p.requires_grad_(False)  
  
    def update(self, model):  
        """  
        更新EMA模型的参数。  
        :param model: 用于更新EMA模型的原始模型。  
        """  
        # 在不需要计算梯度的情况下更新EMA参数。  
        with torch.no_grad():  
            self.updates += 1  # 更新次数加1  
            d = self.decay(self.updates)  # 计算当前的衰减率  
  
            # 获取原始模型的state_dict。  
            msd = de_parallel(model).state_dict()  # model state_dict  
              
            # 遍历EMA模型的state_dict,并根据原始模型和当前的衰减率更新每个参数。  
            for k, v in self.ema.state_dict().items():  
                if v.dtype.is_floating_point:  # 如果参数是浮点数类型  
                    v *= d  # 先乘以衰减率  
                    v += (1 - d) * msd[k].detach()  # 然后加上原始模型的对应参数乘以(1 - 衰减率)  
  
    def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):  
        """  
        更新EMA模型的属性。  
        :param model: 用于更新EMA属性的原始模型。  
        :param include: 需要更新的属性列表,默认为空。  
        :param exclude: 不需要更新的属性列表,默认为'process_group'和'reducer'。  
        """  
        # 使用copy_attr函数更新EMA模型的属性,根据include和exclude列表确定需要更新的属性。  
        copy_attr(self.ema, model, include, exclude)


# 定义一个函数weights_init,该函数用于初始化网络的权重  
# net: 要初始化的神经网络模型  
# init_type: 初始化类型,可选'normal'、'xavier'、'kaiming'、'orthogonal'  
# init_gain: 初始化的增益值,用于控制权重的初始大小  
def weights_init(net, init_type='normal', init_gain=0.02):  
      
    # 定义内部函数init_func,该函数用于根据传入的初始化类型和增益值来初始化网络中的每个层  
    def init_func(m):  
          
        # 获取当前层的类名  
        classname = m.__class__.__name__  
          
        # 如果当前层有weight属性且类名中包含'Conv',则表示该层是卷积层  
        if hasattr(m, 'weight') and classname.find('Conv') != -1:  
              
            # 根据初始化类型初始化权重  
            if init_type == 'normal':  
                torch.nn.init.normal_(m.weight.data, 0.0, init_gain)  # 正态分布初始化  
            elif init_type == 'xavier':  
                torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)  # Xavier正态分布初始化  
            elif init_type == 'kaiming':  
                torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')  # Kaiming正态分布初始化  
            elif init_type == 'orthogonal':  
                torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)  # 正交初始化  
            else:  
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)  # 如果初始化类型不在上述类型中,则抛出异常  
          
        # 如果类名中包含'BatchNorm2d',则表示该层是批量归一化层  
        elif classname.find('BatchNorm2d') != -1:  
            torch.nn.init.normal_(m.weight.data, 1.0, 0.02)  # 正态分布初始化,均值为1,标准差为0.02  
            torch.nn.init.constant_(m.bias.data, 0.0)  # 常数初始化,偏置设为0  
      
    # 打印初始化的类型信息  
    print('initialize network with %s type' % init_type)  
      
    # 对网络中的每一层应用初始化函数  
    net.apply(init_func)


# 定义一个函数,用于获取学习率调度器  
def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio=0.05, warmup_lr_ratio=0.1, no_aug_iter_ratio=0.05, step_num=10):  
      
    # 定义一个函数,实现YOLOX的warmup cosine学习率策略  
    def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):  
        if iters <= warmup_total_iters:  
            # 如果当前迭代次数小于等于warmup的总迭代次数,则使用二次函数进行warmup  
            lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start  
        elif iters >= total_iters - no_aug_iter:  
            # 如果当前迭代次数大于等于总迭代次数减去不进行数据增强的迭代次数,则学习率降为最小学习率  
            lr = min_lr  
        else:  
            # 否则,使用cosine退火策略更新学习率  
            lr = min_lr + 0.5 * (lr - min_lr) * (  
                1.0 + math.cos(math.pi * (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))  
            )  
        return lr  
  
    # 定义一个函数,实现阶梯式学习率策略  
    def step_lr(lr, decay_rate, step_size, iters):  
        if step_size < 1:  
            raise ValueError("step_size必须大于1。")  
        n = iters // step_size  
        out_lr = lr * decay_rate ** n  
        return out_lr  
  
    # 根据传入的学习率衰减类型选择对应的学习率策略  
    if lr_decay_type == "cos":  
        warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3)  
        warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6)  
        no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15)  
        func = partial(yolox_warm_cos_lr, lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)  
    else:  
        decay_rate = (min_lr / lr) ** (1 / (step_num - 1))  
        step_size = total_iters / step_num  
        func = partial(step_lr, lr, decay_rate, step_size)  
  
    return func


# 定义一个函数,该函数用于设置优化器的学习率  
def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):  
    # 通过传入的lr_scheduler_func函数和当前的epoch计算出一个新的学习率  
    lr = lr_scheduler_func(epoch)  
      
    # 遍历优化器的所有参数组  
    for param_group in optimizer.param_groups:  
        # 将每个参数组的学习率设置为上面计算出的新学习率  
        param_group['lr'] = lr

tips

使用IoU的平方(IoU^2)作为损失函数,

而不是直接使用IoU,主要是出于稳定性和优化考虑。具体原因如下:

  • 梯度稳定性:当IoU值很低时,直接使用IoU作为损失可能会导致梯度不稳定。因为IoU的值范围是[0, 1],当IoU接近0时,其梯度也可能非常大,这可能导致训练过程中的不稳定。通过对IoU取平方,可以降低这种不稳定性。
  • 优化方便:平方损失函数(例如均方误差MSE)在优化时具有一些便利性,例如它具有明确的导数表达式,便于梯度下降等优化算法的计算。此外,平方损失函数通常也更容易收敛。
    因此,使用1 - iou ** 2作为损失函数,而不是1 - iou,可以在一定程度上提高训练的稳定性和效率。
  • 13
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

万物琴弦光锥之外

给个0.1,恭喜老板发财

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值