在学习obb_head之前,请务必理解了hbb_head的原理,然后再来看本文。
class OBB(Detect):
"""YOLOv8 OBB detection head for detection with rotation models."""
def __init__(self, nc=80, ne=1, ch=()):
"""Initialize OBB with number of classes `nc` and layer channels `ch`."""
super().__init__(nc, ch)
self.ne = ne # number of extra parameters
self.detect = Detect.forward
c4 = max(ch[0] // 4, self.ne)
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch)
def forward(self, x):
"""Concatenates and returns predicted bounding boxes and class probabilities."""
bs = x[0].shape[0] # batch size
angle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2) # OBB theta logits
# NOTE: set `angle` as an attribute so that `decode_bboxes` could use it.
angle = (angle.sigmoid() - 0.25) * math.pi # [-pi/4, 3pi/4]
# angle = angle.sigmoid() * math.pi / 2 # [0, pi/2]
if not self.training:
self.angle = angle
x = self.detect(self, x)
if self.training:
return x, angle
return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))
def decode_bboxes(self, bboxes, anchors):
"""Decode rotated bounding boxes."""
return dist2rbox(bboxes, self.angle, anchors, dim=1)
可以看到,obb继承自detect,也就是hbb检测头,但它主要增加了一个新的卷积块,用来预测角度angle.最后前向传播时多了一个角度信息。
在样本分配环节,obb同样继承自hbb的样本分配:
class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
def iou_calculation(self, gt_bboxes, pd_bboxes):
"""IoU calculation for rotated bounding boxes."""
return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
@staticmethod
def select_candidates_in_gts(xy_centers, gt_bboxes):
"""
Select the positive anchor center in gt for rotated bounding boxes.
Args:
xy_centers (Tensor): shape(h*w, 2)
gt_bboxes (Tensor): shape(b, n_boxes, 5)
Returns:
(Tensor): shape(b, n_boxes, h*w)
"""
# (b, n_boxes, 5) --> (b, n_boxes, 4, 2)
corners = xywhr2xyxyxyxy(gt_bboxes)
# (b, n_boxes, 1, 2)
a, b, _, d = corners.split(1, dim=-2)
ab = b - a
ad = d - a
# (b, n_boxes, h*w, 2)
ap = xy_centers - a
norm_ab = (ab * ab).sum(dim=-1)
norm_ad = (ad * ad).sum(dim=-1)
ap_dot_ab = (ap * ab).sum(dim=-1)
ap_dot_ad = (ap * ad).sum(dim=-1)
return (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad) # is_in_box
可以看到,原来的iou计算方法被重写了,这里使用的是probiou。
然后选择候选框方法也被重写了,首先将坐标转换为多边形表示法,然后通过split方法,得到矩形的四个顶点a, b, _, d,它们分别代表了每个顶点的(x,y),通过相减,我们可以得到矩形的两条边向量ab,ad.
ap = xy_centers - a
利用了广播机制,计算点a到各个锚框中心点的向量。
norm_ab
和 norm_ad
分别是边界框边向量的模长的平方,用于后续计算。
这里判断锚点是否落在了真实框内,用到了向量的知识,如果顶点a到锚点p的向量ap与真实框的两条边的向量点乘均为正,说明在同侧,同时如果点乘大小小于两条边的长度,说明确实落在了真实框内部。
这里了解一下如何将θ角表示法转换为多边形表示法:
主要涉及了平面坐标变换的数学知识,如果一个二维平面坐标旋转了θ角度,则坐标变换公式是
具体数学推导可以参考:
【数学——二维旋转矩阵的解释】https://www.bilibili.com/video/BV1T4411H79Q?vd_source=8720b96192c0895a7043e7fe6f6b6565
这里我们的真实框的宽高的一半是h/2, w/2,所以我们很自然可以得到在旋转后的坐标系下,这些宽和高的表示:
vec1 = [w / 2 * cos_value, w / 2 * sin_value]
vec2 = [-h / 2 * sin_value, h / 2 * cos_value]
由于旋转框是基于中心点旋转的,所以中心点实际上是不动的,只需要变换顶点坐标就可以了。
我们得到的vec相当于从中心点出发,到旋转后的坐标的顶点。故以此类推,可以得到四个顶点的计算方法,那就是分别加上/减去对应xy坐标的偏移量即可:
pt1 = ctr + vec1 + vec2
pt2 = ctr + vec1 - vec2
pt3 = ctr - vec1 - vec2
pt4 = ctr - vec1 + vec2
最后是损失计算部分:
class RotatedBboxLoss(BboxLoss):
"""Criterion class for computing training losses during training."""
def __init__(self, reg_max, use_dfl=False):
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
super().__init__(reg_max, use_dfl)
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
"""IoU loss."""
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
# DFL loss
if self.use_dfl:
target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.reg_max)
loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
loss_dfl = loss_dfl.sum() / target_scores_sum
else:
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
return loss_iou, loss_dfl
可以看到,还是继承了hbb的损失计算方法,变化的只是iou计算方法。
相应的dfl部分也是 简单粗暴的直接舍去了旋转角度,直接去计算hbb的四个点的损失。