关于build_rpn_targets函数解析
def build_rpn_targets(anchors, gt_class_ids, gt_boxes, config):
"""Given the anchors and GT boxes, compute overlaps and identify positive
anchors and deltas to refine them to match their corresponding GT boxes.
anchors: [num_anchors, (y1, x1, y2, x2)]
gt_class_ids: [num_gt_boxes] Integer class IDs.
gt_boxes: [num_gt_boxes, (y1, x1, y2, x2)]
Returns:
rpn_match: [N] (int32) matches between anchors and GT boxes.
1 = positive anchor, -1 = negative anchor, 0 = neutral
rpn_bbox: [N, (dy, dx, log(dh), log(dw))] Anchor bbox deltas.
"""
# RPN Match: 1 = positive anchor, -1 = negative anchor, 0 = neutral
rpn_match = np.zeros([anchors.shape[0]], dtype=np.int32)
# RPN bounding boxes: [max anchors per image, (dy, dx, log(dh), log(dw))]
rpn_bbox = np.zeros((config.RPN_TRAIN_ANCHORS_PER_IMAGE, 4))
# Handle COCO crowds
# A crowd box in COCO is a bounding box around several instances. Exclude
# them from training. A crowd box is given a negative class ID.
crowd_ix = np.where(gt_class_ids < 0)[0]
if crowd_ix.shape[0] > 0:
# Filter out crowds from ground truth class IDs and boxes
non_crowd_ix = np.where(gt_class_ids > 0)[0]
crowd_boxes = gt_boxes[crowd_ix]
gt_class_ids = gt_class_ids[non_crowd_ix]
gt_boxes = gt_boxes[non_crowd_ix]
# Compute overlaps with crowd boxes [anchors, crowds]
crowd_overlaps = utils.compute_overlaps(anchors, crowd_boxes)
crowd_iou_max = np.amax(crowd_overlaps, axis=1)
no_crowd_bool = (crowd_iou_max < 0.001)
else:
# All anchors don't intersect a crowd
no_crowd_bool = np.ones([anchors.shape[0]], dtype=bool)
# Compute overlaps [num_anchors, num_gt_boxes]
overlaps = utils.compute_overlaps(anchors, gt_boxes)
# Match anchors to GT Boxes
# If an anchor overlaps a GT box with IoU >= 0.7 then it's positive.
# If an anchor overlaps a GT box with IoU < 0.3 then it's negative.
# Neutral anchors are those that don't match the conditions above,
# and they don't influence the loss function.
# However, don't keep any GT box unmatched (rare, but happens). Instead,
# match it to the closest anchor (even if its max IoU is < 0.3).
#
# 1. Set negative anchors first. They get overwritten below if a GT box is
# matched to them. Skip boxes in crowd areas.
anchor_iou_argmax = np.argmax(overlaps, axis=1)
anchor_iou_max = overlaps[np.arange(overlaps.shape[0]), anchor_iou_argmax]
rpn_match[(anchor_iou_max < 0.3) & (no_crowd_bool)] = -1
# 2. Set an anchor for each GT box (regardless of IoU value).
# TODO: If multiple anchors have the same IoU match all of them
gt_iou_argmax = np.argmax(overlaps, axis=0)
rpn_match[gt_iou_argmax] = 1
# 3. Set anchors with high overlap as positive.
rpn_match[anchor_iou_max >= 0.7] = 1
# Subsample to balance positive and negative anchors
# Don't let positives be more than half the anchors
pos_ids = np.where(rpn_match == 1)[0]
extra = len(pos_ids) - (config.RPN_TRAIN_ANCHORS_PER_IMAGE // 2)
if extra > 0:
# Reset the extra ones to neutral
pos_ids = np.random.choice(pos_ids, extra, replace=False)
rpn_match[pos_ids] = 0
# Same for negative proposals
neg_ids = np.where(rpn_match == -1)[0]
extra = len(neg_ids) - (config.RPN_TRAIN_ANCHORS_PER_IMAGE -
np.sum(rpn_match == 1))
if extra > 0:
# Rest the extra ones to neutral
neg_ids = np.random.choice(neg_ids, extra, replace=False)
rpn_match[neg_ids] = 0
# For positive anchors, compute shift and scale needed to transform them
# to match the corresponding GT boxes.
ids = np.where(rpn_match == 1)[0]
ix = 0
# TODO: use box_refinment() rather than duplicating the code here
for idx, anchor in zip(ids, anchors[ids]):
# Closest gt box (it might have IoU < 0.7)
gt = gt_boxes[anchor_iou_argmax[idx]]
# Convert coordinates to center plus width/height.
# GT Box
gt_h = gt[2] - gt[0]
gt_w = gt[3] - gt[1]
gt_center_y = gt[0] + 0.5 * gt_h
gt_center_x = gt[1] + 0.5 * gt_w
# Anchor
a_h = anchor[2] - anchor[0]
a_w = anchor[3] - anchor[1]
a_center_y = anchor[0] + 0.5 * a_h
a_center_x = anchor[1] + 0.5 * a_w
# Compute the bbox refinement that the RPN should predict.
rpn_bbox[ix] = [
(gt_center_y - a_center_y) / a_h,
(gt_center_x - a_center_x) / a_w,
np.log(gt_h / a_h),
np.log(gt_w / a_w),
]
ix += 1
# Normalize
rpn_bbox[ix] /= config.BBOX_STD_DEV
return rpn_match, rpn_bbox
这里输入生成的所有anchors(大约6万个),gt_class_ids 标注grand truth的class ids和其所对应的bbox. 1, 如果没有遮挡crowd, 声明所有anchors遮挡状态no_crowd_bool为true.
2. utils.compute_overlaps(anchors, gt_boxes)
计算所有anchors和gt_boxes的IOU 也可以理解为一种重合打分,如果小于0.3则认为和gt_box为negative关系,可以认为不相关,标记为-1,如果大于0.7则认为positive,理解为强相关,标记为1,如果某个gt_bbox和所有anchors IOU都不大于0.7则选取分数最大的那个标记为1. 这里很好理解 就是找到生成的anchors和ground truth标注框的关系,找到我们关注的那些anchors,即和ground truth 强相关或者强不相关的anchors. 方便后面方向传播求RPN 前后背景分析那个分支的Loss。剩下的那些我认为价值不大所有都标注0. 所有这些存进rpn_match里面.
3. 3. 最后找到有效anchors (rpn_match里标注为1的anchor)所对应的ground truth box, 分别求他们中心点平移,和边框缩放值, 保存在rpn_box里
4. 返回rpn_match, rpn_bbox