Detectron研读和实践四:用Mask R-CNN进行服饰关键点定位

前段时间参加了阿里天池的FashionAI服饰关键点定位比赛,为了做比赛,博主尝试用Detectron里面的Mask R-CNN去做关键点定位,取得了一定效果,也算是对Detectron的一些实践,特此做一些记录,希望对需要的朋友有所帮助。


1 Mask R-CNN简介

论文链接Mask R-CNN
这里写图片描述

Mask R-CNN是何凯明等人在Faster R-CNN基础上提出的一个优秀的目标实例分割模型。该模型能够有效地检测图像中的目标并为每个实例生成高质量的分割掩码。如下图所示,该模型通过在Faster R-CNN已存在的bbox识别分支旁并行地添加一个用于预测目标掩码的分支。掩码分支是一个应用到每个RoI上的小型FCN(全卷积网络),能够预测RoI中每个像素所属的类别,从而实现准确的实例分割。

这里写图片描述

Mask R-CNN的技术要点主要有三点:

  • 在原有的Faster R-CNN基础上采用FCN结构添加了一个并行的mask分支用于生成目标的掩模。
  • 由于实例分割对目标掩模的位置精度要求高,为了解决RoIPooling存在的misalignment问题(坐标映射和划分RoI bins时存在量化误差,导致实际的RoI不能严格对齐真实的RoI),作者引入了RoIAlign操作,使用双线性差值(在两个方向分别进行一次线性插值)计算出在每个RoI bin的四个规则采样位置的输入特征的精确浮点值,然后使用max或average整合结果。详细解释可参考详解 ROI Align 的基本原理和实现细节
  • mask分支的损失函数使用的是基于sigmoid的平均二值交叉熵损失函数,而不是FCN中常用的基于softmax的多类交叉熵损失,这样允许网络为每一类输出对应的mask,避免类间竞争,使得mask与类别预测解耦分离。

关于Mask R-CNN的详细解读可以参见这篇博客

Mask R-CNN训练简单,速度与Faster R-CNN相当,而且可以很容易的推广到其它与实例水平的识别相关的任务中,如目标检测和人体姿态估计。比如Mask R-CNN用于人体姿态估计时,主要是通过定位人体关键点来实现的,人体关键点在Mask R-CNN中可以被视作单个像素的mask。在第二部分,我将简单地介绍如何用Detectron的Mask R-CNN模型进行服饰关键点的定位。

2 用Mask R-CNN进行服饰关键点定位

2.1 准备数据集

该系列上一篇博客中介绍的准备数据集的方法类似,不同之处是这次要自己做COCO json格式的关键点annotation文件了,COCO json数据格式详见官网

COCO用于目标检测的annotation文件格式大致是下面这样的:

annotation{
"id" : int, 
"image_id" : int, 
"category_id" : int, 
"segmentation" : RLE or [polygon], 
"area" : float, 
"bbox" : [x,y,width,height], 
"iscrowd" : 0 or 1,
}

categories[{
"id" : int, "name" : str, "supercategory" : str,
}]

而关键点检测的annotation格式则是在上面格式的annotation字段和categories字段分别增加了一下内容:

annotation{
"keypoints" : [x1,y1,v1,...], #每个关键点以三元组x,y,v表示,v标识了关键点的存在性和可见性,为0表示不存在,为1表示不可见,为2表示可见
"num_keypoints" : int, 
"[cloned]" : ...,
}

categories[{
"keypoints" : [str], 
"skeleton" : [edge],  #由关键点按一定的顺序连接而成
"[cloned]" : ...,
}]

"[cloned]": denotes fields copied from object detection annotations defined above.

再来看一看COCO数据集annotation文件到底是个啥样子

# person_keypoint_val2014.json内容
# 字典的5个字段
info
licenses
images
annotations
categories
#各类字段的长度和对应的第一个元素
#info
6 {u'description': u'COCO 2014 Dataset', u'url': u'http://cocodataset.org', u'version': u'1.0', u'year': 2014, u'contributor': u'COCO Consortium', u'date_created': u'2017/09/01'}
#license
8 {u'url': u'http://creativecommons.org/licenses/by-nc-sa/2.0/', u'id': 1, u'name': u'Attribution-NonCommercial-ShareAlike License'}
#images
40504 {u'license': 3, u'file_name': u'COCO_val2014_000000391895.jpg', u'coco_url': u'http://images.cocodataset.org/val2014/COCO_val2014_000000391895.jpg', u'height': 360, u'width': 640, u'date_captured': u'2013-11-14 11:18:45', u'flickr_url': u'http://farm9.staticflickr.com/8186/8119368305_4e622c8349_z.jpg', u'id': 391895}
#categories
1 {u'supercategory': u'person', u'id': 1, u'name': u'person', u'keypoints': [u'nose', u'left_eye', u'right_eye', u'left_ear', u'right_ear', u'left_shoulder', u'right_shoulder', u'left_elbow', u'right_elbow', u'left_wrist', u'right_wrist', u'left_hip', u'right_hip', u'left_knee', u'right_knee', u'left_ankle', u'right_ankle'], u'skeleton': [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13], [6, 7], [6, 8], [7, 9], [8, 10], [9, 11], [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]]}
#annotations
88153 {u'segmentation': [[267.03, 243.78, 314.59, 154.05, 357.84, 136.76, 374.05, 104.32, 410.81, 110.81, 429.19, 131.35, 420.54, 165.95, 451.89, 209.19, 464.86, 240.54, 480, 253.51, 484.32, 263.24, 496.22, 271.89, 484.32, 278.38, 438.92, 257.84, 401.08, 216.76, 370.81, 247.03, 414.05, 277.3, 433.51, 304.32, 443.24, 323.78, 400, 362.7, 376.22, 375.68, 400, 418.92, 394.59, 424.32, 337.3, 382.16, 337.3, 371.35, 388.11, 327.03, 341.62, 301.08, 311.35, 276.22, 304.86, 263.24, 294.05, 249.19]], u'num_keypoints': 8, u'area': 28292.08625, u'iscrowd': 0, u'keypoints': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 325, 160, 2, 398, 177, 2, 0, 0, 0, 437, 238, 2, 0, 0, 0, 477, 270, 2, 287, 255, 1, 339, 267, 2, 0, 0, 0, 423, 314, 2, 0, 0, 0, 355, 367, 2], u'image_id': 537548, u'bbox': [267.03, 104.32, 229.19, 320], u'category_id': 1, u'id': 183020}

知道了关键点的annotation文件格式和具体内容就可以参照着制作自己的annotation文件了,其中关键的三个字段是images,annotations和categories,另外两个字段并不是很要紧。不过要注意的一点是由于Detectron中目前只支持人体关键点检测,因此我们用自己的数据集做的annotation文件中的categories字段的supercategory和name的值都要是person,否则会报错。

下面是我做好的服饰关键点annotation文件的大致样子

info
licenses
images
annotations
categories
2 {u'url': u'https://tianchi.aliyun.com/competition/introduction.htm?spm=5176.100067.5678.1.510044a9T425c6&raceId=231648', u'description': u'FashionAI Dataset'}
1 {u'url': u'http://creativecommons.org/licenses/by-nc-nd/2.0/', u'id': 3, u'name': u'Attribution-NonCommercial-NoDerivs License'}
2292 {u'file_name': u'Images/skirt/bd969a01fe65b95f2736e22d76b214d7.jpg', u'height': 512, u'id': 2997, u'license': 3, u'width': 512}
1 {u'supercategory': u'person', u'id': 1, u'keypoints': [u'waistband_left', u'waistband_right', u'hemline_left', u'hemline_right'], u'name': u'person'}
2292 {u'segmentation': [[41, 137, 41, 288, 499, 288, 499, 137]], u'num_keypoints': 4, u'area': 69158, u'iscrowd': 0, u'keypoints': [205, 155, 2, 324, 142, 2, 46, 283, 2, 494, 266, 2], u'ignore': 0, u'image_id': 2997, u'bbox': [41, 137, 458, 151], u'category_id': 1, u'id': 2997}

2.2 修改训练模型的配置文件

可以参照Detectron/configs/12_2017_baselines中的e2e_keypoint_rcnn_R-50-FPN_1x.yaml配置文件进行相应的修改,主要是增加KRCNN的配置

KRCNN:
  ROI_KEYPOINTS_HEAD: keypoint_rcnn_heads.add_roi_pose_head_v1convX
  NUM_STACKED_CONVS: 8
  NUM_KEYPOINTS: 4 # 把此处改成自己数据集中需要检测的关键点个数
  USE_DECONV_OUTPUT: True
  CONV_INIT: MSRAFill
  CONV_HEAD_DIM: 512
  UP_SCALE: 2
  HEATMAP_SIZE: 56  # ROI_XFORM_RESOLUTION (14) * UP_SCALE (2) * USE_DECONV_OUTPUT (2)
  ROI_XFORM_METHOD: RoIAlign
  ROI_XFORM_RESOLUTION: 14
  ROI_XFORM_SAMPLING_RATIO: 2
  KEYPOINT_CONFIDENCE: bbox

2.3 修改keypoint.py和json_dataset.py文件

  • 把Detectron/lib/datasets/json_dataset.py文件中303行到311行的self.keypoint_flip_map字典改成自己的关键点映射对
self.keypoint_flip_map={u'waistband_left':u'waistband_right', u'hemline_left':u'hemline_right'}
  • 把Detectron/lib/utils/keypoints.py文件中34行到62行的keypoints和keypoint_flip_map也相应改过来
keypoints = [u'waistband_left', u'waistband_right', u'hemline_left', u'hemline_right']
keypoint_flip_map={u'waistband_left':u'waistband_right', u'hemline_left':u'hemline_right'}

2.4 训练模型

执行train_net.py文件训练模型(相关参数自己按需提供,比如–cfg,OUTPUT_DIR等)

2.5 基于训练的模型执行推断

这里因为我的关键点annotation文件没有提供skeleton,没办法直接用Detectron的infer_simple.py进行推断,就在infer_simple.py文件中仿照main函数自己写了个函数,将推断结果(关键点坐标)输出为csv格式的文件,主要就是利用main函数里面的infer_engine.im_detect_all()函数拿到推断的结果。

def write_infer_kpts(args, file_name):
    """write infer result of FashionAI keypoints to csv file"""
    logger = logging.getLogger(__name__)
    merge_cfg_from_file(args.cfg)
    cfg.TEST.WEIGHTS = args.weights
    cfg.NUM_GPUS = 1
    assert_and_infer_cfg()
    model = infer_engine.initialize_model_from_cfg()
    bbox_infer = []
    csv_head = ['image_id', 'image_category', 'xmin','ymin','xmax','ymax',               'waistband_left','waistband_right','hemline_left','hemline_right'
]
    input_data = pd.read_csv(args.input_data).values 
    im_list = map(lambda x: args.im_or_folder+x, list(input_data[:,0]))

    for i, im_name in enumerate(im_list):
        logger.info('Processing {}'.format(im_name))
        im = cv2.imread(im_name)
        timers = defaultdict(Timer)
        t = time.time()
        with c2_utils.NamedCudaScope(0):
            cls_boxes, cls_segms, cls_keyps = infer_engine.im_detect_all(
                model, im, None, timers=timers
            )
        logger.info('Inference time: {:.3f}s'.format(time.time() - t))
        for k, v in timers.items():
            logger.info(' | {}: {:.3f}s'.format(k, v.average_time))
        if i == 0:
            logger.info(
                ' \ Note: inference on the first image will be slower than the '
                'rest (caches and auto-tuning need to warm up)'
            )
        im_id = im_name.split('/')[-3]+'/'+im_name.split('/')[-2]+'/'+im_name.split('/')[-1]
        im_label = im_name.split('/')[-2]
        cls_keyps = np.array(cls_keyps[1])
        cls_boxes = np.array(cls_boxes[1])# cls_boxes[0] represents background
        print(i, cls_keyps.shape, cls_boxes.shape)
#        print('cls_keyps:',cls_keyps)
        kpts_num = len(csv_head[6:])
#        item_list = [im_id, im_label] 

        if len(cls_boxes[0])>0:
                idx = np.argsort(cls_boxes[:,4])[-1] 
                xmin, ymin, xmax, ymax = map(lambda x: int(round(x)), cls_boxes[idx][:4])      
        item_list = [im_id, im_label, xmin, ymin, xmax, ymax]
        if cls_keyps.shape[0]>0:
            for i in range(kpts_num):
                idx = np.argsort(cls_keyps[:,3,i])[-1] 
                kpt_x, kpt_y = map(lambda x: int(round(x)), list(cls_keyps[idx,:2,i]))
                item_list.append(str(kpt_x)+'_'+str(kpt_y)+'_1')       
        bbox_infer.append(item_list)
    df = pd.DataFrame(bbox_infer, columns = csv_head)
    df.to_csv(args.output_dir+'/'+file_name, mode='w')

2.6 服饰关键点定位结果
这里写图片描述
同时比较准确的预测出了skirt的bounding box和四角的关键点。

用Mask R-CNN训练自己的数据集进行关键点定位大致过程就是这样,希望对需要的朋友有所帮助。

  • 5
    点赞
  • 35
    收藏
    觉得还不错? 一键收藏
  • 26
    评论
评论 26
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值