Sparse Local Patch Transformer预测人脸关键点坐标及内在关系 【CVPR 2022]

Sparse Local Patch Transformer for Robust Face Alignment and Landmarks Inherent Relation Learning[CVPR2022]


paper: https://openaccess.thecvf.com/content/CVPR2022/papers/Xia_Sparse_Local_Patch_Transformer_for_Robust_Face_Alignment_and_Landmarks_CVPR_2022_paper.pdf
code: https://github.com/Jiahao-UTS/SLPT-master


动机

目前人脸对齐的方法已经取得很好的精度,但是大姿态、重度遮挡、光照变化的情况仍不能够被很好的处理。人脸的面部具有一个regular structure(规则结构),也就是面部地标之间的内在关系,这在人脸对齐中起着重要的作用。虽然近年来热力图回归方法占据了人脸对齐区域的主导地位,但是基于热力图的方法有两点缺陷(原文+自己总结)

  1. 热力图回归的方式由于是单独回归每张关键点的热力图,所以在预测时缺失了关键点之间的内在关系;
  2. 从热力图到关键点坐标值是往往采取argmax来获取热力图中最大峰值坐标,由于这一步是不可微的,就不可避免的会引入量化误差。

坐标回归的方法具有学习关键点内在关系的固有潜力,因为最后一步在预测时往往是将全局特征输入至一个全连接层,然后对所有的关键点同时预测,这样可以考虑到关键点之间的内在关系。但是作者认为学习潜在关系应该与局部特征一起学习,但是目前回归的方法是直接将全局特征输入至全连接层中来预测关键点,这样就损失了局部外貌信息。

因此这篇文章提出了Sparse Local Patch Transformer (SLPT)方法来学习关键点之间的内在关系。此外,本文的网络还引入了一个从粗到细的框架与SLPT相结合,使初始地标利用动态调整大小的局部patch的细粒度特征逐渐收敛到目标面部地标。
在这里插入图片描述

总体框架

本文的模型框架如下,是一个从粗到细的过程,在模型代码中分成了三个阶段。首先使用HRNet预训练模型作为backbone提取图像特征,然后在每一个阶段将上一个阶段预测到的landmark作为输入(充当一个先验,在最开始阶段使用的是一个标准正面landmark坐标),将上一个阶段预测的landmark在feature map中提取出局部的patch,然后通过一个卷积层生成每个landmark的特征表示,在每一个阶段模型的Transformer的输出都是关键点在相应的local patch中的坐标,与原先的anchor坐标相加后,即得到了该阶段预测的landmark坐标。
在这里插入图片描述
模型的主要代码如下所示,添加了部分注释:

class Sparse_alignment_network(nn.Module):
    def __init__(self, num_point, d_model, trainable,
                 return_interm_layers, dilation, nhead,  feedforward_dim,
                 initial_path, cfg):
        super(Sparse_alignment_network, self).__init__()
        self.num_point = num_point
        self.d_model = d_model
        self.trainable = trainable
        self.return_interm_layers = return_interm_layers
        self.dilation = dilation
        self.nhead = nhead
        self.feedforward_dim = feedforward_dim
        self.initial_path = initial_path
        self.Sample_num = cfg.MODEL.SAMPLE_NUM

        # self.initial_points是标准正面的98点坐标, (98,2)
        self.initial_points = torch.from_numpy(np.load(initial_path)['init_face'] / 256.0).view(1, num_point, 2).float()
        self.initial_points.requires_grad = False       

        # ROI_creator 生成ROI
        self.ROI_1 = get_roi(self.Sample_num, 8.0, 64)
        self.ROI_2 = get_roi(self.Sample_num, 4.0, 64)
        self.ROI_3 = get_roi(self.Sample_num, 2.0, 64)

        self.interpolation = interpolation_layer()	

        # feature_extractor 
        self.feature_extractor = nn.Conv2d(d_model, d_model, kernel_size=self.Sample_num, bias=False)

        self.feature_norm = nn.LayerNorm(d_model)

        # Transformer
        self.Transformer = Transformer(num_point, d_model, nhead, cfg.TRANSFORMER.NUM_DECODER,
                                       feedforward_dim, dropout=0.1)

        self.out_layer = nn.Linear(d_model, 2)

        self._reset_parameters()

        # backbone
        self.backbone = get_face_alignment_net(cfg)

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, image):
        bs = image.size(0)

        output_list = []

        feature_map = self.backbone(image)	# 提取全局图像特征

        initial_landmarks = self.initial_points.repeat(bs, 1, 1).to(image.device)   # (bs, 98, 2)

        # stage_1 
        ROI_anchor_1, bbox_size_1, start_anchor_1 = self.ROI_1(initial_landmarks.detach())
        ROI_anchor_1 = ROI_anchor_1.view(bs, self.num_point * self.Sample_num * self.Sample_num, 2)
        # 使用线性插值方法,提取出特征图的局部特征
        ROI_feature_1 = self.interpolation(feature_map, ROI_anchor_1.detach()).view(bs, self.num_point, self.Sample_num,
                                                                            self.Sample_num, self.d_model)
        ROI_feature_1 = ROI_feature_1.view(bs * self.num_point, self.Sample_num, self.Sample_num,
                                     self.d_model).permute(0, 3, 2, 1)

        # 通过一层卷积层,生成landmarks对应局部patch的embedding
        transformer_feature_1 = self.feature_extractor(ROI_feature_1).view(bs, self.num_point, self.d_model)

        offset_1 = self.Transformer(transformer_feature_1)
        offset_1 = self.out_layer(offset_1)

        # 原始landmark位置+预测到的偏移量 = 预测的landmark位置
        landmarks_1 = start_anchor_1.unsqueeze(1) + bbox_size_1.unsqueeze(1) * offset_1
        output_list.append(landmarks_1)

        # stage_2
        ROI_anchor_2, bbox_size_2, start_anchor_2 = self.ROI_2(landmarks_1[:, -1, :, :].detach())
        ROI_anchor_2 = ROI_anchor_2.view(bs, self.num_point * self.Sample_num * self.Sample_num, 2)
        ROI_feature_2 = self.interpolation(feature_map, ROI_anchor_2.detach()).view(bs, self.num_point, self.Sample_num,
                                                                                 self.Sample_num, self.d_model)
        ROI_feature_2 = ROI_feature_2.view(bs * self.num_point, self.Sample_num, self.Sample_num,
                                           self.d_model).permute(0, 3, 2, 1)

        transformer_feature_2 = self.feature_extractor(ROI_feature_2).view(bs, self.num_point, self.d_model)

        offset_2 = self.Transformer(transformer_feature_2)
        offset_2 = self.out_layer(offset_2)

        landmarks_2 = start_anchor_2.unsqueeze(1) + bbox_size_2.unsqueeze(1) * offset_2
        output_list.append(landmarks_2)

        # stage_3
        ROI_anchor_3, bbox_size_3, start_anchor_3 = self.ROI_3(landmarks_2[:, -1, :, :].detach())
        ROI_anchor_3 = ROI_anchor_3.view(bs, self.num_point * self.Sample_num * self.Sample_num, 2)
        ROI_feature_3= self.interpolation(feature_map, ROI_anchor_3.detach()).view(bs, self.num_point, self.Sample_num,
                                                                                   self.Sample_num, self.d_model)
        ROI_feature_3 = ROI_feature_3.view(bs * self.num_point, self.Sample_num, self.Sample_num,
                                           self.d_model).permute(0, 3, 2, 1)

        transformer_feature_3 = self.feature_extractor(ROI_feature_3).view(bs, self.num_point, self.d_model)

        offset_3 = self.Transformer(transformer_feature_3)
        offset_3 = self.out_layer(offset_3)

        landmarks_3 = start_anchor_3.unsqueeze(1) + bbox_size_3.unsqueeze(1) * offset_3
        output_list.append(landmarks_3)

        return output_list

Transformer核心部分(Inherent_Layer层)

将每个patch的embedding作为src参数传入Transformer中后,先设置三个可学习参数:structure_encoding、landmark_query和tgt,然后一起传入Transformer_block中。Transformer_block其实就是对Inherent_Layer层的输出进行叠加,主要操作部分都在Inherent_Layer部分完成 ,下面主要讲解Inherent_Layer层的操作。

将src、tgt、structure_encoding、landmark_query传入Inherent_Layer后,主要有query-query attention、query-representation attention两个操作,分别用来学习query之间的内在交互关系和关键点之间固有的内在关系。但看代码可以知道,两个操作的核心代码都是nn.MultiheadAttention,只是q、k、v不同。最后再经过一个MLP返回输出。

import copy
import torch

from torch import nn, Tensor
from typing import Optional

from utils import _get_activation_fn


def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


class Inherent_Layer(nn.Module):

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False):
        super().__init__()

        # Attention Layer
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)

        # Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        # normalization & Dropout
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward_post(self, tgt, memory,
                     tgt_mask: Optional[Tensor] = None,
                     memory_mask: Optional[Tensor] = None,
                     tgt_key_padding_mask: Optional[Tensor] = None,
                     memory_key_padding_mask: Optional[Tensor] = None,
                     pos: Optional[Tensor] = None,
                     query_pos: Optional[Tensor] = None):

        # query-query attention
        q = k = self.with_pos_embed(tgt, query_pos)
        tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)

        # query-representation attention
        tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
                                   key=self.with_pos_embed(memory, pos),
                                   value=memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)

        # FFN
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt

    def forward_pre(self, tgt, memory,
                    tgt_mask: Optional[Tensor] = None,
                    memory_mask: Optional[Tensor] = None,
                    tgt_key_padding_mask: Optional[Tensor] = None,
                    memory_key_padding_mask: Optional[Tensor] = None,
                    pos: Optional[Tensor] = None,
                    query_pos: Optional[Tensor] = None):


        # query-query attention
        tgt2 = self.norm1(tgt)
        q = k = self.with_pos_embed(tgt2, query_pos)
        tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)

        # query-representation attention
        tgt2 = self.norm2(tgt)
        tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
                                   key=self.with_pos_embed(memory, pos),
                                   value=memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout2(tgt2)

        # FFN
        tgt2 = self.norm3(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
        tgt = tgt + self.dropout3(tgt2)
        return tgt

    def forward(self, tgt, memory,
                tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None,
                query_pos: Optional[Tensor] = None):
        if self.normalize_before:
            return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
                                    tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
        return self.forward_post(tgt, memory, tgt_mask, memory_mask,
                                 tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)


class Transformer_block(nn.Module):

    def __init__(self, Transformer_block, num_layers, norm=None, return_intermediate=False):
        super().__init__()
        self.layers = _get_clones(Transformer_block, num_layers)
        self.num_layers = num_layers
        self.norm = norm
        self.return_intermediate = return_intermediate

    def forward(self, tgt, memory,
                tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None,
                query_pos: Optional[Tensor] = None):
        output = tgt

        intermediate = []

        for layer in self.layers:
            output = layer(output, memory, tgt_mask=tgt_mask,
                           memory_mask=memory_mask,
                           tgt_key_padding_mask=tgt_key_padding_mask,
                           memory_key_padding_mask=memory_key_padding_mask,
                           pos=pos, query_pos=query_pos)
            if self.return_intermediate:
                intermediate.append(self.norm(output))

        if self.norm is not None:
            output = self.norm(output)
            if self.return_intermediate:
                intermediate.pop()
                intermediate.append(output)

        if self.return_intermediate:
            return torch.stack(intermediate)

        return output


class Transformer(nn.Module):
    def __init__(self, num_points, d_model=256, nhead=8, num_decoder_layer=6, dim_feedforward=1024,
                 dropout=0.1, activation="relu", normalize_before=True):
        super(Transformer, self).__init__()
        # Dim
        self.d_model = d_model
        # number of head
        self.nhead = nhead
        # structure encoding
        self.structure_encoding = nn.Parameter(torch.randn(1, num_points, d_model))
        # landmark query
        self.landmark_query = nn.Parameter(torch.randn(1, num_points, d_model))

        SLPT_Inherent_Layer = Inherent_Layer(d_model, nhead, dim_feedforward, dropout,
                                                    activation, normalize_before)
        decoder_norm = nn.LayerNorm(d_model)
        self.Transformer_block = Transformer_block(SLPT_Inherent_Layer, num_decoder_layer, decoder_norm, return_intermediate=True)

        self._reset_parameters()


    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src):
        bs, num_feat, len_feat = src.size()

        structure_encoding = self.structure_encoding.repeat(bs, 1, 1).permute(1, 0, 2)
        landmark_query = self.landmark_query.repeat(bs, 1, 1).permute(1, 0, 2)

        src = src.permute(1, 0, 2)

        tgt = torch.zeros_like(landmark_query)
        tgt = self.Transformer_block(tgt, src, pos=structure_encoding, query_pos=landmark_query)

        return tgt.permute(2, 0, 1, 3)
  • 3
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值