目录
coco数据
首先我处理的是coco数据集,coco数据集具体的文件夹路径和方式为:
有人 文件夹中没有annotations,只有annotations_trainval2014,那只需要将annotations_trainval2014改为annotations。
在这个annotations文件夹中只有instances_val2014.json,没有instances_minival2014.json和instances_valminusminival2014.json,那只要将instances_val2014.json复制2份,然后分别命名为instances_minival2014.json和instances_valminusminival2014.json即可。
如果想了解json文件中是什么,请查看https://blog.csdn.net/u013066730/article/details/100578941
数据读取
在maskrcnn/samples/coco/coco.py中的main函数中出现了读数据的代码:
if args.command == "train":
# Training dataset. Use the training set and 35K from the
# validation set, as as in the Mask RCNN paper.
dataset_train = CocoDataset()
dataset_train.load_coco(args.dataset, "train", year=args.year, auto_download=args.download)
if args.year in '2014':
dataset_train.load_coco(args.dataset, "valminusminival", year=args.year, auto_download=args.download)
dataset_train.prepare()
# Validation dataset
dataset_val = CocoDataset()
val_type = "val" if args.year in '2017' else "minival"
dataset_val.load_coco(args.dataset, val_type, year=args.year, auto_download=args.download)
dataset_val.prepare()
# Image Augmentation
# Right/Left flip 50% of the time
augmentation = imgaug.augmenters.Fliplr(0.5)
# *** This training schedule is an example. Update to your needs ***
# Training - Stage 1
print("Training network heads")
model.train(dataset_train, dataset_val,
learning_rate=config.LEARNING_RATE,
epochs=40,
layers='heads',
augmentation=augmentation)
代码句句解读
dataset_train = CocoDataset()
CocoDataset的实例化,CocoDataset继承于utils.Dataset,在utils.Dataset类中,包含了对数据最基本的处理。等后面用到类中的函数时再具体介绍。
dataset_train.load_coco(args.dataset, "train", year=args.year, auto_download=args.download)
dataset_train对象调用了load_coco函数,输入的参数是:
dataset_dir = args.dataset = "E:\data\coco2014"
subset = "train"
year = args.year = 2014
class_ids = None
class_map = None
return_coco = False
auto_download = args.download = False
def load_coco(self, dataset_dir, subset, year=DEFAULT_DATASET_YEAR, class_ids=None,
class_map=None, return_coco=False, auto_download=False):
下面来一句句解读代码:
if auto_download is True: self.auto_download(dataset_dir, subset, year)
由于auto_download是False,不需要下载,直接跳过。
coco = COCO("{}/annotations/instances_{}{}.json".format(dataset_dir, subset, year)) if subset == "minival" or subset == "valminusminival": subset = "val" image_dir = "{}/{}{}".format(dataset_dir, subset, year)
coco=COCO("E:/data/coco2014/annotations/instance_train2014.json")
image_dir = "E:/data/coco2014/train2014"
if not class_ids: # All classes class_ids = sorted(coco.getCatIds())
得到的class_ids为<class 'list'>: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
这表明我们使用了coco当中的这些类别,正好80类(这里是没有算背景的)。
# All images or a subset? if class_ids: image_ids = [] for id in class_ids: image_ids.extend(list(coco.getImgIds(catIds=[id]))) # Remove duplicates image_ids = list(set(image_ids)) else: # All images image_ids = list(coco.imgs.keys())
由于class_ids有值,所以进入了上面这个判断,这里主要做的事情就是将相同类别的数据的图片id组合到一个list 当中。
for i in class_ids: self.add_class("coco", i, coco.loadCats(i)[0]["name"])
上面这段调用了
def add_class(self, source, class_id, class_name): assert "." not in source, "Source name cannot contain a dot" # Does the class exist already? for info in self.class_info: if info['source'] == source and info["id"] == class_id: # source.class_id combination already available, skip return # Add the class self.class_info.append({ "source": source, "id": class_id, "name": class_name, })
上面这一段在构造一个字典,这个字典中保存了source=数据来源coco,id=图像类别id,name=类别id对应的类别名称。
# Add images for i in image_ids: self.add_image( "coco", image_id=i, path=os.path.join(image_dir, coco.imgs[i]['file_name']), width=coco.imgs[i]["width"], height=coco.imgs[i]["height"], annotations=coco.loadAnns(coco.getAnnIds( imgIds=[i], catIds=class_ids, iscrowd=None)))
上面这段代码调用了add_image函数
def add_image(self, source, image_id, path, **kwargs): image_info = { "id": image_id, "source": source, "path": path, } image_info.update(kwargs) self.image_info.append(image_info)
这小段代码实现的是将图像的各种信息组成一个字典,这段也是读取数据的关键部分,id=图像id(必须是唯一的),source=数据集名称coco,path=具体图像path,width=图像的宽,height=图像的高,annotations=其实读的就是json文件(具体可以参考https://blog.csdn.net/u013066730/article/details/100578941)。
if args.year in '2014':
dataset_train.load_coco(args.dataset, "valminusminival", year=args.year, auto_download=args.download)
这段主函数的代码和上面load_coco的程序是一样的。
dataset_train.prepare()
调用了父类中的prepare函数
def prepare(self, class_map=None):
"""Prepares the Dataset class for use.
TODO: class map is not supported yet. When done, it should handle mapping
classes from different datasets to the same class ID.
"""
def clean_name(name):
"""Returns a shorter version of object names for cleaner display."""
return ",".join(name.split(",")[:1])
# Build (or rebuild) everything else from the info dicts.
self.num_classes = len(self.class_info)
self.class_ids = np.arange(self.num_classes)
self.class_names = [clean_name(c["name"]) for c in self.class_info]
self.num_images = len(self.image_info)
self._image_ids = np.arange(self.num_images)
# Mapping from source class and image IDs to internal IDs
self.class_from_source_map = {"{}.{}".format(info['source'], info['id']): id
for info, id in zip(self.class_info, self.class_ids)}
self.image_from_source_map = {"{}.{}".format(info['source'], info['id']): id
for info, id in zip(self.image_info, self.image_ids)}
# Map sources to class_ids they support
self.sources = list(set([i['source'] for i in self.class_info]))
self.source_class_ids = {}
# Loop over datasets
for source in self.sources:
self.source_class_ids[source] = []
# Find classes that belong to this dataset
for i, info in enumerate(self.class_info):
# Include BG class in all datasets
if i == 0 or source == info['source']:
self.source_class_ids[source].append(i)
输入参数:class_map=None
得到的结果:
self.num_class:81,表示一共81类。
self.class_ids:[0,1,2,3...,80];
self.class_names:['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'],具体的类别名称。
self.num_images:122218,一共有这么多张图像;
self._image_ids:[0,1,2,3,4...122217];
self.class_from_source_map:{'coco.61': 56, 'coco.86': 76, 'coco.16': 15, 'coco.84': 74, 'coco.3': 3, 'coco.88': 78, 'coco.77': 68, 'coco.81': 72, 'coco.73': 64, 'coco.11': 11, 'coco.38': 34, 'coco.57': 52, 'coco.54': 49, 'coco.25': 24, 'coco.80': 71, 'coco.51': 46, 'coco.56': 51, 'coco.13': 12, 'coco.15': 14, 'coco.14': 13, 'coco.67': 61, 'coco.49': 44, 'coco.46': 41, 'coco.79': 70, 'coco.20': 19, 'coco.17': 16, 'coco.32': 28, 'coco.52': 47, 'coco.48': 43, 'coco.4': 4, 'coco.65': 60, 'coco.34': 30, 'coco.27': 25, 'coco.22': 21, 'coco.50': 45, 'coco.75': 66, 'coco.82': 73, 'coco.47': 42, 'coco.70': 62, 'coco.43': 39, 'coco.31': 27, 'coco.74': 65, 'coco.19': 18, 'coco.21': 20, 'coco.72': 63, 'coco.33': 29, 'coco.2': 2, 'coco.9': 9, 'coco.59': 54, 'coco.63': 58, 'coco.1': 1, 'coco.10': 10, 'coco.62': 57, 'coco.53': 48, 'coco.6': 6, 'coco.37': 33, 'coco.36': 32, 'coco.90': 80, 'coco.89': 79, '.0': 0, 'coco.87': 77, 'coco.60': 55, 'coco.76': 67, 'coco.35': 31, 'coco.85': 75, 'coco.18': 17, 'coco.44': 40, 'coco.8': 8, 'coco.28': 26, 'coco.23': 22, 'coco.24': 23, 'coco.7': 7, 'coco.39': 35, 'coco.5': 5, 'coco.41': 37, 'coco.64': 59, 'coco.78': 69, 'coco.40': 36, 'coco.55': 50, 'coco.42': 38, 'coco.58': 53}
self.image_from_source_map:<class 'dict'>: <Too big to print. Len: 122218>,数据具体的样式为{'coco.103720':66675, 'coco.10039':85616 ...}
self.sources: ['', 'coco']
self.source_class_ids:循环上面的self.sources,第一个为‘’,是空,会被直接跳过,进入第二个‘coco’,其结果为
{'': [0], 'coco': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80]}
augmentation = imgaug.augmenters.Fliplr(0.5)
进行数据增强,这个imgaug库我这里不做介绍,请自行参考https://github.com/aleju/imgaug
model.train(dataset_train, dataset_val,
learning_rate=config.LEARNING_RATE,
epochs=40,
layers='heads',
augmentation=augmentation)
上面这段代码是主函数中进入train接口时,传入了dataset_train数据类。
这里我不对train进行介绍,只介绍数据处理的部分。
这个model.train调用了mrcnn/model.py中的MasRCNN类中的train函数。
def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
augmentation=None, custom_callbacks=None, no_augmentation_sources=None):
"""Train the model.
train_dataset, val_dataset: Training and validation Dataset objects.
learning_rate: The learning rate to train with
epochs: Number of training epochs. Note that previous training epochs
are considered to be done alreay, so this actually determines
the epochs to train in total rather than in this particaular
call.
layers: Allows selecting wich layers to train. It can be:
- A regular expression to match layer names to train
- One of these predefined values:
heads: The RPN, classifier and mask heads of the network
all: All the layers
3+: Train Resnet stage 3 and up
4+: Train Resnet stage 4 and up
5+: Train Resnet stage 5 and up
augmentation: Optional. An imgaug (https://github.com/aleju/imgaug)
augmentation. For example, passing imgaug.augmenters.Fliplr(0.5)
flips images right/left 50% of the time. You can pass complex
augmentations as well. This augmentation applies 50% of the
time, and when it does it flips images right/left half the time
and adds a Gaussian blur with a random sigma in range 0 to 5.
augmentation = imgaug.augmenters.Sometimes(0.5, [
imgaug.augmenters.Fliplr(0.5),
imgaug.augmenters.GaussianBlur(sigma=(0.0, 5.0))
])
custom_callbacks: Optional. Add custom callbacks to be called
with the keras fit_generator method. Must be list of type keras.callbacks.
no_augmentation_sources: Optional. List of sources to exclude for
augmentation. A source is string that identifies a dataset and is
defined in the Dataset class.
"""
assert self.mode == "training", "Create model in training mode."
# Pre-defined layer regular expressions
layer_regex = {
# all layers but the backbone
"heads": r"(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
# From a specific Resnet stage and up
"3+": r"(res3.*)|(bn3.*)|(res4.*)|(bn4.*)|(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
"4+": r"(res4.*)|(bn4.*)|(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
"5+": r"(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
# All layers
"all": ".*",
}
if layers in layer_regex.keys():
layers = layer_regex[layers]
# Data generators
train_generator = data_generator(train_dataset, self.config, shuffle=True,
augmentation=augmentation,
batch_size=self.config.BATCH_SIZE,
no_augmentation_sources=no_augmentation_sources)
val_generator = data_generator(val_dataset, self.config, shuffle=True,
batch_size=self.config.BATCH_SIZE)
# Create log_dir if it does not exist
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
# Callbacks
callbacks = [
keras.callbacks.TensorBoard(log_dir=self.log_dir,
histogram_freq=0, write_graph=True, write_images=False),
keras.callbacks.ModelCheckpoint(self.checkpoint_path,
verbose=0, save_weights_only=True),
]
# Add custom callbacks to the list
if custom_callbacks:
callbacks += custom_callbacks
# Train
log("\nStarting at epoch {}. LR={}\n".format(self.epoch, learning_rate))
log("Checkpoint Path: {}".format(self.checkpoint_path))
self.set_trainable(layers)
self.compile(learning_rate, self.config.LEARNING_MOMENTUM)
# Work-around for Windows: Keras fails on Windows when using
# multiprocessing workers. See discussion here:
# https://github.com/matterport/Mask_RCNN/issues/13#issuecomment-353124009
if os.name is 'nt':
workers = 0
else:
workers = multiprocessing.cpu_count()
self.keras_model.fit_generator(
train_generator,
initial_epoch=self.epoch,
epochs=epochs,
steps_per_epoch=self.config.STEPS_PER_EPOCH,
callbacks=callbacks,
validation_data=val_generator,
validation_steps=self.config.VALIDATION_STEPS,
max_queue_size=100,
workers=workers,
use_multiprocessing=True,
)
self.epoch = max(self.epoch, epochs)
data_generator(最关键部分)
请看【MaskRCNN】源码系列一:数据处理二