本文将介绍发表在ECCV 2020的一篇基于图模型的多人姿态估计方法,作者来自香港大学、商汤科技、南京大学和悉尼大学。
论文链接: https://arxiv.org/abs/2007.11864v1
代码链接: 尚未公开
主要思想:
现有的多人姿态估计模型一般分为一般分为两大类:top-down和bottom-up方法。Top-down的模型先对输入的图像进行目标检测,检测出图像中每个人的bounding box之后,通过单人姿态估计模型对每个人的姿态进行检测。而bottom-up的方法则是先通过关键点检测模型检测出图像中所有人的所有关键点,然后对其进行聚类分组等操作,将检测出的关键点与每个人对应起来。本文提出了一个基于图模型的bottom-up方法,即通过一个可训练的层级图结构(HGG,Hierarchical Graph Grouping)来对前一阶段检测到的关键点进行分组。在训练阶段,HGG部分与关键点检测部分是端到端一起训练的。
模型结构
如上图所示,本文的模型分为关键点提取(Keypoint Candidate Proposal)和基于图的关键点聚类(Hierarchical Graph Grouping
)两部分。其中关键点提取部分采用的是四组堆叠的沙漏网络结构,输出为关键点的Heatmap和点对关系特征图。而用于关键点分组的Hierarchical Graph Grouping部分就是本文的主要创新点,其主要结构如下图所示:
HGG部分主要由三部分组成:
- GNN用来学习Heatmaps中每个节点之间的边(Edge)连接;
- Edge Discriminator用来判断每个边的两个节点是否应该连接(输出0或1),两个点来自同一个人时输出1,来自不同人时输出0;
- Macro-Node Discriminator用来监督每次迭代的Group结果,当每个Group中所有节点都属于同一个人时输出1,否则输出0.
如下面伪代码所示,对关键点的Group是通过迭代进行的,每次迭代由四步操作构成:Graph feature aggregation, Edge proximity update, Node clustering, Graph pruning.
模型训练
损失函数:
模型的损失函数主要由三部分组成:
- Keypoint detection loss:监督训练Heatmap的生成;
- Pairwise pull/push loss:监督训练点对关系特征图的生成,将来自不同人的两个nodes关系尽可能远,同一个人的两个nodes关系尽可能近;
- BCE loss:监督训练两个判别器——Edge Discriminator和Macro-node Discriminator。
训练细节:
优化算法:Adam, lr=2e-4;
Input size:512512;
Output Size:128128