一. 代码准备
基于pytorch。
mask scoring rcnn 代码参考:【github】
mask rcnn benchmark 【github】
二. 环境安装
1. 基于conda创建pytorch环境:
conda create -n pytorch python=3.7.4
conda install ipython
conda install -c pytorch pytorch-nightly torchvision=0.2.1 cudatoolkit=9.0 # 注:必须9.0
conda activate pytorch
pip install numpy scipy ninja yacs cython matplotlib tqdm opencv-python
注:pip install torchvision==0.2.1,否则会出现 AttributeError: 'list' object has no attribute 'resize' #45
参考 https://github.com/zjhuang22/maskscoring_rcnn/issues/45
2. 安装cocoapi & apex:
export INSTALL_DIR=$PWD
# install pycocotools
git clone https://github.com/cocodataset/cocoapi.git
cd cocoapi/PythonAPI
python setup.py build_ext install
# install apex
cd $INSTALL_DIR
git clone https://github.com/NVIDIA/apex.git
cd apex
python setup.py install --cuda_ext --cpp_ext
3. 安装detection:
# install PyTorch Detection
cd $INSTALL_DIR
#maskrcnn-benchmark
#git clone https://github.com/facebookresearch/maskrcnn-benchmark.git
git clone https://github.com/zjhuang22/maskscoring_rcnn
cd maskscoring_rcnn
python setup.py build develop
三. 数据准备
训练数据基于labelme标注,需要转换成coco格式,目前常用的是 labelme2coco.py,可以找到比较多的code,直接用就好了。
# -*- coding:utf-8 -*-
import os, sys
import argparse
import json
import matplotlib.pyplot as plt
import skimage.io as io
from labelme import utils
import numpy as np
import glob
import PIL.Image
class MyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
else:
return super(MyEncoder, self).default(obj)
class labelme2coco(object):
def __init__(self, labelme_json=[], save_json_path='./tran.json'):
'''
:param labelme_json: 所有labelme的json文件路径组成的列表
:param save_json_path: json保存位置
'''
self.labelme_json = labelme_json
self.save_json_path = save_json_path
self.images = []
self.categories = []
self.annotations = []
# self.data_coco = {}
self.label = []
self.annID = 1
self.height = 0
self.width = 0
self.save_json()
def data_transfer(self):
for num, json_file in enumerate(self.labelme_json):
with open(json_