百度飞桨深度学习7日打卡营 课程总结4
课程链接 https://aistudio.baidu.com/aistudio/course/introduce/7073
飞桨官网 https://www.paddlepaddle.org.cn/
课程案例合集 https://aistudio.baidu.com/aistudio/projectdetail/1505799?channelType=0&channel=0
目录
课节3:人脸关键点检测
人脸关键点检测
- 定量问题
- 网络构建:input → backbone → Linear → ReLU → Linear → output
- 损失函数:图像分类 CrossEntropyLoss:交叉熵的计算;人脸关键点检测:L1Loss、L2Loss、SmoothL1Loss:距离的计算。
- 评估指标:NME(Normalized Mean Error),所有预测点和 ground-truth 之间的 L2 Norm,除以(关键点的个数
*
两只眼睛之间的距离)。
『深度学习7日打卡营』人脸关键点检测
流程:
-
问题定义:对现实问题进行分析。直接影响算法的选择、模型评估标准,投入的时间。
-
数据准备:
我们使用paddle.io.Dataset与paddle.vision.transform.*完成数据加载与数据预处理
训练样本量3,462张
验证样本量770张
单个样本形状(3,224,224)(处理后)
加载使用方式paddle.io.Dataset
数据预处理Resize、RandomCrop、Compose -
模型选择和开发:
飞桨框架:
- 方式1:Sequential
- 方式2:Subclass
- 方式3:内置网络
input → backbone → Linear → ReLU → Linear → output
-
模型训练和调优:
- Model封装
- 指定Adam优化器
- 指定Loss计算方法:SmoothL1Loss
- 指定评估指标:NME
- 按照训练的轮次和数据批次迭代训练
-
模型评估测试:
模型评估:model.evaluate()
模型测试:model.predict()- 基于验证样本对模型进行评估验证
- 得到Loss和评价指标值
-
部署上线:模型存储、导出、推理服务部署线上系统对接,指标监控。
预测部署模型存储:model.save(path,training=False)
- 存储模型
- 使用预测引擎部署 → PaddleSlim → Paddlelnference、PaddleLite、PaddleJs
问题定义
人脸关键点检测,是输入一张人脸图片,模型会返回人脸关键点的一系列坐标,从而定位到人脸的关键信息。
# 环境导入
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import cv2
import paddle
paddle.set_device('gpu') # 设置为GPU
import warnings
warnings.filterwarnings('ignore') # 忽略 warning
数据准备
2.1 下载数据集
本次实验所采用的数据集来源为github的开源项目。
目前该数据集已上传到 AI Studio 人脸关键点识别,加载后可以直接使用下面的命令解压。
# !unzip data/data69065/data.zip
解压后的数据集结构为
data/
|—— test
| |—— Abdel_Aziz_Al-Hakim_00.jpg
... ...
|—— test_frames_keypoints.csv
|—— training
| |—— Abdullah_Gul_10.jpg
... ...
|—— training_frames_keypoints.csv
其中,training 和 test 文件夹分别存放训练集和测试集。training_frames_keypoints.csv 和 test_frames_keypoints.csv 存放着训练集和测试集的标签。接下来,我们先来观察一下 training_frames_keypoints.csv 文件,看一下训练集的标签是如何定义的。
key_pts_frame = pd.read_csv('data/training_frames_keypoints.csv') # 读取数据集
print('Number of images: ', key_pts_frame.shape[0]) # 输出数据集大小
key_pts_frame.head(5) # 看前五条数据
输出:
Number of images: 3462
Unnamed: 0 0 1 2 3 4 5 6 \
0 Luis_Fonsi_21.jpg 45.0 98.0 47.0 106.0 49.0 110.0 53.0
1 Lincoln_Chafee_52.jpg 41.0 83.0 43.0 91.0 45.0 100.0 47.0
2 Valerie_Harper_30.jpg 56.0 69.0 56.0 77.0 56.0 86.0 56.0
3 Angelo_Reyes_22.jpg 61.0 80.0 58.0 95.0 58.0 108.0 58.0
4 Kristen_Breitweiser_11.jpg 58.0 94.0 58.0 104.0 60.0 113.0 62.0
7 8 ... 126 127 128 129 130 131 132 133 \
0 119.0 56.0 ... 83.0 119.0 90.0 117.0 83.0 119.0 81.0 122.0
1 108.0 51.0 ... 85.0 122.0 94.0 120.0 85.0 122.0 83.0 122.0
2 94.0 58.0 ... 79.0 105.0 86.0 108.0 77.0 105.0 75.0 105.0
3 120.0 58.0 ... 98.0 136.0 107.0 139.0 95.0 139.0 91.0 139.0
4 121.0 67.0 ... 92.0 117.0 103.0 118.0 92.0 120.0 88.0 122.0
134 135
0 77.0 122.0
1 79.0 122.0
2 73.0 105.0
3 85.0 136.0
4 84.0 122.0
[5 rows x 137 columns]
上表中每一行都代表一条数据,其中,第一列是图片的文件名,之后从第0列到第135列,就是该图的关键点信息。因为每个关键点可以用两个坐标表示,所以 136/2 = 68,就可以看出这个数据集为68点人脸关键点数据集。
目前常用的人脸关键点标注,有如下点数的标注:
- 5点
- 21点
- 68点
- 98点
本次所采用的68标注。
# 计算标签的均值和标准差,用于标签的归一化
key_pts_values = key_pts_frame.values[:,1:] # 取出标签信息
data_mean = key_pts_values.mean() # 计算均值
data_std = key_pts_values.std() # 计算标准差
print('标签的均值为:', data_mean)
print('标签的标准差为:', data_std)
```python
输出:
```python
标签的均值为: 104.4724870017331
标签的标准差为: 43.17302271754281
2.2 查看图像
def show_keypoints(image, key_pts):
"""
Args:
image: 图像信息
key_pts: 关键点信息,
展示图片和关键点信息
"""
plt.imshow(image.astype('uint8')) # 展示图片信息
for i in range(len(key_pts)//2,):
plt.scatter(key_pts[i*2], key_pts[i*2+1], s=20, marker='.', c='b') # 展示关键点信息
# 展示单条数据
n = 14 # n为数据在表格中的索引
image_name = key_pts_frame.iloc[n, 0] # 获取图像名称
key_pts = key_pts_frame.iloc[n, 1:].as_matrix() # 将图像label格式转为numpy.array的格式
key_pts = key_pts.astype('float').reshape(-1) # 获取图像关键点信息
print(key_pts.shape)
plt.figure(figsize=(5, 5)) # 展示的图像大小
show_keypoints(mpimg.imread(os.path.join('data/training/', image_name)), key_pts) # 展示图像与关键点信息
plt.show() # 展示图像
2.3 数据集定义
使用飞桨框架高层API的 paddle.io.Dataset 自定义数据集类,具体可以参考官网文档 自定义数据集。
# 按照Dataset的使用规范,构建人脸关键点数据集
from paddle.io import Dataset
class FacialKeypointsDataset