#关于DataGenerator getitem 部分分析
class DataGenerator(Dataset):
def __init__(self, dataset_handler, config, augmentation=None, anchors=None):
"""A generator that returns images and corresponding target class ids,
bounding box deltas, and masks.
dataset_handler: The Dataset object to pick data from
config: The model config object
shuffle: If True, shuffles the samples before every epoch
augment: If True, applies image augmentation to images
(currently only horizontal flips are supported)
Returns a Python generator. Upon calling next() on it, the
generator returns two lists, inputs and outputs. The containtes
of the lists differs depending on the received arguments:
inputs list:
- images: [batch, H, W, C]
- image_metas: [batch, size of image meta]
- rpn_match: [batch, N] Integer (1=positive anchor,
-1=negative, 0=neutral)
- rpn_bbox: [batch, N, (dy, dx, log(dh), log(dw))] Anchor bbox
deltas.
- gt_class_ids: [batch, MAX_GT_INSTANCES] Integer class IDs
- gt_boxes: [batch, MAX_GT_INSTANCES, (y1, x1, y2, x2)]
- gt_masks: [batch, height, width, MAX_GT_INSTANCES]. The height
and width are those of the image unless use_mini_mask
is True, in which case they are defined in
MINI_MASK_SHAPE.
outputs list: Usually empty in regular training. But if
detection_targets is True then the outputs list
contains target class_ids, bbox deltas, and masks.
"""
self.b = 0 # batch item index
self.image_index = -1
self.image_ids = np.copy(dataset_handler.image_ids)
self.error_count = 0
self.config = config
self.dataset_handler = dataset_handler
self.augmentation = augmentation
# Anchors
# [anchor_count, (y1, x1, y2, x2)]
self.anchors = anchors.cpu()
@profilable
def __getitem__(self, image_index):
# Get GT bounding boxes and masks for image.
image_id = self.image_ids[image_index]
while True:
image, image_metas, gt_class_ids, gt_boxes, gt_masks = \
load_image_gt(self.dataset_handler, image_id,config=self.config,
augmentation=self.augmentation,
use_mini_mask=self.config.USE_MINI_MASK)
if np.any(gt_class_ids > 0):
break
# RPN Targets
rpn_match, rpn_bbox = build_rpn_targets(
self.anchors, gt_class_ids, gt_boxes, self.config)
# If more instances than fits in the array, sub-sample from them.
if gt_boxes.shape[0] > self.config.MAX_GT_INSTANCES:
ids = np.random.choice(np.arange(gt_boxes.shape[0]),
self.config.MAX_GT_INSTANCES,
replace=False)
gt_class_ids = gt_class_ids[ids]
gt_boxes = gt_boxes[ids]
gt_masks = gt_masks[:, :, ids]
elif gt_boxes.shape[0] < self.config.MAX_GT_INSTANCES:
gt_class_ids_ = np.zeros((self.config.MAX_GT_INSTANCES),
dtype=np.int32)
gt_class_ids_[:gt_class_ids.shape[0]] = gt_class_ids
gt_class_ids = gt_class_ids_
gt_boxes_ = np.zeros((self.config.MAX_GT_INSTANCES, 4),
dtype=np.int32)
gt_boxes_[:gt_boxes.shape[0]] = gt_boxes
gt_boxes = gt_boxes_
gt_masks_ = np.zeros((gt_masks.shape[0], gt_masks.shape[1],
self.config.MAX_GT_INSTANCES),
dtype=np.int32)
gt_masks_[:, :, :gt_masks.shape[-1]] = gt_masks
gt_masks = gt_masks_
# Add to batch
rpn_match = rpn_match[:, np.newaxis]
image = utils.subtract_mean(image, config=self.config)
# Convert to tensors
image = torch.from_numpy(image.transpose(2, 0, 1)).float()
rpn_match = torch.from_numpy(rpn_match)
rpn_bbox = torch.from_numpy(rpn_bbox).float()
gt_class_ids = torch.from_numpy(gt_class_ids)
gt_boxes = torch.from_numpy(gt_boxes).float()
gt_masks = torch.from_numpy(gt_masks.astype(int).transpose(2, 0, 1)).float()
return (image, image_metas.to_numpy(), rpn_match, rpn_bbox,
gt_class_ids, gt_boxes, gt_masks)
def __len__(self):
return self.image_ids.shape[0]
- 从load_image_gt函数里获得image, image_metas, gt_class_ids, gt_boxes, gt_masks
- 从build_rpn_target获得rpn_match, rpn_bbox过程在分析一里说明过。
- 将gt_boxes, gt_class_ids, gt_masks 的size 补充值self.config.MAX_GT_INSTANCES的长度,这里是100. 如果不够,补零,比如gt_class_ids原来是[2,8,14,44] 现在补充为[2,8,14,44,0,0,0…0,0]
- 将image, rpn_match, rpn_bbox, gt_class_ids, gt_boxes, gt_masks 转换为numpy后输出。