【mmdetection】PolarMask
官方代码:https://github.com/xieenze/PolarMask
官方论文解读:https://zhuanlan.zhihu.com/p/84890413
代码运行
创建环境
conda create -n open-mmlab python=3.7 -y
conda activate open-mmlab
conda install cython
Install PyTorch stable
eg:conda install pytorch torchvision cudatoolkit=9.2 -c pytorch
git clone https://github.com/xieenze/PolarMask.git
cd PolarMask
python setup.py develop
# or "pip install -v -e ."
修改配置文件data_root
为数据集所在位置。
train
# 单gpu
python tools/train.py configs/polarmask/4gpu/polar_768_1x_r50.py
#多gpu
./tools/dist_train.sh configs/polarmask/4gpu/polar_768_1x_r50.py 4 --launcher pytorch --work_dir ./work_dirs/polar_768_1x_r50_4gpu
出错No module named 'Polygon
polygon3
pip install Polygon3
出错cannot import name 'get_dist_info' from 'mmcv.runner.utils'
修改/PolarMask/mmdet/datasets/loader/sampler.py中from mmcv.runner.utils import get_dist_info变为from mmcv.runner import get_dist_info
test
# 单gpu
python tools/test.py configs/polarmask/4gpu/polar_768_1x_r50.py [YOUR_CHECKPOINT_DIR] --out [OUT_DIR]
# eg:python tools/test.py configs/polarmask/4gpu/polar_768_1x_r101.py /home/wh/weights/polarmask_r101_1x.pth --out work_dirs/polar101.pkl
数据读取
代码在mmdetection中重新定义了Coco_Seg_Dataset
类,在mmdet/datasets/coco_seg.py
里,主要有五个函数,这五个函数都是原本coco.py里面的。
- load_annotations():加载标注文件中的annotation字典,返回图片名字。
- get_ann_info():实际调用_parse_ann_info()。
- _filter_imgs():过滤没有标注文件的图片和尺寸小于min-size的图片。
- _parse_ann_info():返回key为bboxes,bboxes_ignore, labels, masks, mask_polys, poly_lens的字典。
- prepare_train_img() :重载了custom.py中CustomDataset的prepare_train_img() 方法。
在这一部分,还进行了标签的分配。
@DATASETS.register_module
class Coco_Seg_Dataset(CustomDataset):
CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic_light', 'fire_hydrant',
'stop_sign', 'parking_meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports_ball', 'kite', 'baseball_bat',
'baseball_glove', 'skateboard', 'surfboard', 'tennis_racket',
'bottle', 'wine_glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
'hot_dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
'vase', 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush')
def load_annotations(self, ann_file):
self.coco = COCO(ann_file)
self.cat_ids = self.coco.getCatIds()
self.cat2label = {
cat_id: i + 1
for i, cat_id in enumerate(self.cat_ids)
}
self.img_ids = self.coco.getImgIds()
img_infos = []
for i in self.img_ids:
info = self.coco.loadImgs([i])[0]
info['filename'] = info['file_name']
img_infos.append(info)
return img_infos
def get_ann_info(self, idx):
img_id = self.img_infos[idx]['id']
ann_ids = self.coco.getAnnIds(imgIds=[img_id])
ann_info = self.coco.loadAnns(ann_ids)
return self._parse_ann_info(ann_info, self.with_mask)
def _filter_imgs(self, min_size=32):
"""Filter images too small or without ground truths."""
valid_inds = []
ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
for i, img_info in enumerate(self.img_infos):
if self.img_ids[i] not in ids_with_ann:
continue
if min(img_info['width'], img_info['height']) >= min_size:
valid_inds.append(i)
return valid_inds
def _parse_ann_info(self, ann_info, with_mask=True):
"""Parse bbox and mask annotation.
Args:
ann_info (list[dict]): Annotation info of an image.
with_mask (bool): Whether to parse mask annotations.
Returns:
dict: A dict containing the following keys: bboxes, bboxes_ignore,
labels, masks, mask_polys, poly_lens.
"""
gt_bboxes = []
gt_labels = []
gt_bboxes_ignore = []
# Two formats are provided.
# 1. mask: a binary map of the same size of the image.
# 2. polys: each mask consists of one or several polys, each poly is a
# list of float.
self.debug = False
if with_mask:
gt_masks = []
gt_mask_polys = []
gt_poly_lens = []
if self.debug:
count = 0
total = 0
for i, ann in enumerate(ann_info):
if ann.get('ignore', False):
continue
x1, y1, w, h = ann['bbox']
#filter bbox < 10
if self.debug:
total+=1
if ann['area'] <= 15 or (w < 10 and h < 10) or self.coco.annToMask(ann).sum() < 15:
# print('filter, area:{},w:{},h:{}'.format(ann['area'],w,h))
if self.debug:
count+=1
continue
bbox = [x1, y1, x1 + w - 1, y1 + h - 1]
if ann['iscrowd']:
gt_bboxes_ignore.append(bbox)
else:
gt_bboxes.append(bbox)
gt_labels.append(self.cat2label[ann['category_id']])
if with_mask:
gt_masks.append(self.coco.annToMask(ann))
mask_polys = [
p for p in ann['segmentation'] if len(p) >= 6
] # valid polygons have >= 3 points (6 coordinates)
poly_lens = [len(p) for p in mask_polys]
gt_mask_polys.append(mask_polys)
gt_poly_lens.extend(poly_lens)
if self.debug:
print('filter:',count/total)
if gt_bboxes:
gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
gt_labels = np.array(gt_labels, dtype=np.int64)
else:
gt_bboxes = np.zeros((0, 4), dtype=np.float32)
gt_labels = np.array([], dtype=np.int64)
if gt_bboxes_ignore:
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
else:
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
ann = dict(
bboxes=gt_bboxes, labels=gt_labels, bboxes_ignore=gt_bboxes_ignore)
if with_mask:
ann['masks'] = gt_masks
# poly format is not used in the current implementation
ann['mask_polys'] = gt_mask_polys
ann['poly_lens'] = g