这一部分主要包含几个重要的函数,下面详细介绍:
1.计算bbox回归要用到的几个参数dx,dy,dw,dh
# calculate the four regression value of the ex_rois to the gt_rois
def bbox_transform(ex_rois, gt_rois):
# calculate the width and height of each anchor
ex_widths = ex_rois[:, 2]-ex_rois[:, 0]+1.0
ex_heights = ex_rois[:, 3]-ex_rois[:, 1]+1.0
# calculate the center point of each anchor
ex_ctr_x = ex_rois[:, 0]+0.5*ex_widths
ex_ctr_y = ex_rois[:, 1]+0.5*ex_heights
# calculate the width and height of each GT
gt_widths = gt_rois[:, 2] - gt_rois[:, 0] + 1.0
gt_heights = gt_rois[:, 3] - gt_rois[:, 1] + 1.0
# calculate the center point of each GT
gt_ctr_x = gt_rois[:, 0] + 0.5 * gt_widths
gt_ctr_y = gt_rois[:, 1] + 0.5 * gt_heights
# the prepare calculation for the bbox regression
targets_dx = (gt_ctr_x - ex_ctr_x) / ex_widths
targets_dy = (gt_ctr_y - ex_ctr_y) / ex_heights
targets_dw = torch.log(gt_widths / ex_widths)
targets_dh = torch.log(gt_heights / ex_heights)
targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh), 1)
return targets
2. 对原boxes经过平移缩放后得到的pred_boxes
# transform the boxes to the pred_boxes
def bbox_transform_inv(boxes, deltas, batch_size):
widths = boxes[:, :, 2] - boxes[:, :, 0] + 1.0
heights = boxes[:, :, 3] - boxes[:, :, 1] + 1.0
ctr_x = boxes[:, :, 0] + 0.5 * widths
ctr_y = boxes[:, :, 1] + 0.5 * heights
dx = deltas[:, :, 0::4]
dy = deltas[:, :, 1::4]
dw = deltas[:, :, 2::4]
dh = deltas[:, :, 3::4]
pred_ctr_x = dx * widths.unsqueeze(2) + ctr_x.unsqueeze(2)
pred_ctr_y = dy * heights.unsqueeze(2) + ctr_y.unsqueeze(2)
pred_w = torch.exp(dw) * widths.unsqueeze(2)
pred_h = torch.exp(dh) * heights.unsqueeze(2)
pred_boxes = deltas.clone()
# x1
pred_boxes[:, :, 0::4] = pred_ctr_x-0.5 * pred_w
# y1
pred_boxes[:, :, 1::4] = pred_ctr_y - 0.5 * pred_h
# x2
pred_boxes[:, :, 2::4] = pred_ctr_x + 0.5 * pred_w
# y2
pred_boxes[:, :, 3::4] = pred_ctr_y + 0.5 * pred_h
return pred_boxes
3.调整boxes的坐标,使其全部在图像的范围内
# make the boxes in the field of the picture
def clip_boxes(boxes, im_shape, batch_size):
for i in range(batch_size):
torch.clamp(boxes[i, :, 0::4], 0, im_shape[i, 1] - 1)
torch.clamp(boxes[i, :, 1::4], 0, im_shape[i, 0] - 1)
torch.clamp(boxes[i, :, 2::4], 0, im_shape[i, 1] - 1)
torch.clamp(boxes[i, :, 3::4], 0, im_shape[i, 0] - 1)
return boxes
4.计算IOU
# calculate the IOU between the anchors and the gt_boxes
def bbox_overlaps(anchors, gt_boxes):
"""
anchors:(N, 4) ndarray of float
gt_boxes: (K, 4) ndarray of float
overlaps: (N, K) ndarray of overlap between boxes and query_boxes
"""
N = anchors.size(0)
K = gt_boxes.size(0)
gt_boxes_aera = ((gt_boxes[:, 2] - gt_boxes[:, 0] + 1) * (gt_boxes[:, 3] - gt_boxes[:, 1] + 1)).view(1, K)
anchors_aera = ((anchors[:, 2] - anchors[:, 0] + 1) * (anchors[:, 3] - anchors[:, 1] + 1)).view(N, 1)
boxes = anchors.view(N, 1, 4).expand(N, K, 4)
query_boxes = gt_boxes.view(1, K, 4).expand(N, K, 4)
iw = (torch.min(boxes[:, :, 2], query_boxes[:, :, 2]) - torch.max(boxes[:, :, 0], query_boxes[:, :, 0]) + 1)
iw[iw<0] = 0
ih = (torch.min(boxes[:, :, 3], query_boxes[:, :, 3]) - torch.max(boxes[:, :, 1], query_boxes[:, :, 1]) + 1)
ih[ih<0] = 0
ua = anchors_aera + gt_boxes_aera -iw * ih
overlaps = iw * ih / ua
return overlaps