总述
- Mask/Faster R-Cnn中总共进行了2次很相似的anchors/proposals与target匹配过程
- 目的都是将anchors/proposal和gt对应起来,即为每一个建议框都对应一个GTBox/GTLabel.以计算损失
- 一次在RPN过程中,一次在ROIHeads过程中,分别为了计算rpn的2个Loss,以及ROIHeads的多个Loss
函数原型/调用
两个方法存在一些异同,下面是2个函数原型
首先proposals就是anchors经过过RPN回归后得到的更精确的候选区域的过滤结果
def assign_targets_to_anchors(self, anchors, targets):
def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
函数调用
labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
loss目标不同
- assign_targets_to_anchors 是为了计算RPN的loss,以更新rpn参数的
- assign_targets_to_proposals 是为了计算RoiHeads的loss
返回值
方法 | labels | matched_gt_boxes/matched_idxs |
---|---|---|
assign_targets_to_anchors | labels 只包含[-1,0,1],表示正样本的1,表示负样本的0,和表示丢弃的-1 | 表示对应的gt的边框信息(labels为负的是无效信息(错误的值(0对应的box))) |
assign_targets_to_proposals | labels 包含[0,1,2…]:表示proposal属于哪一分类,0代表背景 | 表示对应的gt的下标(labels为0对应的是无效信息(错误的值(0))) |
入参
方法 | anchors/proposals | targets/(gt_boxes, gt_labels) | 配置参数:fg_iou_thresh/bg_iou_thresh |
---|---|---|---|
assign_targets_to_anchors | 边框信息(候选框) | targets实际包含了gt_boxes和gt_labels,但是只是用了gt_boxes,因为在RPN阶段,我们不关心这个GTbox到底是什么分类的,只关心是否真的存在对象,即anchors对应的是前景还是背景, | 计算时. 仅仅考虑anchors是否和 gtbox重合了(iou),如iou大于上限(fg_iou_thresh)则是正样本,小于下限(bg_iou_thresh)即为负样本,中间的忽略. 故而上下限默认值是0.7和0.3.且返回的数据labels,包含了-1值, |
assign_targets_to_proposals | 边框信息(建议框) | 因为需要计算此proposal具体是哪一个分类,所以需要精确对应于哪一个分类,即gt_labels信息.在大于iou的情况下,即是前景(包含对象)需要定位是哪一个分类.否则是背景,记为分类0. | 计算时一般情况无需忽略,而且判定为对象的阈值较为宽松,所以默认设为0.5和0.5.返回数据labels 从0开始 |
附pytorch实现函数
def assign_targets_to_anchors(
self, anchors: List[Tensor], targets: List[Dict[str, Tensor]]
) -> Tuple[List[Tensor], List[Tensor]]:
labels = []
matched_gt_boxes = []
for anchors_per_image, targets_per_image in zip(anchors, targets):
gt_boxes = targets_per_image["boxes"]
if gt_boxes.numel() == 0:
# Background image (negative example)
device = anchors_per_image.device
matched_gt_boxes_per_image = torch.zeros(anchors_per_image.shape, dtype=torch.float32, device=device)
labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device)
else:
match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image)
matched_idxs = self.proposal_matcher(match_quality_matrix)
# get the targets corresponding GT for each proposal
# NB: need to clamp the indices because we can have a single
# GT in the image, and matched_idxs can be -2, which goes
# out of bounds
matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]
labels_per_image = matched_idxs >= 0
labels_per_image = labels_per_image.to(dtype=torch.float32)
# Background (negative examples)
bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
labels_per_image[bg_indices] = 0.0
# discard indices that are between thresholds
inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
labels_per_image[inds_to_discard] = -1.0
labels.append(labels_per_image)
matched_gt_boxes.append(matched_gt_boxes_per_image)
return labels, matched_gt_boxes
def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
# type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
matched_idxs = []
labels = []
for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
if gt_boxes_in_image.numel() == 0:
# Background image
device = proposals_in_image.device
clamped_matched_idxs_in_image = torch.zeros(
(proposals_in_image.shape[0],), dtype=torch.int64, device=device
)
labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
else:
# set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
labels_in_image = labels_in_image.to(dtype=torch.int64)
# Label background (below the low threshold)
bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
labels_in_image[bg_inds] = 0
# Label ignore proposals (between low and high thresholds)
ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler
matched_idxs.append(clamped_matched_idxs_in_image)
labels.append(labels_in_image)
return matched_idxs, labels