二、Mask R-CNN代码解读
FAIR在发布detectron的同时,也发布了一系列的tutorial文件,接下来将根据Detectron/GETTING_STARTED.md文件来解读代码。
先来看一下detectron的文件结构。
- config中是训练和测试的配置文件,官方的baseline的参数配置都以.yaml文件的形式存放在其中。
- demo一些图像示例还有分割好的结果。
- detectron核心程序,包括参数配置、数据集准备、模型、训练和测试的一些工具,都存放在其中。
- tools运行模型时调用的工具,包括推断、训练、测试等。
使用预训练好的模型进行推断
1.目录中的图片
推断目录中的图片使用tools/infer_simple.py工具,命令如下:
python2 tools/infer_simple.py \
--cfg configs/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_2x.yaml \
--output-dir /tmp/detectron-visualizations \
--image-ext jpg \
--wts https://s3-us-west-2.amazonaws.com/detectron/35861858/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_2x.yaml.02_32_51.SgT4y1cO/output/train/coco_2014_train:coco_2014_valminusminival/generalized_rcnn/model_final.pkl \
demo
--cfg是之前提到的配置文件,detectron在运行程序时首先导入存放在core/config.py的所有参数的默认值,然后在调用函数merge_cfg_from_file(args.cfg),将--cfg参数引用的配置文件中存放的参数将默认值替换。举个例子,在config.py中关于数据集中的类别数有默认的定义:
# Number of classes in the dataset; must be set
# E.g., 81 for COCO (80 foreground + 1 background)
__C.MODEL.NUM_CLASSES = -1
这显然是一个默认值,需要我们在--cfg的配置文件中重新设置。故在configs/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_2x.yaml中有:
MODEL:
TYPE: generalized_rcnn
CONV_BODY: FPN.add_fpn_ResNet101_conv5_body
NUM_CLASSES: 81
FASTER_RCNN: True
MASK_ON: True
merge_cfg_from_file()函数的功能就是将cfg文件中的MODEL.NUM_CLASSES的值(=81)替换config.py中的__C.MODEL.NUM_CLASSES(=-1)。
--image-ext是输出图像的后缀。
--wts是模型的参数文件,其实也就意味着是训练好可以拿来直接使用的模型。这里给的是一个地址,是官方训练好上传到亚马逊云上的模型,因为这样下载会很慢,所以也可以提前下载好(用迅雷)存放在本地,将--wts参数替换为本地的地址。在运行程序中,会检测--wts后的参数是网址还是地址,自动调取模型文件。
接下来就来看infer_simple.py的代码。
if __name__ == '__main__':
workspace.GlobalInit(['caffe2', '--caffe2_log_level=0']) # 对工作区的全局初始化
setup_logging(__name__) # 日志设置
args = parse_args() # 参数读取
main(args)
对于日志的设置,要事先导入logging模块。
from detectron.utils.logging import setup_logging
setup_logging定义如下:
def setup_logging(name):
FORMAT = '%(levelname)s %(filename)s:%(lineno)4d: %(message)s'
# %(levelname)s: 打印日志级别名称
# %(filename)s: 打印当前执行程序名
# %(lineno)d: 打印日志的当前行号
# %(message)s: 打印日志信息
# Manually clear root loggers to prevent any module that may have called
# logging.basicConfig() from blocking our logging setup
logging.root.handlers = []
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(name)
return logger
读取参数:
# 命令行解析模块
def parse_args():
# 创建解析器
# 创建一个ArgumentParser实例,ArgumentParser的参数都为关键字参数
# description :help信息前显示的信息
parser = argparse.ArgumentParser(description='End-to-end inference')
# 添加参数选项 :add_argument
# name or flags :参数有两种,可选参数和位置参数
# dest :参数名
# default :默认值
# type :参数类型,默认为str
parser.add_argument(
'--cfg',
dest='cfg',
help='cfg model file (/path/to/model_config.yaml)',
default=None,
type=str
)
parser.add_argument(
'--wts',
dest='weights',
help='weights model file (/path/to/model_weights.pkl)',
default=None,
type=str
)
parser.add_argument(
'--output-dir',
dest='output_dir',
help='directory for visualization pdfs (default: /tmp/infer_simple)',
default='/tmp/infer_simple',
type=str
)
parser.add_argument(
'--image-ext',
dest='image_ext',
help='image file name extension (default: jpg)',
default='jpg',
type=str
)
parser.add_argument(
'--always-out',
dest='out_when_no_box',
help='output image even when no object is found',
# action