一、loss定义
位于engine.py文件的train_before_loop()函数下:
from yolov6.models.loss import ComputeLoss
self.compute_loss = ComputeLoss(num_classes=self.data_dict['nc'],
ori_img_size=self.img_size,
use_dfl=self.cfg.model.head.use_dfl,
reg_max=self.cfg.model.head.reg_max,
iou_type=self.cfg.model.head.iou_type)
二、loss调用
preds, s_featmaps = self.model(images)
# print(preds[0][i].shape) # [[bs,64,80,80] [bs,128,40,40] [bs,256,20,20]] --- feats
# print(preds[1].shape) # [bs,8400,nc] --- pred_scores
# print(preds[2].shape) # [bs,8400,4] --- pred_distri
# print(targets.shape) # [n,6]
total_loss, loss_items = self.compute_loss(preds, targets, epoch_num)
三、ComputeLoss
class ComputeLoss:
"""Loss computation func."""
def __init__(self, fpn_strides=[8, 16, 32], grid_cell_size=5.0, grid_cell_offset=0.5, num_classes=80,
ori_img_size=640, warmup_epoch=0, use_dfl=True, reg_max=16, iou_type='giou',
loss_weight={'class': 1.0, 'iou': 2.5, 'dfl': 0.5}):
self.fpn_strides = fpn_strides
self.grid_cell_size = grid_cell_size
self.grid_cell_offset = grid_cell_offset
self.num_classes = num_classes
self.ori_img_size = ori_img_size
self.warmup_epoch = warmup_epoch
self.warmup_assigner = ATSSAssigner(9, num_classes=self.num_classes)
self.formal_assigner = TaskAlignedAssigner(topk=13, num_classes=self.num_classes, alpha=1.0, beta=6.0)
self.use_dfl = use_dfl
self.reg_max = reg_max
self.proj = nn.Parameter(torch.linspace(0, self.reg_max, self.reg_max + 1), requires_grad=False)
self.iou_type = iou_type
self.varifocal_loss = VarifocalLoss().cuda()
self.bbox_loss = BboxLoss(self.num_classes, self.reg_max, self.use_dfl, self.iou_type).cuda()
self.loss_weight = loss_weight
def __call__(self, outputs, targets, epoch_num):
# feats: [[bs,64,80,80], [bs,128,40,40], [bs,256,20,20]]
# pred_scores: [bs,8400,nc]
# pred_distri: [bs,8400,4]
feats, pred_scores, pred_distri = outputs
# 8400*4 8400*2 3*1 8400*1
anchors, anchor_points, n_anchors_list, stride_tensor = \
generate_anchors(feats, self.fpn_strides, self.grid_cell_size, self.grid_cell_offset,
device=feats[0].device)
assert pred_scores.type() == pred_distri.type()
gt_bboxes_scale = torch.full((1, 4), self.ori_img_size).type_as(pred_scores) # [1,4]
batch_size = pred_scores.shape[0]
# targets [n,6] --> [bs, max_n, 5] 5 is [cls,x1,y1,x2,y2]
targets = self.preprocess(targets, batch_size, gt_bboxes_scale)
gt_labels = targets[:, :, :1] # labels
gt_bboxes = targets[:, :, 1:] # xyxy
mask_gt = (gt_bboxes.sum(-1, keepdim=True) > 0).float() # [bs, max_n, 1]
# pboxes pred_distri: [bs,8400,4]
anchor_points_s = anchor_points / stride_tensor # restriction 8400*2
pred_bboxes = self.bbox_decode(anchor_points_s, pred_distri) # xyxy [bs,8400,4]
# pred_scores:[bs,8400,nc] pred_bboxes:[bs,8400,4] anchor_points:[8400,2]
# gt_labels:[bs,max_n,1] gt_bboxes:[bs,max_n,4] mask_gt:[bs,max_n,1]
if epoch_num < self.warmup_epoch:
target_labels, target_bboxes, target_scores, fg_mask = \
self.warmup_assigner(anchors, n_anchors_list, gt_labels, gt_bboxes, mask_gt,
pred_bboxes.detach() * stride_tensor)
else:
target_labels, target_bboxes, target_scores, fg_mask = \
self.formal_assigner(pred_scores.detach(), pred_bboxes.detach() * stride_tensor, anchor_points,
gt_labels, gt_bboxes, mask_gt) # bs*8400 bs*8400*4 bs*8400*nc bs*8400
# rescale bbox
target_bboxes /= stride_tensor # bs*8400*4
# cls loss pred_scores:[bs,8400,nc] target_scores:[bs,8400,nc]
target_labels = torch.where(fg_mask > 0, target_labels, torch.full_like(target_labels, self.num_classes))
one_hot_label = F.one_hot(target_labels.long(), self.num_classes + 1)[..., :-1] # [bs,8400,nc]
loss_cls = self.varifocal_loss(pred_scores, target_scores, one_hot_label)
target_scores_sum = target_scores.sum()
loss_cls /= target_scores_sum
# bbox loss pred_distri:[bs,8400,4] pred_bboxes:[bs,8400,4] anchor_points_s:[8400,2]
# target_bboxes:[bs,8400,4] target_scores:[bs,8400,nc] fg_mask:[bs,8400]
loss_iou, loss_dfl = self.bbox_loss(pred_distri, pred_bboxes, anchor_points_s, target_bboxes,
target_scores, target_scores_sum, fg_mask)
loss = self.loss_weight['class'] * loss_cls + self.loss_weight['iou'] * loss_iou + \
self.loss_weight['dfl'] * loss_dfl
return loss, \
torch.cat(((self.loss_weight['iou'] * loss_iou).unsqueeze(0),
(self.loss_weight['dfl'] * loss_dfl).unsqueeze(0),
(self.loss_weight['class'] * loss_cls).unsqueeze(0))).detach()