1. 基础环境:
显卡:Nvidia titan xp
驱动:460.73.01
cuda: 10.1
2. detectron2安装:
follow installation tutorials: https://detectron2.readthedocs.io/en/latest/tutorials/install.html
NOTE:
(1) create envs:
conda create -n detectron2 python=3.6.2
source activate detectron2
(2) 一些库:
- pytorch: torchvision版本需要和pytorch的版本相适应,可参考 pytorch.org 安装.
cuda10.1+pytorch1.7:
conda install pytorch==1.7 cudatoolkit=10.1 torchvision==0.8.1 torchaudio -c pytorch
- opencv: pip install opencv-python
- fvcore:
pip install -U 'git+https://github.com/facebookresearch/fvcore'
- pycocotools: conda install pycocotools=2.0.2
cython: pip install cython
清华源:
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/msys2/
conda config --set show_channel_urls yes
(3)detectron2
git clone https://github.com/facebookresearch/detectron2.git
python -m pip install -e detectron2
rebuild:
rm -rf build/ **/*.so
python -m pip install detectron2 -f \
https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/torch1.7/index.html
3. 用detectron2里的Mask R-CNN 训练自己的数据
3.1 数据集注册, 指定类别信息, 指定数据集子集的名字(默认为coco_train2017, coco_val2017)
- 复制官方的
tools/train_net.py
文件到根目录下,加上注册数据集部分. - 补充数据集类别信息, 改成直接从json 文件中读取,
完整代码如下:
#!/usr/bin/env python import logging import os, json from collections import OrderedDict import torch import numpy as np import detectron2.utils.comm as comm from detectron2.checkpoint import DetectionCheckpointer from detectron2.config import get_cfg from detectron2.data import MetadataCatalog, DatasetCatalog from detectron2.data.datasets.coco import load_coco_json from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch from detectron2.evaluation import ( CityscapesInstanceEvaluator, CityscapesSemSegEvaluator, COCOEvaluator, COCOPanopticEvaluator, DatasetEvaluators, LVISEvaluator, PascalVOCDetectionEvaluator, SemSegEvaluator, verify_results, ) from detectron2.modeling import GeneralizedRCNNWithTTA import colorsys # 数据集路径 DATASET_ROOT = 'datasets/iSAID-reduce100/' ANN_ROOT = os.path.join(DATASET_ROOT, 'annotations') TRAIN_PATH = os.path.join(DATASET_ROOT, 'train') VAL_PATH = os.path.join(DATASET_ROOT, 'val') TRAIN_JSON = os.path.join(ANN_ROOT, 'instances_train.json') VAL_JSON = os.path.join(ANN_ROOT, 'instances_val.json') # 数据集类别元数据 ### 从json 文件中读取 DATASET_CATEGORIES = json.load(open(TRAIN_JSON,'r'))['categories'] N = len(DATASET_CATEGORIES) ###number classes , ignore background bright = True brightness = 1.0 if bright else 0.7 hsv = [(i / N, 1, brightness) for i in range(N)] colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) colors = (np.array(colors)*255).astype('uint8') for i, category_info in enumerate(DATASET_CATEGORIES): category_info.update({"isthing": 1, "color":colors[i]}) DATASET_CATEGORIES[i] = category_info # 数据集的子集 #注意:这里的数据集名字,需要更新到你的config文件中的DATASETS:Base-RCNN-FPN.yaml里. PREDEFINED_SPLITS_DATASET = { "iSAID_train": (TRAIN_PATH, TRAIN_JSON), "iSAID_val": (VAL_PATH, VAL_JSON), } def register_dataset(): """ purpose: register all splits of dataset with PREDEFINED_SPLITS_DATASET """ for key, (image_root, json_file) in PREDEFINED_SPLITS_DATASET.items(): print('key:', key, 'image_root:', image_root, 'json_file:', json_file) register_dataset_instances(name=key, metadate=get_dataset_instances_meta(), json_file=json_file, image_root=image_root) def get_dataset_instances_meta(): """ purpose: get metadata of dataset from DATASET_CATEGORIES return: dict[metadata] """ thing_ids = [k["id"] for k in DATASET_CATEGORIES if k["isthing"] == 1] thing_colors = [k["color"] for k in DATASET_CATEGORIES if k["isthing"] == 1] # assert len(thing_ids) == 2, len(thing_ids) thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)} thing_classes = [k["name"] for k in DATASET_CATEGORIES if k["isthing"] == 1] ret = { "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id, "thing_classes": thing_classes, "thing_colors": thing_colors, } return ret def register_dataset_instances(name, metadate, json_file, image_root): """ purpose: register dataset to DatasetCatalog, register metadata to MetadataCatalog and set attribute """ DatasetCatalog.register(name, lambda: load_coco_json(json_file, image_root, name)) MetadataCatalog.get(name).set(json_file=json_file, image_root=image_root, evaluator_type="coco", **metadate) class Trainer(DefaultTrainer): @classmethod def build_evaluator(cls, cfg, dataset_name, output_folder=None): if output_folder is None: output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") evaluator_list = [] evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type if evaluator_type in ["sem_seg", "coco_panoptic_seg"]: evaluator_list.append( SemSegEvaluator( dataset_name, distributed=True, num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, ignore_label=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, output_dir=output_folder, ) ) if evaluator_type in ["coco", "coco_panoptic_seg"]: evaluator_list.append(COCOEvaluator(dataset_name, cfg, True, output_folder)) if evaluator_type == "coco_panoptic_seg": evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder)) if evaluator_type == "cityscapes_instance": assert ( torch.cuda.device_count() >= comm.get_rank() ), "CityscapesEvaluator currently do not work with multiple machines." return CityscapesInstanceEvaluator(dataset_name) if evaluator_type == "cityscapes_sem_seg": assert ( torch.cuda.device_count() >= comm.get_rank() ), "CityscapesEvaluator currently do not work with multiple machines." return CityscapesSemSegEvaluator(dataset_name) elif evaluator_type == "pascal_voc": return PascalVOCDetectionEvaluator(dataset_name) elif evaluator_type == "lvis": return LVISEvaluator(dataset_name, cfg, True, output_folder) if len(evaluator_list) == 0: raise NotImplementedError( "no Evaluator for the dataset {} with the type {}".format( dataset_name, evaluator_type ) ) elif len(evaluator_list) == 1: return evaluator_list[0] return DatasetEvaluators(evaluator_list) @classmethod def test_with_TTA(cls, cfg, model): logger = logging.getLogger("detectron2.trainer") # In the end of training, run an evaluation with TTA # Only support some R-CNN models. logger.info("Running inference with test-time augmentation ...") model = GeneralizedRCNNWithTTA(cfg, model) evaluators = [ cls.build_evaluator( cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA") ) for name in cfg.DATASETS.TEST ] res = cls.test(cfg, model, evaluators) res = OrderedDict({k + "_TTA": v for k, v in res.items()}) return res def setup(args): """ Create configs and perform basic setups. """ cfg = get_cfg() cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) cfg.freeze() default_setup(cfg, args) return cfg def main(args): cfg = setup(args) # 注册数据集 register_dataset() if args.eval_only: model = Trainer.build_model(cfg) DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS , resume=args.resume ) res = Trainer.test(cfg, model) if cfg.TEST.AUG.ENABLED: res.update(Trainer.test_with_TTA(cfg, model)) if comm.is_main_process(): verify_results(cfg, res) return res """ If you'd like to do anything fancier than the standard training logic, consider writing your own training loop (see plain_train_net.py) or subclassing the trainer. """ trainer = Trainer(cfg) trainer.resume_or_load(resume=args.resume) if cfg.TEST.AUG.ENABLED: trainer.register_hooks( [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))] ) return trainer.train() if __name__ == "__main__": args = default_argument_parser().parse_args() print("Command Line Args:", args) launch( main, args.num_gpus, num_machines=args.num_machines, machine_rank=args.machine_rank, dist_url=args.dist_url, args=(args,), )
对于window 下制作的数据集, 图片和json 文件路径中会有'\\', 在ubuntu 下读取会存在问题:
给detectron2/detectron2/data/dataset_mapper.py", line 125 加上replace('\\','/'):
image = utils.read_image(dataset_dict["file_name"].replace('\\','/'), format=self.image_format)
3.2 训练
python train.py --config-file configs/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml --num-gpus 1 SOLVER.IMS_PER_BATCH 4 SOLVER.BASE_LR 0.0001
- AssertionError: Requires pyyaml>=5.1: conda install pyyaml==5.1
- No module named 'cloudpickle': conda install cloudpickle