_predict分析
def _predict(self, molded_images, proposal_count,
mode='training', gt=None):
if mode not in ['inference', 'training']:
raise ValueError(f"mode {mode} not accepted.")
mrcnn_feature_maps, rpn_out = \
self._foreground_background_layer(molded_images)
batch_size = rpn_out.classes.shape[0]
anchors = (
self.anchors if batch_size > 1 else self.anchors[0].unsqueeze(0)
)
with torch.no_grad():
rpn_rois = proposal_layer( # Generate proposals
rpn_out.classes,
rpn_out.deltas,
proposal_count=proposal_count,
nms_threshold=self.config.RPN_NMS_THRESHOLD,
anchors=anchors,
config=self.config)
if mode == 'inference':
return self._inference(mrcnn_feature_maps, rpn_rois)
elif mode == 'training':
# Normalize coordinates
gt.boxes = gt.boxes / self.config.RPN_NORM
mrcnn_targets, mrcnn_outs = [], []
for img_idx in range(0, batch_size):
with torch.no_grad():
rois, mrcnn_target = detection_target_layer(
rpn_rois[img_idx], gt.class_ids[img_idx],
gt.boxes[img_idx], gt.masks[img_idx], self.config)
if rois.nelement() == 0:
mrcnn_out = MRCNNOutput().to(self.config.DEVICE)
logging.debug('Rois size is empty')
else:
# Network Heads
# Proposal classifier and BBox regressor heads
rois = rois.unsqueeze(0)
mrcnn_feature_maps_batch = [x[img_idx].unsqueeze(0).detach()
for x in mrcnn_feature_maps]
mrcnn_class_logits_, _, mrcnn_deltas_ = \
self.classifier(mrcnn_feature_maps_batch, rois)
# Create masks
mrcnn_mask_ = self.mask(mrcnn_feature_maps_batch, rois)
mrcnn_out = MRCNNOutput(mrcnn_class_logits_,
mrcnn_deltas_, mrcnn_mask_)
mrcnn_outs.append(mrcnn_out)
mrcnn_targets.append(mrcnn_target)
return rpn_out, mrcnn_targets, mrcnn_outs
- 先把molded_image 丢进
self._foreground_background_layer(molded_images)
函数,该函数做fpn分析返回fpn每层的feature maps 和rpn结果的 rpn_out对象 可以理解为,class_logits, class, 和delta三个量合集。 - 再进入proposal_layer,给rpn_out做非极大值抑制,并把rpn_out数量限制在proposal_count以下这里是2000.
- 然后选择训练模式或者推断模式。这里只讲训练模式。
- 把这2000个候选框和gt丢给
detection_target_layer
该函数会根据gt的标注框找到IOU大于0.5的候选框,比如,找到其中24个候选框和gt 标注框IOU大于0.5,根据config.TRAIN_ROIS_PER_IMAGE * config.ROI_POSITIVE_RATIO
的值,这里是66把24个候选框补足为66个候选框,不足的部分在IOU小于0.5的框找,并标注这些框为0即背景。最后该函数返回66个候选框和其对应的target 即class id, delta和masks - 把这66个候选框rois丢给
classifier
函数做分类和回归分析 - 然后再把这66个候选框rois,丢给mask函数做像素级mask分析
- 最后把mrcnn_class_logits, mrcnn_deltas, mrcnn_mask打包成mrcnn_out
- 再把rpn_out, mrcnn_targets, mrcnn_outs返回