IJCV2021 人脸关键点检测器PIPNet

阿联酋起源人工智能研究院(IIAI)科学家提出了一种新颖的人脸关键点检测方法PIPNet,通过融合坐标回归和热力图回归的优势,并结合半监督学习充分利用大量无标注数据提升跨域的泛化性能,最终得到一个又快又准又稳的人脸关键点检测器。相关论文已被IJCV 2021接收。

论文:https://arxiv.org/abs/2003.03771

代码:https://github.com/jhb86253817/PIPNet

预训练地址:

https://drive.google.com/drive/folders/17OwDgJUfuc5_ymQ3QruD8pUnh5zHreP2

gpu上测试30ms左右。

严重侧脸有时也比较飘。

多人脸时,速度比较慢,一个人脸30多ms,4个人脸1百多ms。

 

 

为了得到一个适用于真实应用的人脸关键点模型,本文基于上述挑战提出嵌套网络(PIPNet)模型。该模型主要包含三个重要模块。


首先是一个新颖的检测头,称作嵌套回归(PIP regression)。该方法将关键点定位任务分解成了基于低分辨率特征图的热力图回归和局部特征图上的坐标回归,使模型在不依赖高分辨率特征图的情况下依然具有较高的精度,从而节省了计算量。

此外,我们在检测头上额外设计了近邻回归模块,通过训练每个关键点根据自己的位置定位它的近邻关键点,使得在预测时能得到局部区域的形状约束,从而提升模型的鲁棒性。

最后,我们提出一种基于自训练的半监督学习方法来充分利用大量无标注的不同场景下样本。该方法在对无标注样本估计伪标签时,首先从简单的任务开始,然后在后续迭代中逐渐增加任务的难度,直到变成标准的自训练任务,有效缓解了标准自训练方法在伪标签中引入的噪声问题。

半监督学习:表4展示了STC与基线方法的比较。其中,300W的训练集带有标注,无标注数据集来自COFW和WFLW或CelebA。可以看到,STC无论是与直接跨领域测试,还是与经典的UDA方法DANN以及标准自训练法比较,均取得了更好的结果。

表5展示了与现有方法在跨领域泛化性能上的比较。之前的方法基本遵循在300W上训练,然后直接在测试集上测试(即GSL)。同样遵循这一模式,PIPNet在COFW-68上仅落后于AVS。

而当充分利用CelebA中的无标注数据后,模型在COFW-68上的跨领域性能大幅提高,并超越了AVS,这既表明了STC的有效性,也显示了GSSL范式在实际应用中的可行性。

表4. STC与基线方法的比较

表5. STC与已有方法在同领域及跨领域测试集上的结果比较

速度:为了说明PIPNet在推断速度上的优势,我们与之前的方法比较模型在精度和速度上的平衡。如图1所示,PIPNet在CPU和GPU上均取得了最优的平衡(越靠近右下角越好),尤其是CPU上的优势更为明显。因此,PIPNet很适合计算资源受限的场景。

笔者自己改了一版视频测试脚本:

import cv2, os
import sys

from FaceBoxesV2.faceboxes_detector import FaceBoxesDetector

sys.path.insert(0, 'FaceBoxesV2')
sys.path.insert(0, '..')
import numpy as np
import importlib
from math import floor
# from faceboxes_detector import *
import time

import torch
import torch.nn.parallel
import torch.utils.data
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

from networks import *
import data_utils
from functions import *


experiment_name = 'pip_32_16_60_r18_l2_l1_10_1_nb10.py '
data_name = 'WFLW'
config_path = '.experiments.{}.{}'.format(data_name, experiment_name)
video_file = '../videos/002.avi'
video_file = 'camera'

# my_config = importlib.import_module(config_path, package='PIPNet')
# Config = getattr(my_config, 'Config')
class Config():
    def __init__(self):
        self.det_head = 'pip'
        self.net_stride = 32
        self.batch_size = 16
        self.init_lr = 0.0001
        self.num_epochs = 60
        self.decay_steps = [30, 50]
        self.input_size = 256
        self.backbone = 'resnet18'
        self.pretrained = True
        self.criterion_cls = 'l2'
        self.criterion_reg = 'l1'
        self.cls_loss_weight = 10
        self.reg_loss_weight = 1
        self.num_lms = 98
        self.save_interval = self.num_epochs
        self.num_nb = 10
        self.use_gpu = True
        self.gpu_id = 2


cfg = Config()
cfg.experiment_name = experiment_name
cfg.data_name = data_name

save_dir = os.path.join('./snapshots', cfg.data_name, cfg.experiment_name)

meanface_indices, reverse_index1, reverse_index2, max_len = get_meanface(os.path.join('../data', cfg.data_name, 'meanface.txt'), cfg.num_nb)

if cfg.backbone == 'resnet18':
    resnet18 = models.resnet18(pretrained=cfg.pretrained)
    net = Pip_resnet18(resnet18, cfg.num_nb, num_lms=cfg.num_lms, input_size=cfg.input_size, net_stride=cfg.net_stride)
device = torch.device("cpu")
if cfg.use_gpu:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

net = net.to(device)

# weight_file = os.path.join(save_dir, 'epoch%d.pth' % (cfg.num_epochs-1))
weight_file = 'epoch59.pth'
state_dict = torch.load(weight_file, map_location=device)
net.load_state_dict(state_dict)

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
preprocess = transforms.Compose([transforms.Resize((cfg.input_size, cfg.input_size)), transforms.ToTensor(), normalize])

def demo_video(video_file, net, preprocess, input_size, net_stride, num_nb, use_gpu, device):
    detector = FaceBoxesDetector('FaceBoxes', '../FaceBoxesV2/weights/FaceBoxesV2.pth', use_gpu, device)
    my_thresh = 0.9
    det_box_scale = 1.2

    net.eval()
    if video_file == 'camera':
        cap = cv2.VideoCapture(0)
    else:
        cap = cv2.VideoCapture(video_file)
    if (cap.isOpened()== False): 
        print("Error opening video stream or file")
    frame_width = int(cap.get(3))
    frame_height = int(cap.get(4))
    count = 0
    while(cap.isOpened()):
        ret, frame = cap.read()
        if ret == True:
            start=time.time()
            detections, _ = detector.detect(frame, my_thresh, 1)
            print('detect time',time.time()-start)
            start = time.time()
            for i in range(len(detections)):
                det_xmin = detections[i][2]
                det_ymin = detections[i][3]
                det_width = detections[i][4]
                det_height = detections[i][5]
                det_xmax = det_xmin + det_width - 1
                det_ymax = det_ymin + det_height - 1

                det_xmin -= int(det_width * (det_box_scale-1)/2)
                # remove a part of top area for alignment, see paper for details
                det_ymin += int(det_height * (det_box_scale-1)/2)
                det_xmax += int(det_width * (det_box_scale-1)/2)
                det_ymax += int(det_height * (det_box_scale-1)/2)
                det_xmin = max(det_xmin, 0)
                det_ymin = max(det_ymin, 0)
                det_xmax = min(det_xmax, frame_width-1)
                det_ymax = min(det_ymax, frame_height-1)
                det_width = det_xmax - det_xmin + 1
                det_height = det_ymax - det_ymin + 1
                cv2.rectangle(frame, (det_xmin, det_ymin), (det_xmax, det_ymax), (0, 0, 255), 2)
                det_crop = frame[det_ymin:det_ymax, det_xmin:det_xmax, :]
                det_crop = cv2.resize(det_crop, (input_size, input_size))
                inputs = Image.fromarray(det_crop[:,:,::-1].astype('uint8'), 'RGB')
                inputs = preprocess(inputs).unsqueeze(0)
                inputs = inputs.to(device)
                lms_pred_x, lms_pred_y, lms_pred_nb_x, lms_pred_nb_y, outputs_cls, max_cls = forward_pip(net, inputs, preprocess, input_size, net_stride, num_nb)
                lms_pred = torch.cat((lms_pred_x, lms_pred_y), dim=1).flatten()
                tmp_nb_x = lms_pred_nb_x[reverse_index1, reverse_index2].view(cfg.num_lms, max_len)
                tmp_nb_y = lms_pred_nb_y[reverse_index1, reverse_index2].view(cfg.num_lms, max_len)
                tmp_x = torch.mean(torch.cat((lms_pred_x, tmp_nb_x), dim=1), dim=1).view(-1,1)
                tmp_y = torch.mean(torch.cat((lms_pred_y, tmp_nb_y), dim=1), dim=1).view(-1,1)
                lms_pred_merge = torch.cat((tmp_x, tmp_y), dim=1).flatten()
                lms_pred = lms_pred.cpu().numpy()
                lms_pred_merge = lms_pred_merge.cpu().numpy()
                for i in range(cfg.num_lms):
                    x_pred = lms_pred_merge[i*2] * det_width
                    y_pred = lms_pred_merge[i*2+1] * det_height
                    cv2.circle(frame, (int(x_pred)+det_xmin, int(y_pred)+det_ymin), 1, (0, 0, 255), 2)
            print('keypoint time', time.time() - start)
            count += 1
            #cv2.imwrite('video_out2/'+str(count)+'.jpg', frame)
            cv2.imshow('1', frame)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
        else:
            break

    cap.release()
    cv2.destroyAllWindows()

demo_video(video_file, net, preprocess, cfg.input_size, cfg.net_stride, cfg.num_nb, cfg.use_gpu, device)

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值