Mask R-CNN代码分析(二)

二、Mask R-CNN代码解读

FAIR在发布detectron的同时,也发布了一系列的tutorial文件,接下来将根据Detectron/GETTING_STARTED.md文件来解读代码。

先来看一下detectron的文件结构。

  • config中是训练和测试的配置文件,官方的baseline的参数配置都以.yaml文件的形式存放在其中。
  • demo一些图像示例还有分割好的结果。
  • detectron核心程序,包括参数配置、数据集准备、模型、训练和测试的一些工具,都存放在其中。
  • tools运行模型时调用的工具,包括推断、训练、测试等。

使用预训练好的模型进行推断

1.目录中的图片

推断目录中的图片使用tools/infer_simple.py工具,命令如下:

python2 tools/infer_simple.py \
    --cfg configs/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_2x.yaml \
    --output-dir /tmp/detectron-visualizations \
    --image-ext jpg \
    --wts https://s3-us-west-2.amazonaws.com/detectron/35861858/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_2x.yaml.02_32_51.SgT4y1cO/output/train/coco_2014_train:coco_2014_valminusminival/generalized_rcnn/model_final.pkl \
    demo

--cfg是之前提到的配置文件,detectron在运行程序时首先导入存放在core/config.py的所有参数的默认值,然后在调用函数merge_cfg_from_file(args.cfg),将--cfg参数引用的配置文件中存放的参数将默认值替换。举个例子,在config.py中关于数据集中的类别数有默认的定义:

# Number of classes in the dataset; must be set
# E.g., 81 for COCO (80 foreground + 1 background)
__C.MODEL.NUM_CLASSES = -1

 这显然是一个默认值,需要我们在--cfg的配置文件中重新设置。故在configs/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_2x.yaml中有:

MODEL:
  TYPE: generalized_rcnn
  CONV_BODY: FPN.add_fpn_ResNet101_conv5_body
  NUM_CLASSES: 81
  FASTER_RCNN: True
  MASK_ON: True

merge_cfg_from_file()函数的功能就是将cfg文件中的MODEL.NUM_CLASSES的值(=81)替换config.py中的__C.MODEL.NUM_CLASSES(=-1)。

--image-ext是输出图像的后缀。

--wts是模型的参数文件,其实也就意味着是训练好可以拿来直接使用的模型。这里给的是一个地址,是官方训练好上传到亚马逊云上的模型,因为这样下载会很慢,所以也可以提前下载好(用迅雷)存放在本地,将--wts参数替换为本地的地址。在运行程序中,会检测--wts后的参数是网址还是地址,自动调取模型文件。

接下来就来看infer_simple.py的代码。

if __name__ == '__main__':
    workspace.GlobalInit(['caffe2', '--caffe2_log_level=0']) # 对工作区的全局初始化
    setup_logging(__name__) # 日志设置
    args = parse_args() # 参数读取
    main(args)

对于日志的设置,要事先导入logging模块。

from detectron.utils.logging import setup_logging

 setup_logging定义如下:

def setup_logging(name):
    FORMAT = '%(levelname)s %(filename)s:%(lineno)4d: %(message)s'
    # %(levelname)s: 打印日志级别名称
    # %(filename)s: 打印当前执行程序名
    # %(lineno)d: 打印日志的当前行号
    # %(message)s: 打印日志信息

    # Manually clear root loggers to prevent any module that may have called
    # logging.basicConfig() from blocking our logging setup
    logging.root.handlers = []
    logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
    logger = logging.getLogger(name)
    return logger

读取参数:

# 命令行解析模块
def parse_args():
    # 创建解析器
    # 创建一个ArgumentParser实例,ArgumentParser的参数都为关键字参数
    # description :help信息前显示的信息
    parser = argparse.ArgumentParser(description='End-to-end inference')
    # 添加参数选项 :add_argument
    # name or flags :参数有两种,可选参数和位置参数
    # dest :参数名
    # default :默认值
    # type :参数类型,默认为str
    parser.add_argument(
        '--cfg',
        dest='cfg',
        help='cfg model file (/path/to/model_config.yaml)',
        default=None,
        type=str
    )
    parser.add_argument(
        '--wts',
        dest='weights',
        help='weights model file (/path/to/model_weights.pkl)',
        default=None,
        type=str
    )
    parser.add_argument(
        '--output-dir',
        dest='output_dir',
        help='directory for visualization pdfs (default: /tmp/infer_simple)',
        default='/tmp/infer_simple',
        type=str
    )
    parser.add_argument(
        '--image-ext',
        dest='image_ext',
        help='image file name extension (default: jpg)',
        default='jpg',
        type=str
    )
    parser.add_argument(
        '--always-out',
        dest='out_when_no_box',
        help='output image even when no object is found',
        # action
  • 2
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值