mask-rcnn训练测试自制数据集

29 篇文章 1 订阅
14 篇文章 0 订阅

mask-rcnn训练测试自制数据集

本项目简介

本项目用于口腔模型分割,数据类型有7种,本文主要用于介绍如何使用自制数据集训练自己的模型

训练环境配置

操作系统:win10
GPU: GTX 1080ti
CPU: intel i7 8700
内存: 32G

项目地址

项目地址:https://github.com/a2824256/Mask_RCNN_4_Oral_Segmentation,mask-rcnn模型来自matterport/Mask_RCNN

自制coco数据集格式数据集教程

把数据集制作成coco数据集格式

https://blog.csdn.net/a2824256/article/details/105818290

编写测试训练基类

本代码运行在jupyter notebook环境上,可直接在仓库下载完整文件

核心代码讲解
数据集训练配置代码
class OralConfig(Config):
    # 设置为你数据集的名字
    NAME = "Oral"

    # Train on 1 GPU and 8 images per GPU. We can put multiple images on each
    # GPU because the images are small. Batch size is 8 (GPUs * images/GPU).
    GPU_COUNT = 1
    IMAGES_PER_GPU = 3 #低配显卡可适当调低,最低要求显存9G以上

    # Number of classes (including background)
    # 使用的数据集有七个分类,加上背景所以是1+7
    NUM_CLASSES = 1 + 7

    # Use small images for faster training. Set the limits of the small side
    # the large side, and that determines the image shape.
    IMAGE_MIN_DIM = 480
    IMAGE_MAX_DIM = 640

    # Use smaller anchors because our image and objects are small
    RPN_ANCHOR_SCALES = (8, 16, 32, 64, 128)  # anchor side in pixels

    # Reduce training ROIs per image because the images are small and have
    # few objects. Aim to allow ROI sampling to pick 33% positive ROIs.
    TRAIN_ROIS_PER_IMAGE = 32

    # Use a small epoch since the data is simple
    STEPS_PER_EPOCH = 100

    # use small validation steps since the epoch is small
    VALIDATION_STEPS = 5
数据集加载类相关代码
class OralDataset(utils.Dataset):
	#该函数作用是把数据集的图像与相关信息添加到算法库的model类的内部参数image_info
    def load_shapes(self):
        # 添加用到的七个分类
        self.add_class("oral", 1, "teeth_top")
        self.add_class("oral", 2, "teeth_bottom")
        self.add_class("oral", 3, "uvula")
        self.add_class("oral", 4, "tongue")
        self.add_class("oral", 5, "pp_wall")
        self.add_class("oral", 6, "tonsil_right")
        self.add_class("oral", 7, "tonsil_left")
        anns_json_path =  '你coco格式数据集的annotations.json文件路径'
        with open(anns_json_path,'r',encoding='utf8') as load_f:
            config_json = json.load(load_f)
            img_path = '你存放coco格式数据集的图片文件夹路径'
            coco = COCO(anns_json_path)
            class_ids = sorted(coco.getCatIds())
            for i in range(6):
                self.add_image("oral", image_id=i, path=img_path + config_json['images'][i]['file_name'],width=640, height=480, annotations=coco.loadAnns(coco.getAnnIds(imgIds=[i], catIds=class_ids, iscrowd=None)))
	#该函数的作用是根据image_id从model类的内部参数image_info取出单张图片的信息
    def load_image(self, image_id):
        info = self.image_info[image_id]
        image = PILImage.open(info['path'])
        image = np.array(image).astype(np.uint8)
        return image

	#该函数作用是根据image_id加载对应图片的mask
    def load_mask(self, image_id):
        image_info = self.image_info[image_id]
        if image_info["source"] != "oral":
            return super(OralDataset, self).load_mask(image_id)

        instance_masks = []
        class_ids = []
        annotations = self.image_info[image_id]["annotations"]
        # Build mask of shape [height, width, instance_count] and list
        # of class IDs that correspond to each channel of the mask.
        for annotation in annotations:
            class_id = self.map_source_class_id("oral.{}".format(annotation['category_id']))
            if class_id:
                m = self.annToMask(annotation, image_info["height"],
                                   image_info["width"])
                # Some objects are so small that they're less than 1 pixel area
                # and end up rounded out. Skip those objects.
                if m.max() < 1:
                    continue
                # Is it a crowd? If so, use a negative class ID.
                if annotation['iscrowd']:
                    # Use negative class ID for crowds
                    class_id *= -1
                    # For crowd masks, annToMask() sometimes returns a mask
                    # smaller than the given dimensions. If so, resize it.
                    if m.shape[0] != image_info["height"] or m.shape[1] != image_info["width"]:
                        m = np.ones([image_info["height"], image_info["width"]], dtype=bool)
                instance_masks.append(m)
                class_ids.append(class_id)

        # Pack instance masks into an array
        if class_ids:
            mask = np.stack(instance_masks, axis=2).astype(np.bool)
            class_ids = np.array(class_ids, dtype=np.int32)
            return mask, class_ids
        else:
            # Call super class to return an empty mask
            return super(OralDataset, self).load_mask(image_id)

    # annotation convert to RLE
    def annToRLE(self, ann, height, width):
        """
        Convert annotation which can be polygons, uncompressed RLE to RLE.
        :return: binary mask (numpy 2D array)
        """
        segm = ann['segmentation']
        if isinstance(segm, list):
            # polygon -- a single object might consist of multiple parts
            # we merge all parts into one mask rle code
            rles = maskUtils.frPyObjects(segm, height, width)
            rle = maskUtils.merge(rles)
        elif isinstance(segm['counts'], list):
            # uncompressed RLE
            rle = maskUtils.frPyObjects(segm, height, width)
        else:
            # rle
            rle = ann['segmentation']
        return rle

    # annotation 转mask
    def annToMask(self, ann, height, width):
        """
        Convert annotation which can be polygons, uncompressed RLE, or RLE to binary mask.
        :return: binary mask (numpy 2D array)
        """
        rle = self.annToRLE(ann, height, width)
        m = maskUtils.decode(rle)
        return m

配置训练集和验证集
# Training dataset
dataset_train = OralDataset()
dataset_train.load_shapes()
dataset_train.prepare()

# Validation dataset
dataset_val = OralDataset()
# 可修改load_shapes函数,添加一个参数用来判断是验证集还是测试集
dataset_val.load_shapes()
dataset_val.prepare()
验证图片和mask的加载情况
image = dataset_train.load_image(0)
mask, class_ids = dataset_train.load_mask(0)
visualize.display_top_masks(image, mask, class_ids, dataset_train.class_names)

测试结果:
测试结果

加载模型与训练
model = modellib.MaskRCNN(mode="training", config=config,
                          model_dir=MODEL_DIR)
init_with = "coco"  # imagenet, coco, or last
model.load_weights(COCO_MODEL_PATH, by_name=True,
                       exclude=["mrcnn_class_logits", "mrcnn_bbox_fc", 
                                "mrcnn_bbox", "mrcnn_mask"])
model.train(dataset_train, dataset_val, 
            learning_rate=config.LEARNING_RATE, 
            epochs=40,
            layers='heads')
model.train(dataset_train, dataset_val, 
            learning_rate=config.LEARNING_RATE / 10,
            epochs=160,
            layers="all")
模型加载权重与检测
class InferenceConfig(OralConfig):
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1

inference_config = InferenceConfig()

# Recreate the model in inference mode
model = modellib.MaskRCNN(mode="inference", 
                          config=inference_config,
                          model_dir=MODEL_DIR)

# Get path to saved weights
# Either set a specific path or find last trained weights
# model_path = os.path.join(ROOT_DIR, ".h5 file name here")
model_path = model.find_last()

# Load trained weights
print("Loading weights from ", model_path)
model.load_weights(model_path, by_name=True)
# Test on a random image
image_id = random.choice(dataset_val.image_ids)
original_image, image_meta, gt_class_id, gt_bbox, gt_mask =\
    modellib.load_image_gt(dataset_val, inference_config, 
                           image_id, use_mini_mask=False)

log("original_image", original_image)
log("image_meta", image_meta)
log("gt_class_id", gt_class_id)
log("gt_bbox", gt_bbox)
log("gt_mask", gt_mask)

visualize.display_instances(original_image, gt_bbox, gt_mask, gt_class_id, 
                            dataset_train.class_names, figsize=(8, 8))
results = model.detect([original_image], verbose=1)

r = results[0]
print("----------r['class_ids']----------")
print(r['class_ids'])
print("----------r['class_ids']----------")
visualize.display_instances(original_image, r['rois'], r['masks'], r['class_ids'], 
                            dataset_val.class_names, r['scores'], ax=get_ax())

测试结果:
测试结果
其他需求请自行摸索

  • 2
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Alex-Leung

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值