matterport_MaskRCNN(5)——代码解读model.py

本文详细解析了`model.py`中`load_image_gt()`函数及Mask R-CNN模型的实现,包括数据集处理、图像resize、mask处理等步骤。同时介绍了ResNet网络结构和配置文件`config.py`的作用。辅助函数如`resize_image`、`resize_mask`、`extract_bboxes`等在处理图像和mask时的关键作用。
摘要由CSDN通过智能技术生成

model.py——def load_image_gt()

用于载入图片的相关标注信息(),即根据给定的image_id,从dataset中载入:原图,mask图,boxes信息
输入:
数据集实例,配置类实例,image_id,其他可选参数
输出:
原图([h,w,3]),图片原始形状,class_ids(图片中各目标的类别id),bbox(各目标的坐标信息),mask(各个目标的掩膜信息)

def load_image_gt(dataset, config, image_id, augment=False, augmentation=None,
                  use_mini_mask=False):
    """Load and return ground truth data for an image (image, mask, bounding boxes).
    载入并返回一张图片的真实数据(图片,掩膜,边界框)

    augment: 不建议使用。(deprecated. Use augmentation instead). If true, apply random
        image augmentation. Currently, only horizontal flipping is offered.
        
    augmentation: 可选项,用于图片增强,Optional. An imgaug (https://github.com/aleju/imgaug) augmentation.
        For example, passing imgaug.augmenters.Fliplr(0.5) flips images
        right/left 50% of the time.
    use_mini_mask: 
        如果不使用,返回的mask与原图一致,可能会非常大
        If False, returns full-size masks that are the same height
        and width as the original image. These can be big, for example
        1024x1024x100 (for 100 instances). Mini masks are smaller, typically,
        224x224 and are generated by extracting the bounding box of the
        object and resizing it to MINI_MASK_SHAPE.

    Returns:返回值
    image: 原图,[height, width, 3]
    shape: 图片的原始形状(resize且crop之前)
           the original shape of the image before resizing and cropping.
    class_ids: 类别id,也就是一维数组,长度=图中实例个数
               [instance_count] Integer class IDs
    bbox: 边界框,[instance_count, (y1, x1, y2, x2)]
    mask: 掩膜数组(图片)[height, width, instance_count]. The height and width are those
        of the image unless use_mini_mask is True, in which case they are
        defined in MINI_MASK_SHAPE.
    """
    # 载入图片和mask,load_image和load_mask见utils.py中dataset类的内部方法
    # 或者在重写class dataset时重新定义的
    image = dataset.load_image(image_id)          #根据id载入图片,如果是灰度图会自动转换成三通道[h,w,3]形式
    mask, class_ids = dataset.load_mask(image_id) # 获取mask图片[h,w,n]以及对应class_ids一维数组[n,]
    original_shape = image.shape                  # 图片原始shape(h,w,c)
    
    # 根据相关参数对image和mask进行resize
    image, window, scale, padding, crop = utils.resize_image(
        image,
        min_dim=config.IMAGE_MIN_DIM,
        min_scale=config.IMAGE_MIN_SCALE,
        max_dim=config.IMAGE_MAX_DIM,
        mode=config.IMAGE_RESIZE_MODE)
    mask = utils.resize_mask(mask, scale, padding, crop)

    # 数据增强部分,暂时略过
    # Random horizontal flips.
    # TODO: will be removed in a future update in favor of augmentation
    if augment:
        logging.warning("'augment' is deprecated. Use 'augmentation' instead.")
        if random.randint(0, 1):
            image = np.fliplr(image)
            mask = np.fliplr(mask)

    # Augmentation
    # This requires the imgaug lib (https://github.com/aleju/imgaug)
    if augmentation:
        import imgaug

        # Augmenters that are safe to apply to masks
        # Some, such as Affine, have settings that make them unsafe, so always
        # test your augmentation on masks
        MASK_AUGMENTERS = ["Sequential", "SomeOf", "OneOf", "Sometimes",
                           "Fliplr", "Flipud", "CropAndPad",
                           "Affine", "PiecewiseAffine"]

        def hook(images, augmenter, parents, default):
            """Determines which augmenters to apply to masks."""
            return augmenter.__class__.__name__ in MASK_AUGMENTERS

        # Store shapes before augmentation to compare
        image_shape = image.shape
        mask_shape = mask.shape
        # Make augmenters deterministic to apply similarly to images and masks
        det = augmentation.to_deterministic()
        image = det.augment_image(image)
        # Change mask to np.uint8 because imgaug doesn't support np.bool
        mask = det.augment_image(mask.astype(np.uint8),
                                 hooks=imgaug.HooksImages(activator=hook))
        # Verify that shapes didn't change
        assert image.shape == image_shape, "Augmentation shouldn't change image size"
        assert mask.shape == mask_shape, "Augmentation shouldn't change mask size"
        # Change mask back to bool
        mask = mask.astype(np.bool)

    # 如果resize使用了crop,可能将mask裁减掉,这里筛选掉全0的mask
    # 将mask的第0,1维度所有像素值相加,并判断是否>0
    # _idx是一个bool值列表,如果是False说明该index对应的mask[:,:,index]像素全0,要被筛选掉
    _idx = np.sum(mask, axis=(0, 1)) > 0     
    mask = mask[:, :, _idx]                 # _idx中应False的维度将被删除
    class_ids = class_ids[_idx]             #class_ids中对应的也删除
    # 根据mask中的像素值计算边界框的角点值
    bbox = utils.extract_bboxes(mask)

    # Active classes  活动的类别,应该是指用到的类别吧,在单数据集训练时似乎没有用处
    # Different datasets have different classes, so track the
    # classes supported in the dataset of this image.
    active_class_ids = np.zeros([dataset.num_classes], dtype=np.int32)
    source_class_ids = dataset.source_class_ids[dataset.image_info[image_id]["source"]]
    active_class_ids[source_class_ids] = 1

    # 如果使用nimi_mask则,再次对mask进行resize,可以节约内存
    if use_mini_mask:
        mask = utils.minimize_mask(bbox, mask, config.MINI_MASK_SHAPE)

    # 将原始图片及相关信息(原图尺寸,缩放因子等等)存入一维数组中
    image_meta = compose_image_meta(image_id, original_shape, image.shape,
                                    window, scale, active_class_ids)

    return image, image_meta, class_ids, bbox, mask

untils.py——class Dataset(object)

所有数据集的父类,当使用模型时,我们需要根据自己的数据集,创建子类并重写其中的一些方法。
在该类中,数据的组织形式是image_info=[],该list中包含了多个键值对,对应了图片的各种信息,在载入图片时,主要是从image_info中取需要的信息。
数据集的构建过程:先自定义数据集类,然后手动导入类别信息和图片信息(调用add_class以及add_image),之后的操作全在dataset.prepare()中

class Dataset(object):
	# 初始化
    def __init__(self, class_map=None):
        self._image_ids = [] #图片类别id号的list
        self.image_info = [] #图片信息(最常用的list)
        self.class_info = [{
   "source": "", "id": 0, "name": "BG"}] 
        # class_info包含了整个数据集中所有的类别name和id,而source指的应该是‘COCO’,'VOC'这种数据集名称
        self.source_class_ids = {
   }  #源类别id
	
	# 添加新类别
    def add_class(self, source, class_id, class_name):
        assert "." not in source, "Source name cannot contain a dot"
        # 如果该类已经添加过,直接返回
        for info in self.class_info:
            if info['source'] == source and info["id"] == class_id:
                return
        # 否则,添加该类(以字典形式添加到class_info中)
        self.class_info.append({
   
            "source": source,
            "id": class_id,
            "name": class_name,
        })
        
    # 添加图片,输入(source,图片id,图片路径,其他参数)
    def add_image(self, source, image_id, path, **kwargs):
        image_info = {
   
            "id": image_id,
            "source": source,
            "path": path,
        }
        image_info.update(kwargs)           #如果有其他参数传入,则更新image_info
        self.image_info.append(image_info)  #将图片信息以字典形式加入到image_info的list中
        
    # 用于查找图像用
    def image_reference(self, image_id):
        """可以根据图片id给图片一个链接信息
        方便图片查找(源码中没有写,故跳过)
        """
        return ""
    
    # 数据集准备工作(包含了类别和图片的从name->id的映射,以及当数据来源多余一个数据集时,对数据集进行整理)
    # 数据集来源source指的是‘COCO’,‘VOC’这种,当然也可以自定义,比如‘18年’,‘19年’等
    def prepare(self, class_map=None):
        """
        数据集的准备工作

        TODO: class map is not supported yet. When done, it should handle mapping
              classes from different datasets to the same class ID.
              尚不支持类映射。 完成后,它应处理从不同数据集到相同类ID的映射类。
        """
        # 返回一个简短的对象名用于简洁显示(没看懂意义)
        # 可能是作者所使用的数据名称有特定格式,作用就是取name中第一个','之前的字符串
        def clean_name(name):
            return ",".join(name.split(",")[:1])
        
        # 从info字典中创建(或重建)其他信息
        self.num_classes = len(self.class_info)       #获取类别数目(类别信息list的长度)
        self.class_ids = np.arange(self.num_classes)  #生成类别id(用于将class_name映射成class_id)
        self.class_names = [clean_name(c["name"]) for c in self.class_info]   #类别名称list
        self.num_images = len(self.image_info)        #图片数量
        self._image_ids = np.arange(self.num_images)  # 图片id(也是用于映射的)
        
        # 映射操作,即对names和ids两个迭代器进行同时迭代,然后生成字典
        # 字典中的键值对就是映射关系
        # 先进行类别映射
        # 从class_info和class_ids同时取值,然后生成如 'COCO.x':'y' 形式的键值对
        self.class_from_source_map = {
   "{}.{}".format(info['source'], info['id']): id
                                      for info, id in zip(self.class_info, self.class_ids)}
        # 进行图片映射,同上
        self.image_from_source_map = {
   "{}.{}".format(info['source'], info['id']): id
                                      for info, id in zip(self.image_info, self.image_ids)}
        
        # 获取source名称的list,使用set避免重复
        self.sources = list(set([i['source'] for i in self.class_info]))
        
        self.source_class_ids = {
   }
        # Loop over datasets
        # 遍历数据集
        for source in self.sources:
            # 对于sources中的每个中数据集,都在source_class_ids中创建一个键值对,键是数据集名称,值是空列表
            self.source_class_ids[source] = []
            # 寻找属于该classes的数据集
            # 遍历class_info,找到属于数据集source中的类别,添加到对应的列表中
            # 注:背景类属于任何一个数据集,所以都要添加
            for i, info in enumerate(self.class_info):
                # Include BG class in all datasets所有数据集都包含背景类
                if i == 0 or source == info['source']:
                    self.source_class_ids[source].append(i)
                    
    # 根据source名称获取对应的分配其的整型ID,数据集来源多余一个时用到,此处略过
    def map_source_class_id(self, source_class_id):
        """Takes a source class ID and returns the int class ID assigned to it.
        For example:
        dataset.map_source_class_id("coco.12") -> 23
        """
        return self.class_from_source_map[source_class_id]
    # 根据source和class_id查到该类在calss_innfo中的id,只有一类数据集时也用不到
    def get_source_class_id(self, class_id, source):
        """Map an internal class ID to the corresponding class ID in the source dataset.
        将内部类ID映射到源数据集中的相应类ID
        输入分配给该类的id,看该id对应的source与输入source"""
        info = self.class_info[class_id]
        assert info['source'] == source
        return info['id']
        
    @property  # 一个装饰器,用于把函数方法变成属性,方便调用,即可以直接写list1=dataset.image_ids
    # 返回图片内部id的list,用于查看有多找张图像以及他们在数据集中的id
    def image_ids(self):
        return self._image_ids
    # 根据image_id返回image的path
    def source_image_link(self, image_id):
        
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值