MaskRcnn(七)源码解读之DetectionTarget层

一、原理解读

DetectionTarget层作用:
1、把之前Proposal筛选后的几千个候选框中用0做padding的去掉。
2、把数据集中一个框包含多个物体的去掉,影响训练的结果。
3、判断正负样本。
4、设置负样本和正样本之间样本数量的比例。
5、得到每一个正样本的类别。
6、得到每一个候选框与实际正样本的偏移量。
7、得到与实际正样本最接近的候选框对应的Mask。
8、返回所有的结果,负样本偏移量和mask都用0进行填充。

三、代码解读

def detection_targets_graph(proposals, gt_class_ids, gt_boxes, gt_masks, config):
    
    asserts = [
        tf.Assert(tf.greater(tf.shape(proposals)[0], 0), [proposals],
                  name="roi_assertion"),
    ]
    with tf.control_dependencies(asserts):
        proposals = tf.identity(proposals)
    # 剔除掉0填充的
    proposals, _ = trim_zeros_graph(proposals, name="trim_proposals")
    gt_boxes, non_zeros = trim_zeros_graph(gt_boxes, name="trim_gt_boxes")
    gt_class_ids = tf.boolean_mask(gt_class_ids, non_zeros,
                                   name="trim_gt_class_ids")
    gt_masks = tf.gather(gt_masks, tf.where(non_zeros)[:, 0], axis=2,
                         name="trim_gt_masks")
    #进行筛选一些重叠等情况的标签。
    crowd_ix = tf.where(gt_class_ids < 0)[:, 0]
    non_crowd_ix = tf.where(gt_class_ids > 0)[:, 0]
    crowd_boxes = tf.gather(gt_boxes, crowd_ix)
    crowd_masks = tf.gather(gt_masks, crowd_ix, axis=2)
    gt_class_ids = tf.gather(gt_class_ids, non_crowd_ix)
    gt_boxes = tf.gather(gt_boxes, non_crowd_ix)
    gt_masks = tf.gather(gt_masks, non_crowd_ix, axis=2)
    #进行比较,计算iou值
    overlaps = overlaps_graph(proposals, gt_boxes)
    #剔除掉一个框多个物体的标签,
    crowd_overlaps = overlaps_graph(proposals, crowd_boxes)
    crowd_iou_max = tf.reduce_max(crowd_overlaps, axis=1)
    no_crowd_bool = (crowd_iou_max < 0.001)
    # 选出iou最大的
    roi_iou_max = tf.reduce_max(overlaps, axis=1)
    # 选择大于0.5的候选框
    positive_roi_bool = (roi_iou_max >= 0.5)
    positive_indices = tf.where(positive_roi_bool)[:, 0]
    negative_indices = tf.where(tf.logical_and(roi_iou_max < 0.5, no_crowd_bool))[:, 0]
    #得到正负样本的个数
    positive_count = int(config.TRAIN_ROIS_PER_IMAGE *
                         config.ROI_POSITIVE_RATIO)
    positive_indices = tf.random_shuffle(positive_indices)[:positive_count]
    positive_count = tf.shape(positive_indices)[0]
    #r是3,代表负的样本个数是真的3倍。
    r = 1.0 / config.ROI_POSITIVE_RATIO
    negative_count = tf.cast(r * tf.cast(positive_count, tf.float32), tf.int32) - positive_count
    negative_indices = tf.random_shuffle(negative_indices)[:negative_count]
    #收集选定的ROI
    positive_rois = tf.gather(proposals, positive_indices)
    negative_rois = tf.gather(proposals, negative_indices)
    # 将正ROI分配给GT框。
    positive_overlaps = tf.gather(overlaps, positive_indices)
    roi_gt_box_assignment = tf.argmax(positive_overlaps, axis=1)
    roi_gt_boxes = tf.gather(gt_boxes, roi_gt_box_assignment)
    roi_gt_class_ids = tf.gather(gt_class_ids, roi_gt_box_assignment)
    # 实际训练的不是坐标,而是偏移量,把实际的偏移量计算出来
    deltas = utils.box_refinement_graph(positive_rois, roi_gt_boxes)
    deltas /= config.BBOX_STD_DEV
    # 将正ROI分配给GT框。
    transposed_masks = tf.expand_dims(tf.transpose(gt_masks, [2, 0, 1]), -1)
    # 为每个ROI选择正确的mask
    roi_masks = tf.gather(transposed_masks, roi_gt_box_assignment)
    # 计算遮罩目标
    boxes = positive_rois
    if config.USE_MINI_MASK:
        # 从标准化图像空间变换ROI
        y1, x1, y2, x2 = tf.split(positive_rois, 4, axis=1)
        gt_y1, gt_x1, gt_y2, gt_x2 = tf.split(roi_gt_boxes, 4, axis=1)
        gt_h = gt_y2 - gt_y1
        gt_w = gt_x2 - gt_x1
        y1 = (y1 - gt_y1) / gt_h
        x1 = (x1 - gt_x1) / gt_w
        y2 = (y2 - gt_y1) / gt_h
        x2 = (x2 - gt_x1) / gt_w
        boxes = tf.concat([y1, x1, y2, x2], 1)
    box_ids = tf.range(0, tf.shape(roi_masks)[0])
    masks = tf.image.crop_and_resize(tf.cast(roi_masks, tf.float32), boxes,
                                     box_ids,
                                     config.MASK_SHAPE)
    # 从mask中删除额外的尺寸标注。
    masks = tf.squeeze(masks, axis=3)

    # 阈值遮罩像素为0.5,GT遮罩为0或1,用于二元交叉熵损失。
    #重新计算mask
    masks = tf.round(masks)
    # 附加负ROI和pad bbox增量以及不用于带零的负ROI的掩码。
    rois = tf.concat([positive_rois, negative_rois], axis=0)
    N = tf.shape(negative_rois)[0]
    P = tf.maximum(config.TRAIN_ROIS_PER_IMAGE - tf.shape(rois)[0], 0)
    rois = tf.pad(rois, [(0, P), (0, 0)])
    roi_gt_boxes = tf.pad(roi_gt_boxes, [(0, N + P), (0, 0)])
    roi_gt_class_ids = tf.pad(roi_gt_class_ids, [(0, N + P)])
    deltas = tf.pad(deltas, [(0, N + P), (0, 0)])
    masks = tf.pad(masks, [[0, N + P], (0, 0), (0, 0)])
    #返回结果roi,每个roi对应的类别,偏移量,掩码
    return rois, roi_gt_class_ids, deltas, masks
mask rcnn pytorch 的源码是一个用于目标检测和实例分割的深度学习模型的实现。其中包含了一些关键的文件和类来构建网络和实现相关功能。 在 Faster R-CNN 中,首次提出了 RPN 网络,该网络用于生成目标检测任务所需的候选区域框。在 MaskrcnnBenchmark 中,关于 RPN 网络的定义位于 ./maskrcnn_benchmark/modeling/rpn/ 文件夹中。这个文件夹包含以下四个文件:rpn.py、anchor_generator.py、inference.py、loss.py。在 class GeneralizedRCNN(nn.Module) 类中,通过 self.rpn = build_rpn(cfg) 函数来创建 RPN 网络,该函数位于 ./maskrcnn_benchmark/modeling/rpn/rpn.py 文件中。 在 rpn.py 文件中,有 build_fpn(cfg) 函数返回一个 RPNModule 的实例。make_anchor_generator() 函数是用来定义 RPN 网络默认的 anchor 的面积大小、高宽比和 feature map 采用的 stride,还有剪枝功能的设置。 需要注意的是,在最新版本的实现中,存在一些错误和不足,不适合用作理解 Mask R-CNN 架构的资源。因此,对于深入研究该源码,建议参考更可靠的资源或最新版本的实现。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *3* [【pytorch】Mask-RCNN官方源码剖析(Ⅲ)](https://blog.csdn.net/qq_43348528/article/details/107556259)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *2* [Mask RCNN架构的PyTorch实现-Python开发](https://download.csdn.net/download/weixin_42098830/19060631)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

血狼傲骨

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值