mask rcnn pytorch源码分析二

#关于DataGenerator getitem 部分分析

class DataGenerator(Dataset):
    def __init__(self, dataset_handler, config, augmentation=None, anchors=None):
        """A generator that returns images and corresponding target class ids,
            bounding box deltas, and masks.

            dataset_handler: The Dataset object to pick data from
            config: The model config object
            shuffle: If True, shuffles the samples before every epoch
            augment: If True, applies image augmentation to images
                     (currently only horizontal flips are supported)

            Returns a Python generator. Upon calling next() on it, the
            generator returns two lists, inputs and outputs. The containtes
            of the lists differs depending on the received arguments:
            inputs list:
            - images: [batch, H, W, C]
            - image_metas: [batch, size of image meta]
            - rpn_match: [batch, N] Integer (1=positive anchor,
                                             -1=negative, 0=neutral)
            - rpn_bbox: [batch, N, (dy, dx, log(dh), log(dw))] Anchor bbox
                        deltas.
            - gt_class_ids: [batch, MAX_GT_INSTANCES] Integer class IDs
            - gt_boxes: [batch, MAX_GT_INSTANCES, (y1, x1, y2, x2)]
            - gt_masks: [batch, height, width, MAX_GT_INSTANCES]. The height
                        and width are those of the image unless use_mini_mask
                        is True, in which case they are defined in
                        MINI_MASK_SHAPE.

            outputs list: Usually empty in regular training. But if
                          detection_targets is True then the outputs list
                          contains target class_ids, bbox deltas, and masks.
            """
        self.b = 0  # batch item index
        self.image_index = -1
        self.image_ids = np.copy(dataset_handler.image_ids)
        self.error_count = 0
        self.config = config
        self.dataset_handler = dataset_handler
        self.augmentation = augmentation

        # Anchors
        # [anchor_count, (y1, x1, y2, x2)]
        self.anchors = anchors.cpu()


    @profilable
    def __getitem__(self, image_index):
        # Get GT bounding boxes and masks for image.
        image_id = self.image_ids[image_index]
        while True:
            image, image_metas, gt_class_ids, gt_boxes, gt_masks = \
                load_image_gt(self.dataset_handler, image_id,config=self.config,
                              augmentation=self.augmentation,
                              use_mini_mask=self.config.USE_MINI_MASK)
            if np.any(gt_class_ids > 0):
                break

        # RPN Targets
        rpn_match, rpn_bbox = build_rpn_targets(
            self.anchors, gt_class_ids, gt_boxes, self.config)

        # If more instances than fits in the array, sub-sample from them.
        if gt_boxes.shape[0] > self.config.MAX_GT_INSTANCES:
            ids = np.random.choice(np.arange(gt_boxes.shape[0]),
                                   self.config.MAX_GT_INSTANCES,
                                   replace=False)
            gt_class_ids = gt_class_ids[ids]
            gt_boxes = gt_boxes[ids]
            gt_masks = gt_masks[:, :, ids]
        elif gt_boxes.shape[0] < self.config.MAX_GT_INSTANCES:
            gt_class_ids_ = np.zeros((self.config.MAX_GT_INSTANCES),
                                     dtype=np.int32)
            gt_class_ids_[:gt_class_ids.shape[0]] = gt_class_ids
            gt_class_ids = gt_class_ids_

            gt_boxes_ = np.zeros((self.config.MAX_GT_INSTANCES, 4),
                                 dtype=np.int32)
            gt_boxes_[:gt_boxes.shape[0]] = gt_boxes
            gt_boxes = gt_boxes_

            gt_masks_ = np.zeros((gt_masks.shape[0], gt_masks.shape[1],
                                  self.config.MAX_GT_INSTANCES),
                                 dtype=np.int32)
            gt_masks_[:, :, :gt_masks.shape[-1]] = gt_masks
            gt_masks = gt_masks_

        # Add to batch
        rpn_match = rpn_match[:, np.newaxis]
        image = utils.subtract_mean(image, config=self.config)

        # Convert to tensors
        image = torch.from_numpy(image.transpose(2, 0, 1)).float()
        rpn_match = torch.from_numpy(rpn_match)
        rpn_bbox = torch.from_numpy(rpn_bbox).float()
        gt_class_ids = torch.from_numpy(gt_class_ids)
        gt_boxes = torch.from_numpy(gt_boxes).float()
        gt_masks = torch.from_numpy(gt_masks.astype(int).transpose(2, 0, 1)).float()
        return (image, image_metas.to_numpy(), rpn_match, rpn_bbox,
                gt_class_ids, gt_boxes, gt_masks)

    def __len__(self):
        return self.image_ids.shape[0]
  1. 从load_image_gt函数里获得image, image_metas, gt_class_ids, gt_boxes, gt_masks
  2. 从build_rpn_target获得rpn_match, rpn_bbox过程在分析一里说明过。
  3. 将gt_boxes, gt_class_ids, gt_masks 的size 补充值self.config.MAX_GT_INSTANCES的长度,这里是100. 如果不够,补零,比如gt_class_ids原来是[2,8,14,44] 现在补充为[2,8,14,44,0,0,0…0,0]
  4. 将image, rpn_match, rpn_bbox, gt_class_ids, gt_boxes, gt_masks 转换为numpy后输出。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值