姿态估计Bottom-up系列中的Grouping方式一:Associative Embedding

原文:Associative Embedding: End-to-End Learning for Joint Detection and Grouping
代码:princeton-vl/pose-ae-train


Abstract

文章提出了一种全新的用于detection和grouping的监督方法,Associate Embedding。意外的多人姿态估计、实例分割、多目标跟踪问题都采用了two-stages,先detect后group的方式。文章提出的associate embedding 同时输出detection 和 grouping的结果,并在多人姿态估计任务中的MPII和COCO数据集上达到了SOTA。


Introduction

Associate embedding的基本思想是为每一个detection都分配一个vector作为tag来指示聚类的分配。所以拥有同一个tag的detections聚类成一个集合,构建出一个人体实例。对于每个实例有m个关键点的多人姿态估计任务,网络共输出m张关键点检测的heatmap,和m个tag图,每个tag图分别来指示对于某一关节来说,每个像素分别应该属于哪个人体实例。

具体的解码方式是:先从heatmap上提取出关键点所在的位置,去对于的tag图中找到tag值,把tag值相近的关键点聚类成一个人体实例。

网络训练时,Loss函数只需要鼓励同一个人体实例的不同关节点的对应的tag值相近,而不同的人体实例的关节点tag值相异即可。只需要学习tag应该相同或不同,网络可以自由学习具体的绝对的标签是什么。(用0和1可以区别出来,2和3也可以区别出来,效果都是一样的,网络自己决定用哪组数字来区别,所以对于tag也没有具体的GT。)


Related Work


Approach

Network Architecture

用了Hourglass的网络结构做了一个改进,为了多人姿态,提高了hourglass各层的分辨率,每个网络中的residual modules换成了 3 × 3 3\times 3 3×3 卷积。

Detection and Grouping

Detection的部分和单人姿态类似,为m类关键点分别生成m张heatmap,为不同人的同类关节点在用一张heatmap上无差别的表示,所以一张理想的heatmap是为每个人的这类关节点都有一个不同的peak。训练时GT用2D高斯来构建peak,loss函数用了MSE。

Grouping的部分,也是本文的关键创新点。假设已经从detection的部分获得了共m类关键点的位置,每类关键点都有好几个,需要把这些关键点进行grouping,构建出完整的人体实例。网络在detection的同时额外生成了一个embedding来指示每个像素应该属于哪个人体实例。(对于不同类的实例来说都有一个不同的embedding)。文章提出embedding的纬度并不重要,高维如果能实现,那低维一定也可以,比较只是用来区别人体实例,所以文章里用了1D的embedding。

tag指示了每个检测到的关节点属于哪一个人体实例。每个detection heatmap都各自对应的tag map。因此对于m各人体关节,用m 个 detection heatmap用来检测 和 m 个 tag map用来grouping。为了把每个检测到关节匹配到不同的人体实例,文章首先检索出每个检测出的关键点的峰值处像素所对应的tag map中的tag值,通过比较各个关节点的值来聚类 组合人体。

Grouping 的 loss 函数评估预测出的tag能是否个GT的聚类结果相同。具体地,我们检索每个人体节点的GT位置的对应的tag map中预测出的tag值(并不是预测出的关节点位置对应的tag值)

每个人体各关节点的GT位置的tag求均值作为改人体的reference embedding:
在这里插入图片描述
之后根据同意人体的tag相近,不同人体的reference embedding相异的原则,计算loss:
在这里插入图片描述

Parsing Network Output

为了生成最后的一系列检测结果,我们一个一个关节点遍历。顺序从头和躯干逐渐移动至关节。假设我们首先从脖子开始,组成我们的初始的人体实例pool。对于下一个关节点,
我们挑选出与人体pool最匹配的关节点。每个关节点由它的分数与tag组成,每个人体的reference embedding 由当前关节点的tag均值决定。

我们比较这些embedding之间的距离,我们贪婪的分配 在embedding距离之内的 响应最高的 关节点。如果新的joint没有任何匹配的人体(和任何tag都不相近),把这个joint作为一个新的人体实例。知道所有的joint都分配完成。

代码实现部分:
首先在k个heatmap上得到最多m个detection,在提取出对应的tag值,得到:

ans = {
    
	tag_k:
	loc_k:
	val_k:
}

生成权重矩阵,使用KM算法找到二分图最佳匹配。(按关节点顺序遍历)

对于Missing Joints 我们在评估阶段需要确保每个人体实例都有完整的所有关节点,我们挑选出所有tag相近的位置,在其中挑选出响应最高的点作为其关节点。这个位置在之前的detection阶段可能并没有达到被detect出的阀值。

对于Multiscale Evaluation我们对多尺度的heatmaps在resize之后逐元素取均值,对于tag maps,我们resize之后对m个scale的tag maps进行逐元素的concat,tag作为m维的vector。

HigherHRNet 代码中的associate embedding部分。

# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Bin Xiao (leoxiaobin@gmail.com)
# Modified by Bowen Cheng (bcheng9@illinois.edu)
# ------------------------------------------------------------------------------


import torch
import torch.nn as nn

class AELoss(nn.Module):
    def __init__(self, loss_type):
        super().__init__()
        self.loss_type = loss_type

    def singleTagLoss(self, pred_tag, joints):
        """
        associative embedding loss for one image
        """
        tags = []
        pull = 0
        for joints_per_person in joints:
            tmp = []
            for joint in joints_per_person:
                if joint[1] > 0:
                    tmp.append(pred_tag[joint[0]])
            if len(tmp) == 0:
                continue
            tmp = torch.stack(tmp)
            tags.append(torch.mean(tmp, dim=0))
            pull = pull + torch.mean((tmp - tags[-1].expand_as(tmp))**2)

        num_tags = len(tags)
        if num_tags == 0:
            return make_input(torch.zeros(1).float()), \
                make_input(torch.zeros(1).float())
        elif num_tags == 1:
            return make_input(torch.zeros(1).float()), \
                pull/(num_tags)

        tags = torch.stack(tags)

        size = (num_tags, num_tags)
        A = tags.expand(*size)
        B = A.permute(1, 0)

        diff = A - B

        if self.loss_type == 'exp':
            diff = torch.pow(diff, 2)
            push = torch.exp(-diff)
            push = torch.sum(push) - num_tags
        elif self.loss_type == 'max':
            diff = 1 - torch.abs(diff)
            push = torch.clamp(diff, min=0).sum() - num_tags
        else:
            raise ValueError('Unkown ae loss type')

        return push/((num_tags - 1) * num_tags) * 0.5, \
            pull/(num_tags)

    def forward(self, tags, joints):
        """
        accumulate the tag loss for each image in the batch
        """
        pushes, pulls = [], 
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值