Mask RCNN使用及实现详解(5)

模型的输入与数据加载

1、 模型输入
从模型的训练与预测中可以看到实际创建的模型是一个名为MaskRCNN的类相关代码片段如下:

在这里插入图片描述

训练创建模型的代码

在这里插入图片描述

预测创建模型的代码

这个类的初始化方法如下:

def __init__(self, mode, config, model_dir):
    """
    mode: Either "training" or "inference"
    config: A Sub-class of the Config class
    model_dir: Directory to save training logs and trained weights
    """
    assert mode in ['training', 'inference']
    self.mode = mode
    self.config = config
    self.model_dir = model_dir
    self.set_log_dir()
    self.keras_model = self.build(mode=mode, config=config)

其中mode参数是用来标识模型是用于训练还是预测,不同的用途模型返回的结构有少许不同。而build方法是用来创建模型的,其创建的是一个keras模型。该方法很长,与输入相关的部分如下:

def build(self, mode, config):
    """Build Mask R-CNN architecture.
        input_shape: The shape of the input image.
        mode: Either "training" or "inference". The inputs and
            outputs of the model differ accordingly.
    """
    assert mode in ['training', 'inference']

    # Image size must be dividable by 2 multiple times
    h, w = config.IMAGE_SHAPE[:2]
    if h / 2**6 != int(h / 2**6) or w / 2**6 != int(w / 2**6):
        raise Exception("Image size must be dividable by 2 at least 6 times "
                        "to avoid fractions when downscaling and upscaling."
                        "For example, use 256, 320, 384, 448, 512, ... etc. ")

    # Inputs
    input_image = KL.Input(
        shape=[None, None, config.IMAGE_SHAPE[2]], name="input_image")
    input_image_meta = KL.Input(shape=[config.IMAGE_META_SIZE],
                                name="input_image_meta")
    if mode == "training":
        # RPN GT
        input_rpn_match = KL.Input(
            shape=[None, 1], name="input_rpn_match", dtype=tf.int32)
        input_rpn_bbox = KL.Input(
            shape=[None, 4], name="input_rpn_bbox", dtype=tf.float32)

        # Detection GT (class IDs, bounding boxes, and masks)
        # 1. GT Class IDs (zero padded)
        input_gt_class_ids = KL.Input(
            shape=[None], name="input_gt_class_ids", dtype=tf.int32)
        # 2. GT Boxes in pixels (zero padded)
        # [batch, MAX_GT_INSTANCES, (y1, x1, y2, x2)] in image coordinates
        input_gt_boxes = KL.Input(
            shape=[None, 4], name="input_gt_boxes", dtype=tf.float32)
        # Normalize coordinates
        gt_boxes = KL.Lambda(lambda x: norm_boxes_graph(
            x, K.shape(input_image)[1:3]))(input_gt_boxes)
        # 3. GT Masks (zero padded)
        # [batch, height, width, MAX_GT_INSTANCES]
        if config.USE_MINI_MASK:
            input_gt_masks = KL.Input(
                shape=[config.MINI_MASK_SHAPE[0],
                       config.MINI_MASK_SHAPE[1], None],
                name="input_gt_masks", dtype=bool)
        else:
            input_gt_masks = KL.Input(
                shape=[config.IMAGE_SHAPE[0], config.IMAGE_SHAPE[1], None],
                name="input_gt_masks", dtype=bool)
    elif mode == "inference":
        # Anchors in normalized coordinates
        input_anchors = KL.Input(shape=[None, 4], name="input_anchors")

首先细看代码,在第10行到第14行,这里首先拿出了配置的输入图片的形状,拿到图片的宽与高后判断其宽和高是否能被2的6次方整除。这一步是因为在后面的FPN中会对图片进行6次步长为2的卷积,所有需要其能被2的6次方整除。

在这里插入图片描述

然后是第17行和第19行,这里定义了两个输入:input_image和input_image_meta。其中image_input是真实的图片输入,input_image_meta是图片的元数据。虽然input_image对图像的宽高输入的定义是None,看似是对宽高不作限制,但实际在训练和预测的时候是对图片进行resize操作的,它实际是将图片缩放到配置的大小范围内。而input_image_meta主要就是记录上述信息的,其定义的形状如下:1 + 3 + 3 + 4 + 1 + self.NUM_CLASSES。

在这里插入图片描述

然后是第21行和第50行的if语句,这里是对训练和预测的情况进行了不同的处理。

在这里插入图片描述

首先是训练模式下,定义了5个输入:input_rpn_match,input_rpn_bbox,input_gt_class_ids,input_gt_boxes,input_gt_mask。

这5个输入实际上并不能算输入,他们实际是根据标注结果来生成的模型输出对应的真实值(即y_true)。这里将他们作为输入主要是因为其将损失计算也定义在了模型中。这5个输入主要是用在计算损失层中。

这5个输入代表的含义如下:

input_rpn_match:对应RPN网络的输出,这个输出用来判断输出的建议框中是否真实的包含目标。

input_rpn_bbox: 对应RPN网络的输出,这个输出用来表示先验框应该如何调整。

input_gt_class_ids: 对应的分类网络的输出,即图像真实包含的目标的id

input_gt_boxes: 对应分类网络的输出,即图像真实包含的目标的框

input_gt_masks: 对应mask网络的输出,即图像标注的mask

在这里插入图片描述

最后是第50行对预测模式下定义的输入,这里就一个输入:input_anchors。这个输入即先验框。

综上所述,Mask rcnn在训练的时候需要7个输入:input_image,input_image_meta,input_rpn_match,input_rpn_bbox,input_gt_class_ids,input_gt_boxes,input_gt_masks。预测时需要三个参数:input_image,input_image_meta,input_anchors。

这里我们以训练模式为例,讲解Mask Rcnn如何将labelme标注的文件加载成训练需要的7个输入。

在文档数据准备中介绍的labelme的使用方法,labelme转换过后的文件目录如下:

在这里插入图片描述
labelme标注文件目录结构

上述文件中主要会使用三个文件:img.png,info.yaml,label.png。其中img.png是原图片;info.yaml是图片的标注的类别;label.png是图片标注的掩码(mask)。

其中label.png文件使用PIL打开后实际存储的数据内容如下:
在这里插入图片描述
label.png存储内容

这里存储的“0,1,2”实际是代表着info.yaml文件中的标签名称。

在文档模型的训练与预测中,运行的训练脚本中定义的一个用来处理数据的DrugDataset类,这个类在训练脚本中调用的片段如下:

在这里插入图片描述
调用DrugDataset

这里在创建了DrugDataset后调用了load_shapes和prepare方法。创建DrugDataset的init方法如下:

def __init__(self, class_map=None):
    self._image_ids = []
    self.image_info = []
    # Background is always the first class
    self.class_info = [{"source": "", "id": 0, "name": "BG"}]
    self.source_class_ids = {}

这里的init方法是其父类的方法,这里主要是初始化了几个参数。其中_image_ids是用来记录图片id,image_info是用来记录图片信息,class_info用来记录分类信息,source_class_ids是用来记录分类对应的id信息。

然后是load_shapes方法:

def load_shapes(self, count,floder, imglist):
    # Add classes,可通过这种方式扩展多个物体
    self.add_class("shapes", 1, "box")


    for i in range(count):
        # 获取图片宽和高
        
        filestr = imglist[i]
        mask_path = floder + "/" + filestr + "/label.png"
        yaml_path = floder + "/" + filestr + "/info.yaml"

        cv_img = cv2.imread( floder + "/" + filestr + "/img.png")

        self.add_image("shapes", image_id=i, path=floder + "/" + filestr + "/img.png",
                       width=cv_img.shape[1], height=cv_img.shape[0], mask_path=mask_path, yaml_path=yaml_path)

首先分析输入的参数:count是文件夹中文件的数量,floder是文件夹的路径,imglist是通过listdir方法拿到的子文件夹(即用labelme生成的文件夹,一个文件夹代表一条数据)。然后是第3行的add_class方法,这个方法主要用来添加分类信息。最后是第6行的for循环将训练数据的具体信息添加到dataset中。

细看这个for循环,它首先会拿到imglist中对应的文件夹名,然后拼接出mask文件的路径和yaml文件的路径,然后用cv2读取原始图片,最后调用add_image方法存储数据信息。注意这里用cv2读取图片并不是用来存储图片,而是用来获取图片的宽高信息。

这个方法主要调用了两个方法:add_class和add_image方法。首先是add_class方法,其内容如下:

def add_class(self, source, class_id, class_name):
    assert "." not in source, "Source name cannot contain a dot"
    # Does the class exist already?
    for info in self.class_info:
        if info['source'] == source and info["id"] == class_id:
            # source.class_id combination already available, skip
            return
    # Add the class
    self.class_info.append({
        "source": source,
        "id": class_id,
        "name": class_name,
    })

这个方法很简单,就是将传入的信息添加到初始化时创建的class_info中。

add_image方法内容如下:

def add_image(self, source, image_id, path, **kwargs):
    image_info = {
        "id": image_id,
        "source": source,
        "path": path,
    }
    image_info.update(kwargs)
    self.image_info.append(image_info)

这个方法也很简单就是将传入的信息添加到初始化时创建的image_info中。这里需要注意的是这里有两个image_info:方法内的局部变量image_info 和初始化时创建的self.image_info。

然后是prepare方法,该方法内容如下:

def prepare(self, class_map=None):

    def clean_name(name):
        """Returns a shorter version of object names for cleaner display."""
        return ",".join(name.split(",")[:1])

    # Build (or rebuild) everything else from the info dicts.
    self.num_classes = len(self.class_info)
    self.class_ids = np.arange(self.num_classes)
    self.class_names = [clean_name(c["name"]) for c in self.class_info]
    self.num_images = len(self.image_info)
    self._image_ids = np.arange(self.num_images)

    # Mapping from source class and image IDs to internal IDs
    self.class_from_source_map = {"{}.{}".format(info['source'], info['id']): id
                                  for info, id in zip(self.class_info, self.class_ids)}
    self.image_from_source_map = {"{}.{}".format(info['source'], info['id']): id
                                  for info, id in zip(self.image_info, self.image_ids)}

    # Map sources to class_ids they support
    self.sources = list(set([i['source'] for i in self.class_info]))
    self.source_class_ids = {}
    # Loop over datasets
    for source in self.sources:
        self.source_class_ids[source] = []
        # Find classes that belong to this dataset
        for i, info in enumerate(self.class_info):
            # Include BG class in all datasets
            if i == 0 or source == info['source']:
                self.source_class_ids[source].append(i)

这个方法主要是用来处理load_shapes方法加载的数据。首先是第9行根据class_info的长度来生成分类id,然后是第10行从class_info中取出分类的名称。然后是第12行根据image_info的长度来生成图片的id。然后是第15行和第17行创建分类或图片到对应id的映射。最后是处理不同数据源的分类id的映射关系(简单来说就是将标识数据来源的source作为key,同source下的分类id都存储到一个list中,并将这个list作为value存储到字典中)。

自此,文档解析了dataset在创建时主要的操作。从上述文档中可以看出dataset在创建后并没有实际的去加载图片数据,它主要加载的是图片路径,分类信息等元数据。真正加载数据是在真正进行模型训练的时候。

接下来,继续分析模型在训练的时候如何通过元数据加载图片信息。在文档模型的训练与预测中,训练脚本实际使用的train方法来进行训练的,调用片段如下:

在这里插入图片描述

模型训练代码

调用的train方法内容如下:

def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
          augmentation=None, custom_callbacks=None, no_augmentation_sources=None):
    """Train the model.
    train_dataset, val_dataset: Training and validation Dataset objects.
    learning_rate: The learning rate to train with
    epochs: Number of training epochs. Note that previous training epochs
            are considered to be done alreay, so this actually determines
            the epochs to train in total rather than in this particaular
            call.
    layers: Allows selecting wich layers to train. It can be:
        - A regular expression to match layer names to train
        - One of these predefined values:
          heads: The RPN, classifier and mask heads of the network
          all: All the layers
          3+: Train Resnet stage 3 and up
          4+: Train Resnet stage 4 and up
          5+: Train Resnet stage 5 and up
    augmentation: Optional. An imgaug (https://github.com/aleju/imgaug)
        augmentation. For example, passing imgaug.augmenters.Fliplr(0.5)
        flips images right/left 50% of the time. You can pass complex
        augmentations as well. This augmentation applies 50% of the
        time, and when it does it flips images right/left half the time
        and adds a Gaussian blur with a random sigma in range 0 to 5.

            augmentation = imgaug.augmenters.Sometimes(0.5, [
                imgaug.augmenters.Fliplr(0.5),
                imgaug.augmenters.GaussianBlur(sigma=(0.0, 5.0))
            ])
 custom_callbacks: Optional. Add custom callbacks to be called
     with the keras fit_generator method. Must be list of type keras.callbacks.
    no_augmentation_sources: Optional. List of sources to exclude for
        augmentation. A source is string that identifies a dataset and is
        defined in the Dataset class.
    """
    assert self.mode == "training", "Create model in training mode."

    # Pre-defined layer regular expressions
    layer_regex = {
        # all layers but the backbone
        "heads": r"(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
        # From a specific Resnet stage and up
        "3+": r"(res3.*)|(bn3.*)|(res4.*)|(bn4.*)|(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
        "4+": r"(res4.*)|(bn4.*)|(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
        "5+": r"(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
        # All layers
        "all": ".*",
    }
    if layers in layer_regex.keys():
        layers = layer_regex[layers]

    # Data generators
    train_generator = data_generator(train_dataset, self.config, shuffle=True,
                                     augmentation=augmentation,
                                     batch_size=self.config.BATCH_SIZE,
                                     no_augmentation_sources=no_augmentation_sources)
    val_generator = data_generator(val_dataset, self.config, shuffle=True,
                                   batch_size=self.config.BATCH_SIZE)

    # Create log_dir if it does not exist
    if not os.path.exists(self.log_dir):
        os.makedirs(self.log_dir)

    # Callbacks
    callbacks = [
        keras.callbacks.TensorBoard(log_dir=self.log_dir,
                                    histogram_freq=0, write_graph=True, write_images=False),
        keras.callbacks.ModelCheckpoint(self.checkpoint_path,
                                        verbose=0, save_weights_only=True),
    ]

    # Add custom callbacks to the list
    if custom_callbacks:
        callbacks += custom_callbacks

    # Train
    log("\nStarting at epoch {}. LR={}\n".format(self.epoch, learning_rate))
    log("Checkpoint Path: {}".format(self.checkpoint_path))
    self.set_trainable(layers)
    self.compile(learning_rate, self.config.LEARNING_MOMENTUM)



    # Work-around for Windows: Keras fails on Windows when using
    # multiprocessing workers. See discussion here:
    # https://github.com/matterport/Mask_RCNN/issues/13#issuecomment-353124009
    if os.name is 'nt':
        workers = 0
    else:
        workers = multiprocessing.cpu_count()

    self.keras_model.fit_generator(
        train_generator,
        initial_epoch=self.epoch,
        epochs=epochs,
        steps_per_epoch=self.config.STEPS_PER_EPOCH,
        callbacks=callbacks,
        validation_data=val_generator,
        validation_steps=self.config.VALIDATION_STEPS,
        max_queue_size=20,
        workers=workers,
        use_multiprocessing=True,
    )
    self.epoch = max(self.epoch, epochs)

跳过注释,首先看第38行,这里是对layers参数进行处理。根据迁移学习的理念,对于这种模型主要结构相同只是输出的分类结果不同的情况,没有必要完全重新训练模型的所有参数,可以固定模型主干网络的参数,只训练模型分类网络的参数。这里的layers便是用做以上处理。然后是第52和56行通过data_generator方法来处理输入数据,这个稍后在详细分析。然后是第64行创建模型训练时的回调函数。然后是第78,79行编译模型,最后是第91行调用fit_generator方法来进行训练。

从上述分析中可知,数据处理主要是通过data_generator方法来进行的,该方法的主要内容如下:

def data_generator(dataset, config, shuffle=True, augment=False, augmentation=None,
                   random_rois=0, batch_size=1, detection_targets=False,
                   no_augmentation_sources=None):
                   
    b = 0  # batch item index
    image_index = -1
    image_ids = np.copy(dataset.image_ids)
    error_count = 0
    no_augmentation_sources = no_augmentation_sources or []

    # Anchors
    # [anchor_count, (y1, x1, y2, x2)]
    backbone_shapes = compute_backbone_shapes(config, config.IMAGE_SHAPE)
    anchors = utils.generate_pyramid_anchors(config.RPN_ANCHOR_SCALES,
                                             config.RPN_ANCHOR_RATIOS,
                                             backbone_shapes,
                                             config.BACKBONE_STRIDES,
                                             config.RPN_ANCHOR_STRIDE)

    # Keras requires a generator to run indefinitely.
    while True:
        try:
            # Increment index to pick next image. Shuffle if at the start of an epoch.
            image_index = (image_index + 1) % len(image_ids)
            if shuffle and image_index == 0:
                np.random.shuffle(image_ids)

            # Get GT bounding boxes and masks for image.
            image_id = image_ids[image_index]

            # If the image source is not to be augmented pass None as augmentation
            if dataset.image_info[image_id]['source'] in no_augmentation_sources:
                image, image_meta, gt_class_ids, gt_boxes, gt_masks = \
                load_image_gt(dataset, config, image_id, augment=augment,
                              augmentation=None,
                              use_mini_mask=config.USE_MINI_MASK)
            else:
                image, image_meta, gt_class_ids, gt_boxes, gt_masks = \
                    load_image_gt(dataset, config, image_id, augment=augment,
                                augmentation=augmentation,
                                use_mini_mask=config.USE_MINI_MASK)

            # Skip images that have no instances. This can happen in cases
            # where we train on a subset of classes and the image doesn't
            # have any of the classes we care about.
            if not np.any(gt_class_ids > 0):
                continue

            # RPN Targets
            rpn_match, rpn_bbox = build_rpn_targets(image.shape, anchors,
                                                    gt_class_ids, gt_boxes, config)

            # Mask R-CNN Targets
            if random_rois:
                rpn_rois = generate_random_rois(
                    image.shape, random_rois, gt_class_ids, gt_boxes)
                if detection_targets:
                    rois, mrcnn_class_ids, mrcnn_bbox, mrcnn_mask =\
                        build_detection_targets(
                            rpn_rois, gt_class_ids, gt_boxes, gt_masks, config)

            # Init batch arrays
            if b == 0:
                batch_image_meta = np.zeros(
                    (batch_size,) + image_meta.shape, dtype=image_meta.dtype)
                batch_rpn_match = np.zeros(
                    [batch_size, anchors.shape[0], 1], dtype=rpn_match.dtype)
                batch_rpn_bbox = np.zeros(
                    [batch_size, config.RPN_TRAIN_ANCHORS_PER_IMAGE, 4], dtype=rpn_bbox.dtype)
                batch_images = np.zeros(
                    (batch_size,) + image.shape, dtype=np.float32)
                batch_gt_class_ids = np.zeros(
                    (batch_size, config.MAX_GT_INSTANCES), dtype=np.int32)
                batch_gt_boxes = np.zeros(
                    (batch_size, config.MAX_GT_INSTANCES, 4), dtype=np.int32)
                batch_gt_masks = np.zeros(
                    (batch_size, gt_masks.shape[0], gt_masks.shape[1],
                     config.MAX_GT_INSTANCES), dtype=gt_masks.dtype)
                if random_rois:
                    batch_rpn_rois = np.zeros(
                        (batch_size, rpn_rois.shape[0], 4), dtype=rpn_rois.dtype)
                    if detection_targets:
                        batch_rois = np.zeros(
                            (batch_size,) + rois.shape, dtype=rois.dtype)
                        batch_mrcnn_class_ids = np.zeros(
                            (batch_size,) + mrcnn_class_ids.shape, dtype=mrcnn_class_ids.dtype)
                        batch_mrcnn_bbox = np.zeros(
                            (batch_size,) + mrcnn_bbox.shape, dtype=mrcnn_bbox.dtype)
                        batch_mrcnn_mask = np.zeros(
                            (batch_size,) + mrcnn_mask.shape, dtype=mrcnn_mask.dtype)

            # If more instances than fits in the array, sub-sample from them.
            if gt_boxes.shape[0] > config.MAX_GT_INSTANCES:
                ids = np.random.choice(
                    np.arange(gt_boxes.shape[0]), config.MAX_GT_INSTANCES, replace=False)
                gt_class_ids = gt_class_ids[ids]
                gt_boxes = gt_boxes[ids]
                gt_masks = gt_masks[:, :, ids]

            # Add to batch
            batch_image_meta[b] = image_meta
            batch_rpn_match[b] = rpn_match[:, np.newaxis]
            batch_rpn_bbox[b] = rpn_bbox
            batch_images[b] = mold_image(image.astype(np.float32), config)
            batch_gt_class_ids[b, :gt_class_ids.shape[0]] = gt_class_ids
            batch_gt_boxes[b, :gt_boxes.shape[0]] = gt_boxes
            batch_gt_masks[b, :, :, :gt_masks.shape[-1]] = gt_masks
            if random_rois:
                batch_rpn_rois[b] = rpn_rois
                if detection_targets:
                    batch_rois[b] = rois
                    batch_mrcnn_class_ids[b] = mrcnn_class_ids
                    batch_mrcnn_bbox[b] = mrcnn_bbox
                    batch_mrcnn_mask[b] = mrcnn_mask
            b += 1

            # Batch full?
            if b >= batch_size:
                inputs = [batch_images, batch_image_meta, batch_rpn_match, batch_rpn_bbox,
                          batch_gt_class_ids, batch_gt_boxes, batch_gt_masks]
                outputs = []

                if random_rois:
                    inputs.extend([batch_rpn_rois])
                    if detection_targets:
                        inputs.extend([batch_rois])
                        # Keras requires that output and targets have the same number of dimensions
                        batch_mrcnn_class_ids = np.expand_dims(
                            batch_mrcnn_class_ids, -1)
                        outputs.extend(
                            [batch_mrcnn_class_ids, batch_mrcnn_bbox, batch_mrcnn_mask])

                yield inputs, outputs

                # start a new batch
                b = 0
        except (GeneratorExit, KeyboardInterrupt):
            raise
        except:
            # Log it and skip the image
            logging.exception("Error processing image {}".format(
                dataset.image_info[image_id]))
            error_count += 1
            if error_count > 5:
                raise

这个方法较长,主要作用是实现数据的加载。它主要是通过python的yield关键字来实现generator。

首先是第4行到第9行,这里主要是初始化一些参数。

然后是第13行和第14行,这两行代码的主要作用是获取先验框。mask rcnn及其相关的一系列模型的主要思路都是先利用一个固定算法生成指定数量的固定框(及先验框),然后再对这些固定框进行调整,并去掉重合率高的框,得到检测目标所在位置,然后根据框里的数据来预测类别或mask。

获取先验框主要是通过第14行的generate_pyramid_anchors方法,第13行的compute_backbone_shapes方法主要是需要根据图片的形状来计算主干网络输出的特征的形状。主干网络默认使用的resnet101,通过resnet101构建的特征金字塔网络(FPN)会输出5个相关的特征层,这5个特征层相对于原图的倍率是固定的,这里的compute_backbone_shapes方法便是通过用原图的形状除以对应的倍率来获取形状的。获取到特征层的输出形状后,第14行的generate_pyramid_anchors方法会对所有特征层的所有点做相同的操作,即以该该点为中心点生成三个矩形框。

然后便是第21行的while循环,这里还有一个细节,上面生成先验框的操作放在while循环外,这里是因为所有的图像的先验框都相同。先验框只和图像的形状有关,在实际加载图片的时候mask rcnn会将不同形状的图片resize到相同形状。

while循环内部主要作用就是按批次读取数据,读取完一批次数据后便通过yield输出。

首先是第24行到第29行,这里主要是用来选取需要加载的图片,通过图片id来选取。对于需要shuffle的情况,他是将存储图片id 的list(image_ids)进行shuffle。

然后是第32行到第42行,这里在确认需要加载的图片后,便调用load_image_gt方法来加载数据。然后是第50行调用build_rpn_targets方法来创建rpn网络的输出。在mask rcnn中实际是创建了两个模型:一个是rpn网络模型,一个是实例分割模型。rpn网络主要负责处理先验框。这里的build_rpn_targets方法就是来生成rpn的实际标签。

随后的代码主要是处理批次问题,每一个新的批次首先初始化一个np的数组,然后将数据都加载到数组中,最后再输出。

加载数据的重点在于load_image_gt方法,该方法内容如下:

def load_image_gt(dataset, config, image_id, augment=False, augmentation=None,
                  use_mini_mask=False):

    # Load image and mask
    image = dataset.load_image(image_id)
    mask, class_ids = dataset.load_mask(image_id)
    original_shape = image.shape
    image, window, scale, padding, crop = utils.resize_image(
        image,
        min_dim=config.IMAGE_MIN_DIM,
        min_scale=config.IMAGE_MIN_SCALE,
        max_dim=config.IMAGE_MAX_DIM,
        mode=config.IMAGE_RESIZE_MODE)
    mask = utils.resize_mask(mask, scale, padding, crop)

    # Random horizontal flips.
    # TODO: will be removed in a future update in favor of augmentation
    if augment:
        logging.warning("'augment' is deprecated. Use 'augmentation' instead.")
        if random.randint(0, 1):
            image = np.fliplr(image)
            mask = np.fliplr(mask)

    # Augmentation
    # This requires the imgaug lib (https://github.com/aleju/imgaug)
    if augmentation:
        import imgaug

        # Augmenters that are safe to apply to masks
        # Some, such as Affine, have settings that make them unsafe, so always
        # test your augmentation on masks
        MASK_AUGMENTERS = ["Sequential", "SomeOf", "OneOf", "Sometimes",
                           "Fliplr", "Flipud", "CropAndPad",
                           "Affine", "PiecewiseAffine"]

        def hook(images, augmenter, parents, default):
            """Determines which augmenters to apply to masks."""
            return augmenter.__class__.__name__ in MASK_AUGMENTERS

        # Store shapes before augmentation to compare
        image_shape = image.shape
        mask_shape = mask.shape
        # Make augmenters deterministic to apply similarly to images and masks
        det = augmentation.to_deterministic()
        image = det.augment_image(image)
        # Change mask to np.uint8 because imgaug doesn't support np.bool
        mask = det.augment_image(mask.astype(np.uint8),
                                 hooks=imgaug.HooksImages(activator=hook))
        # Verify that shapes didn't change
        assert image.shape == image_shape, "Augmentation shouldn't change image size"
        assert mask.shape == mask_shape, "Augmentation shouldn't change mask size"
        # Change mask back to bool
        mask = mask.astype(np.bool)

    # Note that some boxes might be all zeros if the corresponding mask got cropped out.
    # and here is to filter them out
    _idx = np.sum(mask, axis=(0, 1)) > 0
    mask = mask[:, :, _idx]
    print(_idx)
    class_ids = class_ids[_idx]
    # Bounding boxes. Note that some boxes might be all zeros
    # if the corresponding mask got cropped out.
    # bbox: [num_instances, (y1, x1, y2, x2)]
    bbox = utils.extract_bboxes(mask)

    # Active classes
    # Different datasets have different classes, so track the
    # classes supported in the dataset of this image.
    active_class_ids = np.zeros([dataset.num_classes], dtype=np.int32)
    source_class_ids = dataset.source_class_ids[dataset.image_info[image_id]["source"]]
    active_class_ids[source_class_ids] = 1

    # Resize masks to smaller size to reduce memory usage
    if use_mini_mask:
        mask = utils.minimize_mask(bbox, mask, config.MINI_MASK_SHAPE)

    # Image meta data
    image_meta = compose_image_meta(image_id, original_shape, image.shape,
                                    window, scale, active_class_ids)

    return image, image_meta, class_ids, bbox, mask

这里还是先看这个方法的主要内容,然后再细看是怎么实现的。首先是第5行,加载图片片。然后是第6行加载mask和分类id。然后是第8行和第14行resize图片和mask。然后第18行到第53行主要是数据增强的相关功能。然后是第60行获取分类的对应的class_id。然后是第64行通过mask生产目标所在的真实框。最后是第78行生成图片的元数据(image_meta)。

在上文中提到了在训练时模型需要7个输入参数,而这里提供了5个参数,剩下的两个在build_rpn_targets方法中。这里先细看这5个参数是如何产生的。

首先是image,这个参数是在第5行调用dataset的load_image方法产生的。该方法内容如下:

def load_image(self,image_id):
    """
        加载图片
    :param image_id:
    :return:
    """

    image= skimage.io.imread(self.image_info[image_id]["path"])
    return image

这个方法很简单,就是加载图片。

然后是mask 和class_ids,这是在第6行调用dataset的load_mask方法生成的。该方法内容如下:

def load_mask(self, image_id):
    """Generate instance masks for shapes of the given image ID.
    """
    global iter_num
    print("image_id",image_id)
    info = self.image_info[image_id]
    count = 1  # number of object
    img = Image.open(info['mask_path'])
    num_obj = self.get_obj_index(img)
    mask = np.zeros([info['height'], info['width'], num_obj], dtype=np.uint8)
    mask = self.draw_mask(num_obj, mask, img,image_id)
    occlusion = np.logical_not(mask[:, :, -1]).astype(np.uint8)
    for i in range(count - 2, -1, -1):
        mask[:, :, i] = mask[:, :, i] * occlusion

        occlusion = np.logical_and(occlusion, np.logical_not(mask[:, :, i]))
    labels = []
    labels = self.from_yaml_get_class(image_id)
    labels_form = []
    for i in range(len(labels)):
        if labels[i].find("box") != -1:
            # print "box"
            labels_form.append("box")


    class_ids = np.array([self.class_names.index(s) for s in labels_form])
    return mask, class_ids.astype(np.int32)

首先是第6行根据图片id获取到相关的信息,即在load_shapes方法中存储到image_info中的图片路径等信息。然后是第8行打开mask图片(即labelme生成的label.png)。然后调用get_obj_index方法获取图片中有多少个实例对象。在讲解labelme工具的时候提到过label.png中存储的是:0,1,2。 而这些数字是对应的yaml文件中分类的索引。这里的get_obj_index方法是直接获取最大的索引值来作为实例的个数(这里在标注是需要规定同一分类下的不同个体在标注的时候需要添加不同的后缀来区分。例如:有一个分类为box,而同一张图片中有两个box,这时这两个box需要分别标注成box1和box2)。

然后是第10行根据图像的宽高和实例个数来初始化mask。这里的mask不能直接使用label.png加载出来是图像,这是因为mask rcnn要求的mask格式和label.png生成的图片格式不一样。mask rcnn要求mask的宽高和图片一样,而通道数为实例的个数。通道的每一层在其对应的分类目标所在位置的数值设置为1,其余的设置为0。格式的转换是通过第11行的draw_mask方法来进行转换的。

然后是第17行到第27行,这里是对分类标签的处理。首先是第18行从yaml文件中加载标签信息,然后是第20行的for循环这里是在处理上文提到的多实例的问题,这里实际就是去掉标签的后缀。然后是第26行从将分类标签转换为分类id。

以上便是加载mask和class_ids的过程,接下来便是生成bbox的方式。其代码如下:

def extract_bboxes(mask):
    """Compute bounding boxes from masks.
    mask: [height, width, num_instances]. Mask pixels are either 1 or 0.

    Returns: bbox array [num_instances, (y1, x1, y2, x2)].
    """
    boxes = np.zeros([mask.shape[-1], 4], dtype=np.int32)
    for i in range(mask.shape[-1]):
        m = mask[:, :, i]
        # Bounding box.
        a = np.any(m, axis=0)
        b = np.where(np.any(m, axis=0))

        horizontal_indicies = np.where(np.any(m, axis=0))[0]
        vertical_indicies = np.where(np.any(m, axis=1))[0]
        if horizontal_indicies.shape[0]:
            x1, x2 = horizontal_indicies[[0, -1]]
            y1, y2 = vertical_indicies[[0, -1]]
            # x2 and y2 should not be part of the box. Increment by 1.
            x2 += 1
            y2 += 1
        else:
            # No mask for this instance. Might happen due to
            # resizing or cropping. Set bbox to zeros
            x1, x2, y1, y2 = 0, 0, 0, 0
        boxes[i] = np.array([y1, x1, y2, x2])
    return boxes.astype(np.int32)

上述便是通过mask生成bbox(真实框)的代码,其主要思路如下:mask的每一层代表的一个实例,每个实例都应有一个对应bbox。而mask的每一层都是由0和1组成的,有1的代表有目标。所有取1最边缘的位置作为bbox的坐标便可。如下图所示:
在这里插入图片描述
bbox取值示意图

如上图所示,可以获取到目标上下左右的极限位置,然后用这些位置来作为bbox 的坐标。

最后是 image_meta,他是在第78行通过调用compose_image_meta方法生成的,其内容如下:

def compose_image_meta(image_id, original_image_shape, image_shape,
                       window, scale, active_class_ids):
    """Takes attributes of an image and puts them in one 1D array.

    image_id: An int ID of the image. Useful for debugging.
    original_image_shape: [H, W, C] before resizing or padding.
    image_shape: [H, W, C] after resizing and padding
    window: (y1, x1, y2, x2) in pixels. The area of the image where the real
            image is (excluding the padding)
    scale: The scaling factor applied to the original image (float32)
    active_class_ids: List of class_ids available in the dataset from which
        the image came. Useful if training on images from multiple datasets
        where not all classes are present in all datasets.
    """
    meta = np.array(
        [image_id] +                  # size=1
        list(original_image_shape) +  # size=3
        list(image_shape) +           # size=3
        list(window) +                # size=4 (y1, x1, y2, x2) in image cooredinates
        [scale] +                     # size=1
        list(active_class_ids)        # size=num_classes
    )
    return meta

这个方法很简单,如同上文所述,只是将图片的元数据整合生成一个numpy格式的数据而已。

自此,load_image_gt方法创建的5个参数都解析完了,接下来继续解析build_rpn_targets方法生成剩下两个参数的方式。该方法的内容如下:

def build_rpn_targets(image_shape, anchors, gt_class_ids, gt_boxes, config):
    """Given the anchors and GT boxes, compute overlaps and identify positive
    anchors and deltas to refine them to match their corresponding GT boxes.

    anchors: [num_anchors, (y1, x1, y2, x2)]
    gt_class_ids: [num_gt_boxes] Integer class IDs.
    gt_boxes: [num_gt_boxes, (y1, x1, y2, x2)]

    Returns:
    rpn_match: [N] (int32) matches between anchors and GT boxes.
               1 = positive anchor, -1 = negative anchor, 0 = neutral
    rpn_bbox: [N, (dy, dx, log(dh), log(dw))] Anchor bbox deltas.
    """
    # RPN Match: 1 = positive anchor, -1 = negative anchor, 0 = neutral
    rpn_match = np.zeros([anchors.shape[0]], dtype=np.int32)
    # RPN bounding boxes: [max anchors per image, (dy, dx, log(dh), log(dw))]
    rpn_bbox = np.zeros((config.RPN_TRAIN_ANCHORS_PER_IMAGE, 4))

    # Handle COCO crowds
    # A crowd box in COCO is a bounding box around several instances. Exclude
    # them from training. A crowd box is given a negative class ID.
    crowd_ix = np.where(gt_class_ids < 0)[0]
    if crowd_ix.shape[0] > 0:
        # Filter out crowds from ground truth class IDs and boxes
        non_crowd_ix = np.where(gt_class_ids > 0)[0]
        crowd_boxes = gt_boxes[crowd_ix]
        gt_class_ids = gt_class_ids[non_crowd_ix]
        gt_boxes = gt_boxes[non_crowd_ix]
        # Compute overlaps with crowd boxes [anchors, crowds]
        crowd_overlaps = utils.compute_overlaps(anchors, crowd_boxes)
        crowd_iou_max = np.amax(crowd_overlaps, axis=1)
        no_crowd_bool = (crowd_iou_max < 0.001)
    else:
        # All anchors don't intersect a crowd
        no_crowd_bool = np.ones([anchors.shape[0]], dtype=bool)

    # Compute overlaps [num_anchors, num_gt_boxes]
    overlaps = utils.compute_overlaps(anchors, gt_boxes)

    # Match anchors to GT Boxes
    # If an anchor overlaps a GT box with IoU >= 0.7 then it's positive.
    # If an anchor overlaps a GT box with IoU < 0.3 then it's negative.
    # Neutral anchors are those that don't match the conditions above,
    # and they don't influence the loss function.
    # However, don't keep any GT box unmatched (rare, but happens). Instead,
    # match it to the closest anchor (even if its max IoU is < 0.3).
    #
    # 1. Set negative anchors first. They get overwritten below if a GT box is
    # matched to them. Skip boxes in crowd areas.
    anchor_iou_argmax = np.argmax(overlaps, axis=1)
    anchor_iou_max = overlaps[np.arange(overlaps.shape[0]), anchor_iou_argmax]
    rpn_match[(anchor_iou_max < 0.3) & (no_crowd_bool)] = -1
    # 2. Set an anchor for each GT box (regardless of IoU value).
    # If multiple anchors have the same IoU match all of them
    gt_iou_argmax = np.argwhere(overlaps == np.max(overlaps, axis=0))[:,0]
    rpn_match[gt_iou_argmax] = 1
    # 3. Set anchors with high overlap as positive.
    rpn_match[anchor_iou_max >= 0.7] = 1

    # Subsample to balance positive and negative anchors
    # Don't let positives be more than half the anchors
    ids = np.where(rpn_match == 1)[0]
    extra = len(ids) - (config.RPN_TRAIN_ANCHORS_PER_IMAGE // 2)
    if extra > 0:
        # Reset the extra ones to neutral
        ids = np.random.choice(ids, extra, replace=False)
        rpn_match[ids] = 0
    # Same for negative proposals
    ids = np.where(rpn_match == -1)[0]
    extra = len(ids) - (config.RPN_TRAIN_ANCHORS_PER_IMAGE -
                        np.sum(rpn_match == 1))
    if extra > 0:
        # Rest the extra ones to neutral
        ids = np.random.choice(ids, extra, replace=False)
        rpn_match[ids] = 0

    # For positive anchors, compute shift and scale needed to transform them
    # to match the corresponding GT boxes.
    ids = np.where(rpn_match == 1)[0]
    ix = 0  # index into rpn_bbox
    # TODO: use box_refinement() rather than duplicating the code here
    for i, a in zip(ids, anchors[ids]):
        # Closest gt box (it might have IoU < 0.7)
        gt = gt_boxes[anchor_iou_argmax[i]]

        # Convert coordinates to center plus width/height.
        # GT Box
        gt_h = gt[2] - gt[0]
        gt_w = gt[3] - gt[1]
        gt_center_y = gt[0] + 0.5 * gt_h
        gt_center_x = gt[1] + 0.5 * gt_w
        # Anchor
        a_h = a[2] - a[0]
        a_w = a[3] - a[1]
        a_center_y = a[0] + 0.5 * a_h
        a_center_x = a[1] + 0.5 * a_w

        # Compute the bbox refinement that the RPN should predict.
        rpn_bbox[ix] = [
            (gt_center_y - a_center_y) / a_h,
            (gt_center_x - a_center_x) / a_w,
            np.log(gt_h / a_h),
            np.log(gt_w / a_w),
        ]
        # Normalize
        rpn_bbox[ix] /= config.RPN_BBOX_STD_DEV
        ix += 1

    return rpn_match, rpn_bbox

这个方法返回的两个参数rpn_match和rpn_bbox是rpn网络模型的标签。在具体分析这个方法之前,先简单的介绍一下rpn模型。

在之前提到过mask rcnn 是基于一种区域搜索的思路来实现的。他会分成两个阶段:首先在图片中搜索可能出现目标的区域,然后再对搜索出来的可能性高的区域进行分类等任务。这里搜索目标可能出现的区域的工作便是由先验框和rpn网络来完成的。在之前提到过先验框是通过一种固定的算法生成的固定位置的框,它主要和图片的形状和模型的主干网络有关。

而rpn网络的作用有两个:第一个是判断先验框包含目标的概率有多少,第二个是如果包含目标那么这个框应该如何调整才能包含整个目标(这是因为先验框的位置是固定的,它很有可能只包含了目标的一部分,所以需要一个调整参数来调整先验框的位置)。

为了完成上述目标,rpn的输出有两个:一个用来表示对应先验框是否包含目标的概率值(与上文中rpn_match对应),一个用来表示先验框应该如何调整(与上文的rpn_bbox对应)。

了解了rpn网络的作用后,再来分析这个方法便很简单了。先验框在上文分析的其生成方式,先验框是否包含目标,主要是同通过计算先验框与目标所在的真实框的重合度(iou)来确定。真实框即上文分析了bbox参数的值。这里rpn_match的取值有三个:-1,0,1。 1代表正样本,-1代表负样本,0代表忽略的样本。其中先验框与真实框的iou的最大值和大于0.7的值都会被设置为正标签,小于0.3的值会被设置为负标签,其余的都为忽略值。同时为了平衡正负样本,在实际使用的时候只取了部分样本。

对于调整参数,这里只对正样本计算了其调整参数。其主要是通过调整框的中心点位置与框的宽高来调整框的位置。具体方式在代码的82行到107行。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值