yolov4_u5版复现—5. compute_loss

  1. build_targets 选取正样本,扩充正样本
def build_targets(pred, targets, model):
    # build targets for compute_loss(),  
    #  targets [number_targets, (image_batch_number,class,x,y,w,h) ]
    # targets 一个batch中所有图像中的所有目标的集合
    detect_layer = model.module.model[-1] \
        if type(model) in (torch.nn.parallel.DataParallel, torch.nn.parallel.DistributedDataParallel) else \
    number_anchor_per_pixel, number_targets = detect_layer.number_anchor_per_pixel, targets.shape[0]

    targets_class, targets_box, indices, anchors = [], [], [], []
    gain = torch.ones(6, device=targets.device)   # normalized to gridspace gain
    # offset_direction 扩充正样本时使用
    offset_direction = torch.tensor(((-1, 0), (0, -1), (1, 0), (0, 1)), device=targets.device).float()
    # anchor_targets 表示3种尺寸anchor和所有的target box一一对应
    anchor_targets = \
        torch.arange(number_anchor_per_pixel).reshape((number_anchor_per_pixel, 1)).repeat(1, number_targets)

    offset_threshold = 0.5
    # 3 scale 输出
    for i in range(detect_layer.number_detection_layer):
        anchor = detect_layer.anchors[i]  # anchor.shape [3,2]
        gain[2:] = torch.tensor(pred[i].shape)[[3, 2, 3, 2]]  # xyxy
        # 将targets 映射到输出特征图上
        targets, offsets = targets * gain, 0   # scale from normalization to detection map

        if number_targets:
        	# 根据 targets和anchor的w,h比值的最大值小于4来确定 正样本targets-anchor对
            # target[None, number_targets, (w,h)] / anchor[number_anchor_per_pixel, None, (w,h)]
            # ratio_wh [number_anchor_per_pixel, number_targets, (ratio_w, ratio_h)]
            ratio_wh = targets[None, :, 4:] / anchor[:, None]  # ratio w or h between targets and anchor
            # positive_anchor_to_target  [number_anchor_per_pixel, number_targets] type: bool
            positive_anchor_to_target = torch.max(ratio_wh, 1. / ratio_wh).max(2)[0] < model.hyp['anchor_t']

            # positive samples
            anchor_targets, targets = anchor_targets[positive_anchor_to_target], \
                                      targets.repeat(number_anchor_per_pixel, 1, 1)[positive_anchor_to_target]

            # expand positive samples threefold(original, (left, up) or (right, down))
            grid_xy = targets[:, 2:4]
            # 如果targets中心点位于某个网格左上方(offset_threshold=0.5),则除过此targets对应的anchor,
            # 再增加此targets中心点所在网格的左边网格和上方网格中anchor为正样本
            # 例如原本targets中心点坐标为(59.1, 10.4),则其对应网格为(59,10),
            # 其向左方偏移过程为 grid_ij = ((59.1,10.4)+(-0.5, 0)).long() = (58, 10)
            # 其向上方偏移过程为 grid_ij = ((59.1,10.4)+(0, -0.5)).long() = (59, 9)
            # anchor扩充为(59,10), (58, 10),(59, 9)三个网格中对应的anchor,但其对应的targets都是原来的targets(位置不变),
            # 即(grid_xy - grid_ij, grid_wh)中grid_xy永远不变,三个网格中的anchor对应的targets变成
            # ((59.1, 10.4) - (59,10)) = (0.1, 0.4),  ((59.1, 10.4) - (58,10)) = (1.1, 0.4),  ((59.1, 10.4) - (59,9)) = (0.1, 1.4)

			# 如果targets中心点位于某个网格右下方(1 - offset_threshold=0.5),则除过此targets对应的anchor,
            # 再增加此targets中心点所在网格的右边网格和下方网格中anchor为正样本,过程同上
            left_offset, up_offset = ((grid_xy % 1. < offset_threshold) & (grid_xy > 1.)).T
            right_offset, down_offset = ((grid_xy % 1. > (1 - offset_threshold)) & (grid_xy < (gain[[2, 3]] - 1.))).T

            anchor_targets = torch.cat((anchor_targets, anchor_targets[left_offset], anchor_targets[up_offset],
                                        anchor_targets[right_offset], anchor_targets[down_offset]), 0)
            targets = torch.cat((targets, targets[left_offset], targets[up_offset], targets[right_offset],
                                 targets[down_offset]), 0)
            original_position = torch.zeros_like(grid_xy)
            offsets = torch.cat((original_position,
                                original_position[left_offset] + offset_direction[0],
                                original_position[up_offset] + offset_direction[1],
                                original_position[right_offset] + offset_direction[2],
                                original_position[down_offset] + offset_direction[3]), 0) * offset_threshold

        grid_wh = targets[:, 4:]
        grid_xy = targets[:, 2:4]
        grid_ij = (targets[:, 2:4] + offsets).long()  # grid points
        grid_i, grid_j = grid_ij.T  # grid xy indices

        targets_class.append(targets[:, 1].long().T)
        targets_box.append(torch.cat((grid_xy - grid_ij, grid_wh), 1))
        indices.append((targets[:, 0].long().T, anchor_targets, grid_j, grid_i))  # image, anchor, grid indices
    return targets_class, targets_box, indices, anchors
  1. compute_loss 计算损失
def compute_loss(pred, targets, model):
    device = targets.device
    ft = torch.cuda.FloatTensor if pred[0].is_cuda else torch.Tensor()
    loss_box, loss_cls, loss_obj = ft([0]).to(device), ft([0]).to(device), ft([0]).to(device)

    # match anchor to target and create positive sample
    targets_class, targets_box, indices, anchors = build_targets(pred, targets, model)
	# 这里不是很明白为什么pos_weight等于一个数值,而不是多个类别组成的向量
    reduction = 'mean'
    BCE_cls = torch.nn.BCEWithLogitsLoss(pos_weight=ft([model.hyp['cls_pw']]), reduction=reduction).to(device)
    BCE_obj = torch.nn.BCEWithLogitsLoss(pos_weight=ft([model.hyp['obj_pw']]), reduction=reduction).to(device)

    # label smooth
    class_positive, class_negative = smooth_BCE(eps=0.0)

    # balance
    balance = [4.0, 1.0, 0.4] if len(pred) == 3 else [4.0, 1.0, 0.4, 0.1]
    for i, pred_i in enumerate(pred):       # layer index, layer predictions
        positive_img_index, positive_anch_index, positive_grid_y, positive_grid_x = indices[i]
        number_targets = positive_img_index.shape[0]

        targets_obj = torch.zeros_like(pred_i[..., 4]).to(device)

        if number_targets:
            # prediction subset corresponding to targets
            pred_positive = pred_i[positive_img_index, positive_anch_index, positive_grid_y, positive_grid_x]
            # regression loss
            pred_xy = pred_positive[:, 0:2].sigmoid() * 2 - 0.5
            pred_wh = (pred_positive[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
            pred_box = torch.cat((pred_xy, pred_wh), 1).to(device)
            iou = bbox_iou(pred_box.t(), targets_box[i], x1y1x2y2=False, iou_type='MIOU-C')
            loss_box += (1.0 - iou).sum() if reduction == 'sum' else (1.0 - iou).mean()

            # class loss
            if model.number_class > 1:  # cls loss (only if multiple classes)
                targets_cls = torch.full_like(pred_positive[:, 5:], class_negative).to(device)
                targets_cls[range(number_targets), targets_class[i]] = class_positive
                pre_cls = pred_positive[:, 5:]
                loss_cls += BCE_cls(pre_cls, targets_cls)

            # object loss
            targets_obj[positive_img_index, positive_anch_index, positive_grid_y, positive_grid_x] =\
                (1.0 - model.iou_ratio) + model.iou_ratio * iou.detach().clamp(0).type(targets_obj.dtype)
        pred_obj = pred_i[..., 4]
        loss_obj += BCE_obj(pred_obj, targets_obj) * balance[i]

    s = 3 / len(pred)   # output count scaling
    batch_size = targets_obj.shape[0]
    loss_box *= model.hyp['giou'] * s
    loss_cls *= model.hyp['cls'] * s
    loss_obj *= model.hyp['obj'] * s * (1.4 if np == 4 else 1.)
    # total loss of an image(loss_obj) or an positive-anchor-target(loss_box, loss_cls)
    loss = loss_box + loss_cls + loss_obj
    return loss * batch_size, torch.cat((loss_box, loss_obj, loss_cls, loss)).detach()
