在上一篇博文中,我们讲解了YOLOv8
实例分割的训练过程,已将前向传播过程分析完毕,那么,接下来便是损失计算过程了。
文章目录
训练整体流程
获得预测结果与真值后,即可计算损失。
整体流程如下:
这里的batch我们使用的是单张图像,所以在计算损失时,真值我们可以直接选出。
详细结构图如下:
这里我们主要是对TaskAlignedAssigner样本匹配策略进行详细解释
预测结果preds
如下:
真值batch
如下:
分割整体损失函数
v8SegmentationLoss
的计算过程如下,从最终的结果来看,其计算了四个损失,分别是目标预测框损失、mask
损失、类别损失以及DEL
损失,博主已将每段代码的结果标注在对应的代码位置。同时,在损失计算过程中不可避免的需要使用其他方法,博主将一些较为重要的方法也罗列出来了。
def __call__(self, preds, batch):
"""Calculate and return the loss for the YOLO model."""
loss = torch.zeros(4, device=self.device) # box, cls, dfl,mask,共 4 个
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1] #feature是目标检测头输出的三个特征图list类型,pred_masks为torch.Size([4, 32, 8400]) ,oroto为torch.Size([4, 32, 160, 160])
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1
)#拆分目标检测特征图结果,并融合在一起,得到(4,64,8400)与(4,80,8400)
# B, grids, .. 维度转换
pred_scores = pred_scores.permute(0, 2, 1).contiguous()#(4,8400,80)
pred_distri = pred_distri.permute(0, 2, 1).contiguous()#(4,8400,64)
pred_masks = pred_masks.permute(0, 2, 1).contiguous()#(4,8400,32)
dtype = pred_scores.dtype #torch.float16
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# Targets
try:
batch_idx = batch["batch_idx"].view(-1, 1)
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)#torch.Size([22, 6]) 即batchid 类别id x y w h 共6个数
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])#(batch,最大目标数量,5) 5=1+4
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy (4,7,1)(4,7,4)
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
# 这行代码的作用是创建一个掩码张量 mask_gt,用于标识哪些目标是有效的(即有实际的边界框),哪些目标是无效的(即填充的零)。
# sum(2, keepdim=True):沿着第 2 维(即每个边界框的坐标维度)进行求和,保持维度不变。
# 这样会将每个边界框的 4 个坐标值相加,如果该目标是有效的边界框,和将大于 0;如果该目标是填充的零,和将等于 0。
except RuntimeError as e
# Pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
pred_scores.detach().sigmoid(),
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor,
gt_labels,#(4,7,1)7指这个四个batch中
gt_bboxes,#(4,7,4)
mask_gt,
)
"""
self.assigner即根据分类与回归的分数加权的分数选择正样本 输入(共6个输入值)
1.pred_scores:表示模型预测的每个锚点位置的分类分数 形状为:(batch_size, 8400, 80)
2.(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype)
pred_bboxes:解码后的边界框坐标,形状为 [batch_size, 8400, 4]。
detach():从计算图中分离出预测的边界框,避免反向传播时更新它们。
* stride_tensor:将边界框坐标乘以步幅张量 stride_tensor,恢复到原图尺度。
type(gt_bboxes.dtype):将边界框的类型转换为与 gt_bboxes 相同的类型。
3.anchor_points * stride_tensor
anchor_points:锚点位置,形状为 [num_anchors, 2]
将锚点位置乘以步幅张量 stride_tensor,恢复到原图尺度
4.gt_labels:真实目标的类别标签
5.gt_bboxes:真实目标的边界框坐标
6.mask_gt:掩码张量,标识哪些目标是有效的
"""
"""
输出(共5个返回值)
target_labels, 这里用_代替了:形状为 [batch_size, num_anchors],包含每个锚点的目标标签
target_bboxes:形状为 [batch_size, num_anchors, 4],包含每个锚点的目标边界框
target_scores:形状为 [batch_size, num_anchors, num_classes],包含每个锚点的目标得分。
fg_mask:形状为 [batch_size, num_anchors],标识哪些锚点是前景(即有效的目标, 正样本)。
fg_mask作用:标识哪些锚点是前景(正样本),哪些是背景(负样本)。
正样本:锚点被分配给一个真实的目标,表示这个锚点负责检测这个目标。
负样本:锚点未被分配给任何目标,表示这个锚点不负责检测任何目标
target_gt_idx, 这里用_代替了:形状为 [batch_size, num_anchors],包含每个锚点对应的真实目标索引。
"""
target_scores_sum = max(target_scores.sum(), 1)
# Cls loss
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE 即求分类损失
if fg_mask.sum():
# Bbox loss
loss[0], loss[3] = self.bbox_loss( #求Box损失和DEL损失
pred_distri,
pred_bboxes,
anchor_points,
target_bboxes / stride_tensor,
target_scores,
target_scores_sum,
fg_mask,
)
# Masks loss
masks = batch["masks"].to(self.device).float()
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
loss[1] = self.calculate_segmentation_loss(
fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
)
# WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
else:
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
loss[0] *= self.hyp