1.总结
第三部分,主要介绍一下网络输出的数据和我们标注的标签之间的怎么求loss,然后反向传播给到网络,去训练网络,但是我们要先研究loss到底需要什么数据
2.标签分配策略 task aligned assigner
ATSS
https://blog.csdn.net/qq_39592053/article/details/127972923
https://blog.csdn.net/u012863603/article/details/128816715
在代码中就是TaskAlignedAssigner
我们主要研究forward函数
首先我用大白话梳理一遍,然后再看细节,假设我的这组数据batch=1 ,且只有一个类 cls=1 ,这个图片只有2个检测物 ,看上一篇的托盘孔洞图片。 5376=8080+4040+20*20 ,
1.get_pos_mask
(1)get_box_metrics
用预测的bboxes和 2个gt_bboxes分别做iou得到voerlaps (1,2,5376)1是batch 、 2 是有两个gt_box。然后用得到的voerlaps 和 预测的得分相乘(复制一份),得到align_metric(1,2,5376)
(2)select_candidates_in_gts
用生成的anchors 和gt_box做计算,筛选在gt_box中心的anchor 。如果在距离gt_box的4个边的最小值大于0就保存最小值,否则保存0 .得到一组数据 mask_in_gts(1,2,5376)1是batch 、 2 是有两个gt_box
(3)select_topk_candidates
输入是 align_metric预测得分*mask_in_gts( anchor距离最小值 或0)=metrics
筛选metrics得分最高的10个topk_idxs 得到他们的index (1,2,10) 然后把这个数据做one-hot 得到(1,2,5376) mask_topk
最后 返回3个数据 mask_pos 其实就是mask_topk *mask_in_gts(20个anchor位置)、align_metric、overlaps
2.select_highest_overlaps
fg_mask (1,5376) 20个anchor的位置为1 其他为0
target_gt_idx (1,5376) 20个anchor 其实对应2个 gt_box 所以5376个位置上保存的是anchor对应拿一个box 所以 理论上应该是10个0 表示box0 10个1 表示box1 其他的都是背景应该啥都没有,但是数据上也用0表示,所以就变成了5366个0 和10个1了 ,个人感觉有点问题
mas_pos 没变
3.get_targets
因为只有一个cls 所以gt_label 为( 0,0 ) target_labels 全为0
target_bboxes 是gt_box 乘以target_gt_idx 也就是5366个box1 10 个box2
target_scores 是target_labels 全变成0 然后判断fg_socres_mask>0 为1 所以有20个 1
返回 target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
class TaskAlignedAssigner(nn.Module):
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9, roll_out_thr=0):
super().__init__()
self.topk = topk
self.num_classes = num_classes
self.bg_idx = num_classes
self.alpha = alpha
self.beta = beta
self.eps = eps
self.roll_out_thr = roll_out_thr
@torch.no_grad()
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
"""This code referenced to
https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
Args:
pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
anc_points (Tensor): shape(num_total_anchors, 2)
gt_labels (Tensor): shape(bs, n_max_boxes, 1)
gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
mask_gt (Tensor): shape(bs, n_max_boxes, 1)
Returns:
target_labels (Tensor): shape(bs, num_total_anchors)
target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
fg_mask (Tensor): shape(bs, num_total_anchors)
"""
self.bs = pd_scores.size(0)
self.n_max_boxes = gt_bboxes.size(1)
self.roll_out = self.n_max_boxes > self.roll_out_thr if self.roll_out_thr else False
if self.n_max_boxes == 0:
device = gt_bboxes.device
return (torch.full_like(pd_scores[..., 0], self.bg_idx).to(device), torch.zeros_like(pd_bboxes).to(device),
torch.zeros_like(pd_scores).to(device), torch.zeros_like(pd_scores[..., 0]).to(device),
torch.zeros_like(pd_scores[..., 0]).to(device))
mask_pos, align_metric, overlaps = self.get_pos_mask(pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points,
mask_gt)
target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
# assigned target
target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
# normalize
align_metric *= mask_pos # 只保留20个 预测的得分
pos_align_metrics = align_metric.amax(axis=-1, keepdim=True) # b, max_num_obj 2个最大的分类得分
pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True) # b, max_num_obj 2个最大的iou得分
norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
target_scores = target_scores * norm_align_metric
return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
# get anchor_align metric, (b, max_num_obj, h*w)
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes)
# get in_gts mask, (b, max_num_obj, h*w)
mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes, roll_out=self.roll_out)
# get topk_metric mask, (b, max_num_obj, h*w)
mask_topk = self.select_topk_candidates(align_metric * mask_in_gts,
topk_mask=mask_gt.repeat([1, 1, self.topk]).bool())
# merge all mask to a final mask, (b, max_num_obj, h*w)
mask_pos = mask_topk * mask_in_gts * mask_gt
return mask_pos, align_metric, overlaps
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes):
if self.roll_out:
align_metric = torch.empty((self.bs, self.n_max_boxes, pd_scores.shape[1]), device=pd_scores.device)
overlaps = torch.empty((self.bs, self.n_max_boxes, pd_scores.shape[1]), device=pd_scores.device)
ind_0 = torch.empty(self.n_max_boxes, dtype=torch.long)
for b in range(self.bs):
ind_0[:], ind_2 = b, gt_labels[b].squeeze(-1).long()
# get the scores of each grid for each gt cls
bbox_scores = pd_scores[ind_0, :, ind_2] # b, max_num_obj, h*w
overlaps[b] = bbox_iou(gt_bboxes[b].unsqueeze(1), pd_bboxes[b].unsqueeze(0), xywh=False,
CIoU=True).squeeze(2).clamp(0)
align_metric[b] = bbox_scores.pow(self.alpha) * overlaps[b].pow(self.beta)
else:
ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj
ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes) # b, max_num_obj
ind[1] = gt_labels.long().squeeze(-1) # b, max_num_obj
# get the scores of each grid for each gt cls
bbox_scores = pd_scores[ind[0], :, ind[1]] # b, max_num_obj, h*w
overlaps = bbox_iou(gt_bboxes.unsqueeze(2), pd_bboxes.unsqueeze(1), xywh=False,
CIoU=True).squeeze(3).clamp(0)
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
return align_metric, overlaps
def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
"""
Args:
metrics: (b, max_num_obj, h*w).
topk_mask: (b, max_num_obj, topk) or None
"""
num_anchors = metrics.shape[-1] # h*w
# (b, max_num_obj, topk)
topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
if topk_mask is None:
topk_mask = (topk_metrics.max(-1, keepdim=True) > self.eps).tile([1, 1, self.topk])
# (b, max_num_obj, topk)
topk_idxs[~topk_mask] = 0
# (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
if self.roll_out:
is_in_topk = torch.empty(metrics.shape, dtype=torch.long, device=metrics.device)
for b in range(len(topk_idxs)):
is_in_topk[b] = F.one_hot(topk_idxs[b], num_anchors).sum(-2)
else:
is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(-2)
# filter invalid bboxes
is_in_topk = torch.where(is_in_topk > 1, 0, is_in_topk)
return is_in_topk.to(metrics.dtype)
def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
"""
Args:
gt_labels: (b, max_num_obj, 1)
gt_bboxes: (b, max_num_obj, 4)
target_gt_idx: (b, h*w)
fg_mask: (b, h*w)
"""
# assigned target labels, (b, 1)
batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w)
print(target_gt_idx.sum(-1))
print(gt_labels.long().flatten())
target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w)
print(target_labels.sum(-1))
# assigned target boxes, (b, max_num_obj, 4) -> (b, h*w)
target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
# assigned target scores
target_labels.clamp(0)
target_scores = F.one_hot(target_labels, self.num_classes) # (b, h*w, 80)
print(target_scores.sum(-2))
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)
target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
print(target_scores.sum(-2))
print(target_bboxes.sum(-2))
return target_labels, target_bboxes, target_scores
2.loss
2.1 class loss
nn.BCEWithLogitsLoss
因为采用这个损失函数,所以数据需要one-hot
但是v8对标签做了处理,就是不能用0-1表示概率,而应该用iou表示
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
因为是一个cls 所以是 1
pred_scores (1,5376,1)
target_scores(1,5376,1)
2.2 box loss
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
# IoU loss
weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1)
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
# DFL loss
if self.use_dfl:
target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max)
loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
loss_dfl = loss_dfl.sum() / target_scores_sum
else:
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
return loss_iou, loss_dfl
pred_dist 预测的上下左右4个边的16个数据 (1,5376,64)
pred_bboxes 预测的上下左右4个边的数据 (1,5376,4)
anchor_points 生成的anchor中心点 (5376,2)
target_bboxes 5366个box1 10个box2
target_scores, 20个anchor 的scores 其他都是0
target_scores_sum, 上一个的和
fg_mask 20个anchor的index为1 其他是0
预测的box 选出对应的20 box 和 真实的box 选出20个 做iou 计算loss iou
计算traget的上下左右
def _df_loss(pred_dist, target):
# Return sum of left and right DFL losses
# Distribution Focal Loss (DFL) proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
tl = target.long() # target left
tr = tl + 1 # target right
wl = tr - target # weight left
wr = 1 - wl # weight right
return (F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl +
F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr).mean(-1, keepdim=True)
https://zhuanlan.zhihu.com/p/149186719
3.修改网络输出数据格式
上一节我们看到网络输出的数据是什么样子,我们知道0-64是bbox 65是class
def bbox_decode(self, anchor_points, pred_dist):
if self.use_dfl:
b, a, c = pred_dist.shape # batch, anchors, channels
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
# pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))
# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
return dist2bbox(pred_dist, anchor_points, xywh=False)
4.修改标注数据格式
def preprocess(self, targets, batch_size, scale_tensor):
if targets.shape[0] == 0:
out = torch.zeros(batch_size, 0, 5, device=self.device)
else:
i = targets[:, 0] # image index
_, counts = i.unique(return_counts=True)
out = torch.zeros(batch_size, counts.max(), 5, device=self.device)
for j in range(batch_size):
matches = i == j
n = matches.sum()
if n:
out[j, :n] = targets[matches, 1:]
out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
return out