Sparse Local Patch Transformer for Robust Face Alignment and Landmarks Inherent Relation Learning[CVPR2022]
paper: https://openaccess.thecvf.com/content/CVPR2022/papers/Xia_Sparse_Local_Patch_Transformer_for_Robust_Face_Alignment_and_Landmarks_CVPR_2022_paper.pdf
code: https://github.com/Jiahao-UTS/SLPT-master
动机
目前人脸对齐的方法已经取得很好的精度,但是大姿态、重度遮挡、光照变化的情况仍不能够被很好的处理。人脸的面部具有一个regular structure(规则结构),也就是面部地标之间的内在关系,这在人脸对齐中起着重要的作用。虽然近年来热力图回归方法占据了人脸对齐区域的主导地位,但是基于热力图的方法有两点缺陷(原文+自己总结):
- 热力图回归的方式由于是单独回归每张关键点的热力图,所以在预测时缺失了关键点之间的内在关系;
- 从热力图到关键点坐标值是往往采取argmax来获取热力图中最大峰值坐标,由于这一步是不可微的,就不可避免的会引入量化误差。
坐标回归的方法具有学习关键点内在关系的固有潜力,因为最后一步在预测时往往是将全局特征输入至一个全连接层,然后对所有的关键点同时预测,这样可以考虑到关键点之间的内在关系。但是作者认为学习潜在关系应该与局部特征一起学习,但是目前回归的方法是直接将全局特征输入至全连接层中来预测关键点,这样就损失了局部外貌信息。
因此这篇文章提出了Sparse Local Patch Transformer (SLPT)方法来学习关键点之间的内在关系。此外,本文的网络还引入了一个从粗到细的框架与SLPT相结合,使初始地标利用动态调整大小的局部patch的细粒度特征逐渐收敛到目标面部地标。
总体框架
本文的模型框架如下,是一个从粗到细的过程,在模型代码中分成了三个阶段。首先使用HRNet预训练模型作为backbone提取图像特征,然后在每一个阶段将上一个阶段预测到的landmark作为输入(充当一个先验,在最开始阶段使用的是一个标准正面landmark坐标),将上一个阶段预测的landmark在feature map中提取出局部的patch,然后通过一个卷积层生成每个landmark的特征表示,在每一个阶段模型的Transformer的输出都是关键点在相应的local patch中的坐标,与原先的anchor坐标相加后,即得到了该阶段预测的landmark坐标。
模型的主要代码如下所示,添加了部分注释:
class Sparse_alignment_network(nn.Module):
def __init__(self, num_point, d_model, trainable,
return_interm_layers, dilation, nhead, feedforward_dim,
initial_path, cfg):
super(Sparse_alignment_network, self).__init__()
self.num_point = num_point
self.d_model = d_model
self.trainable = trainable
self.return_interm_layers = return_interm_layers
self.dilation = dilation
self.nhead = nhead
self.feedforward_dim = feedforward_dim
self.initial_path = initial_path
self.Sample_num = cfg.MODEL.SAMPLE_NUM
# self.initial_points是标准正面的98点坐标, (98,2)
self.initial_points = torch.from_numpy(np.load(initial_path)['init_face'] / 256.0).view(1, num_point, 2).float()
self.initial_points.requires_grad = False
# ROI_creator 生成ROI
self.ROI_1 = get_roi(self.Sample_num, 8.0, 64)
self.ROI_2 = get_roi(self.Sample_num, 4.0, 64)
self.ROI_3 = get_roi(self.Sample_num, 2.0, 64)
self.interpolation = interpolation_layer()
# feature_extractor
self.feature_extractor = nn.Conv2d(d_model, d_model, kernel_size=self.Sample_num, bias=False)
self.feature_norm = nn.LayerNorm(d_model)
# Transformer
self.Transformer = Transformer(num_point, d_model, nhead, cfg.TRANSFORMER.NUM_DECODER,
feedforward_dim, dropout=0.1)
self.out_layer = nn.Linear(d_model, 2)
self._reset_parameters()
# backbone
self.backbone = get_face_alignment_net(cfg)
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, image):
bs = image.size(0)
output_list = []
feature_map = self.backbone(image) # 提取全局图像特征
initial_landmarks = self.initial_points.repeat(bs, 1, 1).to(image.device) # (bs, 98, 2)
# stage_1
ROI_anchor_1, bbox_size_1, start_anchor_1 = self.ROI_1(initial_landmarks.detach())
ROI_anchor_1 = ROI_anchor_1.view(bs, self.num_point * self.Sample_num * self.Sample_num, 2)
# 使用线性插值方法,提取出特征图的局部特征
ROI_feature_1 = self.interpolation(feature_map, ROI_anchor_1.detach()).view(bs, self.num_point, self.Sample_num,
self.Sample_num, self.d_model)
ROI_feature_1 = ROI_feature_1.view(bs * self.num_point, self.Sample_num, self.Sample_num,
self.d_model).permute(0, 3, 2, 1)
# 通过一层卷积层,生成landmarks对应局部patch的embedding
transformer_feature_1 = self.feature_extractor(ROI_feature_1).view(bs, self.num_point, self.d_model)
offset_1 = self.Transformer(transformer_feature_1)
offset_1 = self.out_layer(offset_1)
# 原始landmark位置+预测到的偏移量 = 预测的landmark位置
landmarks_1 = start_anchor_1.unsqueeze(1) + bbox_size_1.unsqueeze(1) * offset_1
output_list.append(landmarks_1)
# stage_2
ROI_anchor_2, bbox_size_2, start_anchor_2 = self.ROI_2(landmarks_1[:, -1, :, :].detach())
ROI_anchor_2 = ROI_anchor_2.view(bs, self.num_point * self.Sample_num * self.Sample_num, 2)
ROI_feature_2 = self.interpolation(feature_map, ROI_anchor_2.detach()).view(bs, self.num_point, self.Sample_num,
self.Sample_num, self.d_model)
ROI_feature_2 = ROI_feature_2.view(bs * self.num_point, self.Sample_num, self.Sample_num,
self.d_model).permute(0, 3, 2, 1)
transformer_feature_2 = self.feature_extractor(ROI_feature_2).view(bs, self.num_point, self.d_model)
offset_2 = self.Transformer(transformer_feature_2)
offset_2 = self.out_layer(offset_2)
landmarks_2 = start_anchor_2.unsqueeze(1) + bbox_size_2.unsqueeze(1) * offset_2
output_list.append(landmarks_2)
# stage_3
ROI_anchor_3, bbox_size_3, start_anchor_3 = self.ROI_3(landmarks_2[:, -1, :, :].detach())
ROI_anchor_3 = ROI_anchor_3.view(bs, self.num_point * self.Sample_num * self.Sample_num, 2)
ROI_feature_3= self.interpolation(feature_map, ROI_anchor_3.detach()).view(bs, self.num_point, self.Sample_num,
self.Sample_num, self.d_model)
ROI_feature_3 = ROI_feature_3.view(bs * self.num_point, self.Sample_num, self.Sample_num,
self.d_model).permute(0, 3, 2, 1)
transformer_feature_3 = self.feature_extractor(ROI_feature_3).view(bs, self.num_point, self.d_model)
offset_3 = self.Transformer(transformer_feature_3)
offset_3 = self.out_layer(offset_3)
landmarks_3 = start_anchor_3.unsqueeze(1) + bbox_size_3.unsqueeze(1) * offset_3
output_list.append(landmarks_3)
return output_list
Transformer核心部分(Inherent_Layer层)
将每个patch的embedding作为src参数传入Transformer中后,先设置三个可学习参数:structure_encoding、landmark_query和tgt,然后一起传入Transformer_block中。Transformer_block其实就是对Inherent_Layer层的输出进行叠加,主要操作部分都在Inherent_Layer部分完成 ,下面主要讲解Inherent_Layer层的操作。
将src、tgt、structure_encoding、landmark_query传入Inherent_Layer后,主要有query-query attention、query-representation attention两个操作,分别用来学习query之间的内在交互关系和关键点之间固有的内在关系。但看代码可以知道,两个操作的核心代码都是nn.MultiheadAttention,只是q、k、v不同。最后再经过一个MLP返回输出。
import copy
import torch
from torch import nn, Tensor
from typing import Optional
from utils import _get_activation_fn
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
class Inherent_Layer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False):
super().__init__()
# Attention Layer
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
# normalization & Dropout
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
# query-query attention
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# query-representation attention
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
# FFN
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_pre(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
# query-query attention
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
# query-representation attention
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
# FFN
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
class Transformer_block(nn.Module):
def __init__(self, Transformer_block, num_layers, norm=None, return_intermediate=False):
super().__init__()
self.layers = _get_clones(Transformer_block, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
def forward(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
output = tgt
intermediate = []
for layer in self.layers:
output = layer(output, memory, tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos, query_pos=query_pos)
if self.return_intermediate:
intermediate.append(self.norm(output))
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)
if self.return_intermediate:
return torch.stack(intermediate)
return output
class Transformer(nn.Module):
def __init__(self, num_points, d_model=256, nhead=8, num_decoder_layer=6, dim_feedforward=1024,
dropout=0.1, activation="relu", normalize_before=True):
super(Transformer, self).__init__()
# Dim
self.d_model = d_model
# number of head
self.nhead = nhead
# structure encoding
self.structure_encoding = nn.Parameter(torch.randn(1, num_points, d_model))
# landmark query
self.landmark_query = nn.Parameter(torch.randn(1, num_points, d_model))
SLPT_Inherent_Layer = Inherent_Layer(d_model, nhead, dim_feedforward, dropout,
activation, normalize_before)
decoder_norm = nn.LayerNorm(d_model)
self.Transformer_block = Transformer_block(SLPT_Inherent_Layer, num_decoder_layer, decoder_norm, return_intermediate=True)
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, src):
bs, num_feat, len_feat = src.size()
structure_encoding = self.structure_encoding.repeat(bs, 1, 1).permute(1, 0, 2)
landmark_query = self.landmark_query.repeat(bs, 1, 1).permute(1, 0, 2)
src = src.permute(1, 0, 2)
tgt = torch.zeros_like(landmark_query)
tgt = self.Transformer_block(tgt, src, pos=structure_encoding, query_pos=landmark_query)
return tgt.permute(2, 0, 1, 3)