【MaskRCNN】源码系列一:train数据处理一

87 篇文章 20 订阅
43 篇文章 5 订阅

目录

coco数据

数据读取

 代码句句解读

data_generator(最关键部分)


coco数据

首先我处理的是coco数据集,coco数据集具体的文件夹路径和方式为:

有人 文件夹中没有annotations,只有annotations_trainval2014,那只需要将annotations_trainval2014改为annotations。

在这个annotations文件夹中只有instances_val2014.json,没有instances_minival2014.json和instances_valminusminival2014.json,那只要将instances_val2014.json复制2份,然后分别命名为instances_minival2014.json和instances_valminusminival2014.json即可。

如果想了解json文件中是什么,请查看https://blog.csdn.net/u013066730/article/details/100578941 

数据读取

在maskrcnn/samples/coco/coco.py中的main函数中出现了读数据的代码:

    if args.command == "train":
        # Training dataset. Use the training set and 35K from the
        # validation set, as as in the Mask RCNN paper.
        dataset_train = CocoDataset()
        dataset_train.load_coco(args.dataset, "train", year=args.year, auto_download=args.download)
        if args.year in '2014':
            dataset_train.load_coco(args.dataset, "valminusminival", year=args.year, auto_download=args.download)
        dataset_train.prepare()

        # Validation dataset
        dataset_val = CocoDataset()
        val_type = "val" if args.year in '2017' else "minival"
        dataset_val.load_coco(args.dataset, val_type, year=args.year, auto_download=args.download)
        dataset_val.prepare()

        # Image Augmentation
        # Right/Left flip 50% of the time
        augmentation = imgaug.augmenters.Fliplr(0.5)

        # *** This training schedule is an example. Update to your needs ***

        # Training - Stage 1
        print("Training network heads")
        model.train(dataset_train, dataset_val,
                    learning_rate=config.LEARNING_RATE,
                    epochs=40,
                    layers='heads',
                    augmentation=augmentation)

 代码句句解读

dataset_train = CocoDataset()

CocoDataset的实例化,CocoDataset继承于utils.Dataset,在utils.Dataset类中,包含了对数据最基本的处理。等后面用到类中的函数时再具体介绍。

 


 

dataset_train.load_coco(args.dataset, "train", year=args.year, auto_download=args.download)

dataset_train对象调用了load_coco函数,输入的参数是

dataset_dir        =  args.dataset    =  "E:\data\coco2014"

subset                =  "train"

year                    =   args.year       =  2014

class_ids           =  None

class_map         =  None

return_coco       =  False

auto_download = args.download =  False

    def load_coco(self, dataset_dir, subset, year=DEFAULT_DATASET_YEAR, class_ids=None,
                  class_map=None, return_coco=False, auto_download=False):

下面来一句句解读代码: 

if auto_download is True:
            self.auto_download(dataset_dir, subset, year)

由于auto_download是False,不需要下载,直接跳过。

        coco = COCO("{}/annotations/instances_{}{}.json".format(dataset_dir, subset, year))
        if subset == "minival" or subset == "valminusminival":
            subset = "val"
        image_dir = "{}/{}{}".format(dataset_dir, subset, year)

coco=COCO("E:/data/coco2014/annotations/instance_train2014.json")

image_dir = "E:/data/coco2014/train2014"

    if not class_ids:
            # All classes
            class_ids = sorted(coco.getCatIds())

得到的class_ids为<class 'list'>: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]

这表明我们使用了coco当中的这些类别,正好80类(这里是没有算背景的)。

        # All images or a subset?
        if class_ids:
            image_ids = []
            for id in class_ids:
                image_ids.extend(list(coco.getImgIds(catIds=[id])))
            # Remove duplicates
            image_ids = list(set(image_ids))
        else:
            # All images
            image_ids = list(coco.imgs.keys())

由于class_ids有值,所以进入了上面这个判断,这里主要做的事情就是将相同类别的数据的图片id组合到一个list 当中。

        for i in class_ids:
            self.add_class("coco", i, coco.loadCats(i)[0]["name"])

上面这段调用了

    def add_class(self, source, class_id, class_name):
        assert "." not in source, "Source name cannot contain a dot"
        # Does the class exist already?
        for info in self.class_info:
            if info['source'] == source and info["id"] == class_id:
                # source.class_id combination already available, skip
                return
        # Add the class
        self.class_info.append({
            "source": source,
            "id": class_id,
            "name": class_name,
        })

 上面这一段在构造一个字典,这个字典中保存了source=数据来源coco,id=图像类别id,name=类别id对应的类别名称。

        # Add images
        for i in image_ids:
            self.add_image(
                "coco", image_id=i,
                path=os.path.join(image_dir, coco.imgs[i]['file_name']),
                width=coco.imgs[i]["width"],
                height=coco.imgs[i]["height"],
                annotations=coco.loadAnns(coco.getAnnIds(
                    imgIds=[i], catIds=class_ids, iscrowd=None)))

上面这段代码调用了add_image函数

    def add_image(self, source, image_id, path, **kwargs):
        image_info = {
            "id": image_id,
            "source": source,
            "path": path,
        }
        image_info.update(kwargs)
        self.image_info.append(image_info)

 这小段代码实现的是将图像的各种信息组成一个字典,这段也是读取数据的关键部分,id=图像id(必须是唯一的),source=数据集名称coco,path=具体图像path,width=图像的宽,height=图像的高,annotations=其实读的就是json文件(具体可以参考https://blog.csdn.net/u013066730/article/details/100578941)。

 


 

        if args.year in '2014':
            dataset_train.load_coco(args.dataset, "valminusminival", year=args.year, auto_download=args.download)

这段主函数的代码和上面load_coco的程序是一样的。

 


        dataset_train.prepare()

调用了父类中的prepare函数

    def prepare(self, class_map=None):
        """Prepares the Dataset class for use.

        TODO: class map is not supported yet. When done, it should handle mapping
              classes from different datasets to the same class ID.
        """

        def clean_name(name):
            """Returns a shorter version of object names for cleaner display."""
            return ",".join(name.split(",")[:1])

        # Build (or rebuild) everything else from the info dicts.
        self.num_classes = len(self.class_info)
        self.class_ids = np.arange(self.num_classes)
        self.class_names = [clean_name(c["name"]) for c in self.class_info]
        self.num_images = len(self.image_info)
        self._image_ids = np.arange(self.num_images)

        # Mapping from source class and image IDs to internal IDs
        self.class_from_source_map = {"{}.{}".format(info['source'], info['id']): id
                                      for info, id in zip(self.class_info, self.class_ids)}
        self.image_from_source_map = {"{}.{}".format(info['source'], info['id']): id
                                      for info, id in zip(self.image_info, self.image_ids)}

        # Map sources to class_ids they support
        self.sources = list(set([i['source'] for i in self.class_info]))
        self.source_class_ids = {}
        # Loop over datasets
        for source in self.sources:
            self.source_class_ids[source] = []
            # Find classes that belong to this dataset
            for i, info in enumerate(self.class_info):
                # Include BG class in all datasets
                if i == 0 or source == info['source']:
                    self.source_class_ids[source].append(i)

输入参数:class_map=None

得到的结果:

self.num_class:81,表示一共81类。

self.class_ids:[0,1,2,3...,80];

self.class_names:['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'],具体的类别名称。

self.num_images:122218,一共有这么多张图像;

self._image_ids:[0,1,2,3,4...122217];

self.class_from_source_map:{'coco.61': 56, 'coco.86': 76, 'coco.16': 15, 'coco.84': 74, 'coco.3': 3, 'coco.88': 78, 'coco.77': 68, 'coco.81': 72, 'coco.73': 64, 'coco.11': 11, 'coco.38': 34, 'coco.57': 52, 'coco.54': 49, 'coco.25': 24, 'coco.80': 71, 'coco.51': 46, 'coco.56': 51, 'coco.13': 12, 'coco.15': 14, 'coco.14': 13, 'coco.67': 61, 'coco.49': 44, 'coco.46': 41, 'coco.79': 70, 'coco.20': 19, 'coco.17': 16, 'coco.32': 28, 'coco.52': 47, 'coco.48': 43, 'coco.4': 4, 'coco.65': 60, 'coco.34': 30, 'coco.27': 25, 'coco.22': 21, 'coco.50': 45, 'coco.75': 66, 'coco.82': 73, 'coco.47': 42, 'coco.70': 62, 'coco.43': 39, 'coco.31': 27, 'coco.74': 65, 'coco.19': 18, 'coco.21': 20, 'coco.72': 63, 'coco.33': 29, 'coco.2': 2, 'coco.9': 9, 'coco.59': 54, 'coco.63': 58, 'coco.1': 1, 'coco.10': 10, 'coco.62': 57, 'coco.53': 48, 'coco.6': 6, 'coco.37': 33, 'coco.36': 32, 'coco.90': 80, 'coco.89': 79, '.0': 0, 'coco.87': 77, 'coco.60': 55, 'coco.76': 67, 'coco.35': 31, 'coco.85': 75, 'coco.18': 17, 'coco.44': 40, 'coco.8': 8, 'coco.28': 26, 'coco.23': 22, 'coco.24': 23, 'coco.7': 7, 'coco.39': 35, 'coco.5': 5, 'coco.41': 37, 'coco.64': 59, 'coco.78': 69, 'coco.40': 36, 'coco.55': 50, 'coco.42': 38, 'coco.58': 53}

self.image_from_source_map:<class 'dict'>: <Too big to print. Len: 122218>,数据具体的样式为{'coco.103720':66675, 'coco.10039':85616 ...}

self.sources: ['', 'coco']

self.source_class_ids:循环上面的self.sources,第一个为‘’,是空,会被直接跳过,进入第二个‘coco’,其结果为

{'': [0], 'coco': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80]}

 


 

augmentation = imgaug.augmenters.Fliplr(0.5)

进行数据增强,这个imgaug库我这里不做介绍,请自行参考https://github.com/aleju/imgaug

 


 

        model.train(dataset_train, dataset_val,
                    learning_rate=config.LEARNING_RATE,
                    epochs=40,
                    layers='heads',
                    augmentation=augmentation)

上面这段代码是主函数中进入train接口时,传入了dataset_train数据类。

这里我不对train进行介绍,只介绍数据处理的部分。

这个model.train调用了mrcnn/model.py中的MasRCNN类中的train函数。

    def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
              augmentation=None, custom_callbacks=None, no_augmentation_sources=None):
        """Train the model.
        train_dataset, val_dataset: Training and validation Dataset objects.
        learning_rate: The learning rate to train with
        epochs: Number of training epochs. Note that previous training epochs
                are considered to be done alreay, so this actually determines
                the epochs to train in total rather than in this particaular
                call.
        layers: Allows selecting wich layers to train. It can be:
            - A regular expression to match layer names to train
            - One of these predefined values:
              heads: The RPN, classifier and mask heads of the network
              all: All the layers
              3+: Train Resnet stage 3 and up
              4+: Train Resnet stage 4 and up
              5+: Train Resnet stage 5 and up
        augmentation: Optional. An imgaug (https://github.com/aleju/imgaug)
            augmentation. For example, passing imgaug.augmenters.Fliplr(0.5)
            flips images right/left 50% of the time. You can pass complex
            augmentations as well. This augmentation applies 50% of the
            time, and when it does it flips images right/left half the time
            and adds a Gaussian blur with a random sigma in range 0 to 5.

                augmentation = imgaug.augmenters.Sometimes(0.5, [
                    imgaug.augmenters.Fliplr(0.5),
                    imgaug.augmenters.GaussianBlur(sigma=(0.0, 5.0))
                ])
	    custom_callbacks: Optional. Add custom callbacks to be called
	        with the keras fit_generator method. Must be list of type keras.callbacks.
        no_augmentation_sources: Optional. List of sources to exclude for
            augmentation. A source is string that identifies a dataset and is
            defined in the Dataset class.
        """
        assert self.mode == "training", "Create model in training mode."

        # Pre-defined layer regular expressions
        layer_regex = {
            # all layers but the backbone
            "heads": r"(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
            # From a specific Resnet stage and up
            "3+": r"(res3.*)|(bn3.*)|(res4.*)|(bn4.*)|(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
            "4+": r"(res4.*)|(bn4.*)|(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
            "5+": r"(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
            # All layers
            "all": ".*",
        }
        if layers in layer_regex.keys():
            layers = layer_regex[layers]

        # Data generators
        train_generator = data_generator(train_dataset, self.config, shuffle=True,
                                         augmentation=augmentation,
                                         batch_size=self.config.BATCH_SIZE,
                                         no_augmentation_sources=no_augmentation_sources)
        val_generator = data_generator(val_dataset, self.config, shuffle=True,
                                       batch_size=self.config.BATCH_SIZE)

        # Create log_dir if it does not exist
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

        # Callbacks
        callbacks = [
            keras.callbacks.TensorBoard(log_dir=self.log_dir,
                                        histogram_freq=0, write_graph=True, write_images=False),
            keras.callbacks.ModelCheckpoint(self.checkpoint_path,
                                            verbose=0, save_weights_only=True),
        ]

        # Add custom callbacks to the list
        if custom_callbacks:
            callbacks += custom_callbacks

        # Train
        log("\nStarting at epoch {}. LR={}\n".format(self.epoch, learning_rate))
        log("Checkpoint Path: {}".format(self.checkpoint_path))
        self.set_trainable(layers)
        self.compile(learning_rate, self.config.LEARNING_MOMENTUM)

        # Work-around for Windows: Keras fails on Windows when using
        # multiprocessing workers. See discussion here:
        # https://github.com/matterport/Mask_RCNN/issues/13#issuecomment-353124009
        if os.name is 'nt':
            workers = 0
        else:
            workers = multiprocessing.cpu_count()

        self.keras_model.fit_generator(
            train_generator,
            initial_epoch=self.epoch,
            epochs=epochs,
            steps_per_epoch=self.config.STEPS_PER_EPOCH,
            callbacks=callbacks,
            validation_data=val_generator,
            validation_steps=self.config.VALIDATION_STEPS,
            max_queue_size=100,
            workers=workers,
            use_multiprocessing=True,
        )
        self.epoch = max(self.epoch, epochs)

data_generator(最关键部分)

请看【MaskRCNN】源码系列一:数据处理二

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值