【深度学习】【mmdetection】PolarMask

【mmdetection】PolarMask

官方代码:https://github.com/xieenze/PolarMask
官方论文解读:https://zhuanlan.zhihu.com/p/84890413

代码运行

创建环境

conda create -n open-mmlab python=3.7 -y
conda activate open-mmlab
conda install cython

Install PyTorch stable 
eg:conda install pytorch torchvision cudatoolkit=9.2 -c pytorch

git clone https://github.com/xieenze/PolarMask.git
cd PolarMask

python setup.py develop
# or "pip install -v -e ."

修改配置文件data_root为数据集所在位置。

train

# 单gpu
python tools/train.py configs/polarmask/4gpu/polar_768_1x_r50.py

#多gpu
./tools/dist_train.sh configs/polarmask/4gpu/polar_768_1x_r50.py 4 --launcher pytorch --work_dir ./work_dirs/polar_768_1x_r50_4gpu

出错No module named 'Polygon polygon3

pip install Polygon3

出错cannot import name 'get_dist_info' from 'mmcv.runner.utils'
修改/PolarMask/mmdet/datasets/loader/sampler.py中from mmcv.runner.utils import get_dist_info变为from mmcv.runner import get_dist_info

在这里插入图片描述

test

# 单gpu
python tools/test.py configs/polarmask/4gpu/polar_768_1x_r50.py [YOUR_CHECKPOINT_DIR] --out [OUT_DIR]
# eg:python tools/test.py configs/polarmask/4gpu/polar_768_1x_r101.py /home/wh/weights/polarmask_r101_1x.pth --out work_dirs/polar101.pkl

数据读取

代码在mmdetection中重新定义了Coco_Seg_Dataset类,在mmdet/datasets/coco_seg.py里,主要有五个函数,这五个函数都是原本coco.py里面的。

  • load_annotations():加载标注文件中的annotation字典,返回图片名字。
  • get_ann_info():实际调用_parse_ann_info()。
  • _filter_imgs():过滤没有标注文件的图片和尺寸小于min-size的图片。
  • _parse_ann_info():返回key为bboxes,bboxes_ignore, labels, masks, mask_polys, poly_lens的字典。
  • prepare_train_img() :重载了custom.py中CustomDataset的prepare_train_img() 方法。

在这一部分,还进行了标签的分配。

@DATASETS.register_module
class Coco_Seg_Dataset(CustomDataset):
    CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
               'train', 'truck', 'boat', 'traffic_light', 'fire_hydrant',
               'stop_sign', 'parking_meter', 'bench', 'bird', 'cat', 'dog',
               'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
               'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
               'skis', 'snowboard', 'sports_ball', 'kite', 'baseball_bat',
               'baseball_glove', 'skateboard', 'surfboard', 'tennis_racket',
               'bottle', 'wine_glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
               'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
               'hot_dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
               'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', 'laptop',
               'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
               'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
               'vase', 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush')


    def load_annotations(self, ann_file):
        self.coco = COCO(ann_file)
        self.cat_ids = self.coco.getCatIds()
        self.cat2label = {
   
            cat_id: i + 1
            for i, cat_id in enumerate(self.cat_ids)
        }
        self.img_ids = self.coco.getImgIds()
        img_infos = []
        for i in self.img_ids:
            info = self.coco.loadImgs([i])[0]
            info['filename'] = info['file_name']
            img_infos.append(info)
        return img_infos

    def get_ann_info(self, idx):
        img_id = self.img_infos[idx]['id']
        ann_ids = self.coco.getAnnIds(imgIds=[img_id])
        ann_info = self.coco.loadAnns(ann_ids)
        return self._parse_ann_info(ann_info, self.with_mask)

    def _filter_imgs(self, min_size=32):
        """Filter images too small or without ground truths."""
        valid_inds = []
        ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
        for i, img_info in enumerate(self.img_infos):
            if self.img_ids[i] not in ids_with_ann:
                continue
            if min(img_info['width'], img_info['height']) >= min_size:
                valid_inds.append(i)
        return valid_inds

    def _parse_ann_info(self, ann_info, with_mask=True):
        """Parse bbox and mask annotation.

        Args:
            ann_info (list[dict]): Annotation info of an image.
            with_mask (bool): Whether to parse mask annotations.

        Returns:
            dict: A dict containing the following keys: bboxes, bboxes_ignore,
                labels, masks, mask_polys, poly_lens.
        """
        gt_bboxes = []
        gt_labels = []
        gt_bboxes_ignore = []
        # Two formats are provided.
        # 1. mask: a binary map of the same size of the image.
        # 2. polys: each mask consists of one or several polys, each poly is a
        # list of float.


        self.debug = False

        if with_mask:
            gt_masks = []
            gt_mask_polys = []
            gt_poly_lens = []

        if self.debug:
            count = 0
            total = 0
        for i, ann in enumerate(ann_info):
            if ann.get('ignore', False):
                continue
            x1, y1, w, h = ann['bbox']
            #filter bbox < 10
            if self.debug:
                total+=1

            if ann['area'] <= 15 or (w < 10 and h < 10) or self.coco.annToMask(ann).sum() < 15:
                # print('filter, area:{},w:{},h:{}'.format(ann['area'],w,h))
                if self.debug:
                    count+=1
                continue

            bbox = [x1, y1, x1 + w - 1, y1 + h - 1]
            if ann['iscrowd']:
                gt_bboxes_ignore.append(bbox)
            else:
                gt_bboxes.append(bbox)
                gt_labels.append(self.cat2label[ann['category_id']])
            if with_mask:
                gt_masks.append(self.coco.annToMask(ann))
                mask_polys = [
                    p for p in ann['segmentation'] if len(p) >= 6
                ]  # valid polygons have >= 3 points (6 coordinates)
                poly_lens = [len(p) for p in mask_polys]
                gt_mask_polys.append(mask_polys)
                gt_poly_lens.extend(poly_lens)

        if self.debug:
            print('filter:',count/total)
        if gt_bboxes:
            gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
            gt_labels = np.array(gt_labels, dtype=np.int64)
        else:
            gt_bboxes = np.zeros((0, 4), dtype=np.float32)
            gt_labels = np.array([], dtype=np.int64)

        if gt_bboxes_ignore:
            gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
        else:
            gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)

        ann = dict(
            bboxes=gt_bboxes, labels=gt_labels, bboxes_ignore=gt_bboxes_ignore)

        if with_mask:
            ann['masks'] = gt_masks
            # poly format is not used in the current implementation
            ann['mask_polys'] = gt_mask_polys
            ann['poly_lens'] = g
  • 3
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 22
    评论
### 回答1: MMDetection 是一套开源的目标检测框架,你可以在官方网站上查看它的文档,详细了解它的安装方法、使用方法及相关技术背景。此外,你还可以通过观看视频教程,或者参加在线课程或线下培训来学习MMDetection。 ### 回答2: 学习MMDetection可以遵循以下步骤: 1. 学习基础知识:在学习MMDetection之前,确保你已经掌握了深度学习、计算机视觉、Python编程和相关的数学知识(如线性代数和概率统计)。 2. 理解MMDetection框架:详细阅读MMDetection的官方文档,了解其框架和模块之间的关系。MMDetection是一个基于PyTorch的开源目标检测框架,具有丰富的预训练模型和数据集支持。 3. 了解数据集:学习如何准备和处理目标检测所需的数据集。掌握不同数据集的格式、标注工具以及数据增强技术等,对于后续的模型训练和评估至关重要。 4. 实践编码:尝试使用MMDetection框架进行实际的目标检测任务。可以从官方提供的教程和示例代码开始,逐步修改和调整以满足特定需求。 5. 调试和优化:在实践过程中,可能会遇到许多问题。学会调试代码,理解模型训练过程中的性能瓶颈,并尝试使用不同的优化技术和策略来提升模型的准确性和速度。 6. 深入阅读和研究:阅读相关的论文和博客,了解最新的目标检测算法和技术。从MMDetection的源代码中获取更深入的了解,探索其内部实现和特性。 7. 加入社区和讨论:可以加入MMDetection的官方社区和论坛,与其他开发者交流经验和问题。参与讨论和分享,从中获取更多的帮助和学习机会。 通过以上步骤,你可以逐步掌握MMDetection的基础知识和技巧,并在实践中不断提高自己的目标检测能力。 ### 回答3: 学习 MMDetection 的方法有很多,以下是我推荐的步骤: 1. 基础知识学习:首先,你需要了解目标检测的基本概念、算法和技术。阅读有关目标检测的教材、论文及相关博客文章,掌握物体检测的背景知识。 2. 学习MMDetection框架:阅读 MMDetection 的官方文档,了解其整体结构、主要模块和功能。学习如何使用配置文件、数据读取器等工具。 3. 数据准备:收集和准备适合目标检测任务的数据集,确保数据集的标注准确、完整。学习数据增强技术,提高模型的泛化能力。 4. 模型训练:掌握如何使用 MMDetection 进行模型训练。了解不同的训练策略、损失函数和优化器的选择。通过调整超参数,优化模型的性能。 5. 模型评估与调优:学习如何使用 MMDetection 进行模型评估和性能分析。掌握评估指标的含义和计算方法。通过调整模型结构和超参数,提升模型的性能。 6. 进一步探索:利用 MMDetection 的预训练模型,在其他数据集上进行迁移学习。尝试对模型进行改进,如增加新的网络层、引入注意力机制等。 7. 社区交流:加入 MMDetection 的官方论坛或社群,与其他学习者和开发者交流。通过共享经验和讨论,加深对 MMDetection 的理解。 8. 运用实践:实践是学习的最好方式。尝试在真实项目中应用 MMDetection,解决实际的目标检测问题,提升自己的技能和经验。 记住,学习是一个渐进的过程,需要不断实践和积累。不要急于求成,保持耐心和持续的努力,你一定能够掌握 MMDetection 框架。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 22
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值