Mediapipe框架(一)人手关键点检测
-
MediaPipe
是一款由 Google Research 开发并开源的多媒体机器学习模型应用框架
。谷歌的一系列重要产品,如Google Lens、ARCore、Google Home等都已深度整合了 MediaPipe。 -
MediaPipe
目前支持的解决方案(Solution)及支持的平台如下图所示,除了视觉任务,还支持文本、语音及生成式AI任务(实验中)。作为一款跨平台框架,MediaPipe 不仅可以被部署在Web端,更可以在多个移动端 (Android和苹果 iOS)和嵌入式平台(Google Coral 和树莓派)中作为设备端机器学习推理框架。 -
MediaPipe
的每个解决方案(Solution)包括一个或多个模型,一些解决方案还可以自定义模型(使用Model Maker)。 -
今天,我们主要来了解下人手关键点检测(Hand landmark detection)。
官网地址: MediaPipe | Google for Developers
1 MediaPipe核心概念
1.1 MediaPipe 的主要概念
MediaPipe 的核心框架由 C++ 实现。
MediaPipe 的主要概念包括:
-
计算单元(Calculator)
- MediaPipe 在开源了多个由谷歌内部团队实现的计算单元(Calculator)的同时,也向用户提供定制新计算单元的接口。
- 创建一个新的 Calculator,需要用户实现 Open(),Process(),Close() 去分别定义 Calculator 的初始化,针对数据流的处理方法,以及 Calculator 在完成所有运算后的关闭步骤。
-
图(Graph)以及子图(Subgraph):
- MediaPipe 的图是有向的。数据包从数据源(Source Calculator或者 Graph Input Stream)流入图直至在汇聚结点(Sink Calculator 或者 Graph Output Stream) 离开。
- 为了方便用户在多个图中复用已有的通用组件,例如图像数据的预处理、模型的推理以及图像的渲染等, MediaPipe 引入了子图(Subgraph)的概念。因此,一个 MediaPipe 图中的节点既可以是计算单元,亦可以是子图。子图在不同图内的复用,方便了大规模模块化的应用搭建。
-
数据包(Packet)
- 数据包是最基础的数据单位,一个数据包代表了在某一特定时间节点的数据,例如一帧图像或一小段音频信号。
-
数据流(Stream)
- 数据流是由按时间顺序升序排列的多个数据包组成,一个数据流的某一特定时间戳(Timestamp)只允许至多一个数据包的存在;
- 而数据流则是在多个计算单元构成的图中流动。
1.2 通过手掌检测模型实例理解Mediapipe
使用 MediaPipe 来做手掌检测模型移动端模型推理的框架如下图所示:
- input_video 为输入图像,output_video 为输出图像。
- 为了保证整个应用的实时运算,使用 FlowLimiter计算单元来筛选进行运算的输入帧数,只有当前一帧的运算完成后,才会将下一帧图像送入模型。
- 当模型推理完成后,使用 MediaPipe 提供的一系列计算单元来进行输出的渲染和展示——结合使用 DetectionsToRenderData计算单元, RectToRenderData计算单元及AnnotationOverlay计算单元将检测结果渲染在输出图像上。
手掌检测应用的核心部分为上图中的蓝紫色模块(HandDetection子图)。
如下图所示,HandDetection 子图包含了一系列图像处理的计算单元和机器学习模型推理的模块。
- ImageTransformation计算单元将输入的图像调整到模型可以接受的尺寸,用以送入 TF Lite 模型的推理模块;
- 使用 TfLiteTensorsToDetections计算单元将模型输出的 Tensor 转换成检测结果;
- 运用 NonMaxSuppression计算单元等计算单元做后处理;
- 最终从HandDetection子图输出检测结果给主图。
2 人手关键点检测
2.1 概述
-
HandLandmarker由2个相互配合的模型组成:
-
一个手掌检测模型,可在完整图像上运行并返回定向的手边界框。
-
一个手部关键点检测模型,该模型在由手掌检测器定义的裁剪图像区域上操作并返回高保真3D手部关键点。
-
模型信息具体可以参考:Model Card Hand Tracking (Lite/Full) with Fairness Oct 2021
-
相关博客参考:使用 MediaPipe 实现设备端实时手部追踪
-
-
将精确裁剪的手部图像提供给手部关键点检测模型可以极大地减少对数据增强(例如旋转,平移和缩放)的需求,并且可以使网络将其大部分功能专用于坐标预测精度。
-
在整个图像上进行手掌检测之后,手部关键点检测模型将通过回归(即直接坐标预测)对检测到的手区域内部的21个3D手关节坐标进行精确的关键点定位。该模型甚至对于部分可见的手部和自我遮挡也具有鲁棒性。
- 下图为该模型的推理时间,在CPU仅花费17.12ms。
- 该任务配置项如下图所示:
选项 | 描述 | 取值范围 | 默认值 |
---|---|---|---|
`running_mode | IMAGE: 单张图像 VIDEO: 视频帧 LIVE_STREAM: 用于输入数据的实时流模式,例如来自摄像头的数据。在此模式下,必须调用resultListener来设置一个异步接收结果的监听器。 | {IMAGE, VIDEO, LIVE_STREAM } | IMAGE |
num_hands | 最多检测的人手数 | Any integer > 0 | 1 |
min_hand_detection_confidence | 在手掌检测模型中,用于被视为成功的手部检测的最小置信度得分。 | 0.0 - 1.0 | 0.5 |
min_hand_presence_confidence | 手部关键点检测模型中手部存在得分的最小置信度阈值。在视频模式和实时流模式下,如果手部关键点模型返回的手部存在置信度得分低于该阈值,Hand Landmarker将触发手掌检测模型。否则,一个轻量级的手部跟踪算法将确定手的位置,用于后续的关键点检测。 | 0.0 - 1.0 | 0.5 |
min_tracking_confidence | 手部追踪被认为成功的最小置信度得分。这是当前帧和上一帧中手部之间的边界框IoU阈值。在Hand Landmarker的视频模式和流模式中,如果追踪失败,Hand Landmarker将触发手部检测。否则,它将跳过手部检测。 | 0.0 - 1.0 | 0.5 |
result_callback | 异步回调结果(仅用于LIVE_STREAM模式) | N/A | N/A |
2.2 python代码实现
1、安装mediapipe包
pip install mediapipe
2、下载预训练模型(hand_landmarker.task):
3、检测结果
- Handednessd:表示检测到的手是左手还是右手
- Landmarks:手部关键点共有21个,每个关键点由x、y和z坐标组成。x和y坐标通过图像的宽度和高度进行了归一化,范围在[0.0, 1.0]之间。z坐标表示关键点的深度,手腕处的深度被定义为原点。数值越小,表示关键点离摄像机越近。z的大小与x的大小大致相同。
- WorldLandmarks:以世界坐标的形式呈现21个手部关键点。每个关键点由x、y和z组成,表示以米为单位的真实世界三维坐标,原点位于手的几何中心。
HandLandmarkerResult:
Handedness:
Categories #0:
index : 0
score : 0.98396
categoryName : Left # Left代表左手
Landmarks:
Landmark #0:
x : 0.638852
y : 0.671197
z : -3.41E-7
Landmark #1:
x : 0.634599
y : 0.536441
z : -0.06984
... (21 landmarks for a hand)
WorldLandmarks:
Landmark #0:
x : 0.067485
y : 0.031084
z : 0.055223
Landmark #1:
x : 0.063209
y : -0.00382
z : 0.020920
... (21 world landmarks for a hand)
2.2.1 单张图像的人手关键点检测
- 检测结果保存在results中,其会保存一个人手的坐标list以及左右手的标签以及标签的置信度等
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
import cv2
from utils import draw_landmarks_on_image
def detect_hands_from_image(img_path):
# 1、 创建人手坐标点检测器
# 下载人手关键点检测模型hand_landmarker.task
# https://storage.googleapis.com/mediapipe-models/hand_landmarker/hand_landmarker/float16/1/hand_landmarker.task
base_options = python.BaseOptions(model_asset_path='hand_landmarker.task')
options = vision.HandLandmarkerOptions(base_options=base_options, num_hands=2)
detector = vision.HandLandmarker.create_from_options(options)
# 2、 加载输入图片
image = mp.Image.create_from_file(img_path)
# 3、 使用下载好的模型进行人手坐标点检测
detection_result = detector.detect(image)
print(detection_result)
# 4、 可视化人手检测
annotated_image = draw_landmarks_on_image(image.numpy_view(), detection_result)
imageRGB = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
# cv2.imwrite('new_img.jpg', imageRGB)
# 在使用OpenCV的cv2.imshow函数显示图像时,它会默认将传入的图像数据解释为BGR格式
# 如果你传入的是RGB格式的图像数据,OpenCV会在显示时进行颜色通道的调整,使图像以BGR格式进行显示。
cv2.imshow('women_hands', imageRGB)
# 输入esc结束捕获
if cv2.waitKey(0) == 27:
cv2.destroyAllWindows()
if __name__ == '__main__':
detect_hands_from_image(img_path="image.jpg")
检测结果如下:
2.2.2 视频帧的人手关键点检测
- 这里利用python自带的GUI库tkinter绘制一个简单页面
- 可以调用本地电脑摄像头进行在线检测,也可进行离线检测
- 需要设置running_mode为VisionRunningMode.VIDEO
import time
from utils import draw_landmarks_on_image
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
import cv2
# python的图形界面库tkinter
from tkinter import filedialog
from tkinter import *
from PIL import Image, ImageTk
#####################相关GUI页面设置##########################
def SimpleGUI():
global cap
global video_path
# 定义变量来表示窗口是否关闭
global window_closed
# 初始化
video_path = 0
cap = cv2.VideoCapture(video_path)
window_closed = False
global root
global Video_Label
root = Tk()
root.geometry("1920x1080+0+0") # 设置窗口大小
root.state("zoomed")
root.config(bg="#3a3b3c") # 设置窗口的背景颜色
root.title("人手关键点检测简易GUI") # 设置窗口标题
# 1、使用摄像头进行在线检测
live_btn = Button(root, height=1, text='ONLINE', width=8, fg='magenta', font=("Calibri", 14, "bold"),
command=lambda: video_live())
live_btn.place(x=1200, y=30)
text = Label(root, text="For Online Video", bg="#3a3b3c", fg="#ffffff", font=("Calibri", 20))
text.place(x=1000, y=30)
# live_btn绑定事件,进行在线检测
def video_live():
global video_path, cap
video_path = 0
cap = cv2.VideoCapture(video_path)
text = Label(root, text="Live Video Feed", bg="#3a3b3c", fg="#ffffff", font=("Calibri", 20))
text.place(x=250, y=150)
# 2、检测离线的视频
browse_btn = Button(root, height=1, width=8, text='OFFLINE', fg='magenta', font=("Calibri", 14, "bold"),
command=lambda: path_select())
browse_btn.place(x=1200, y=90)
text = Label(root, text="For Offline Video", bg="#3a3b3c", fg="#ffffff", font=("Calibri", 20))
text.place(x=1000, y=90)
# 给browse_btn绑定事件,检测离线视频
def path_select():
global video_path, cap
video_path = filedialog.askopenfilename()
print(video_path)
cap = cv2.VideoCapture(video_path)
text = Label(root, text="Recorded Video ", bg="#3a3b3c", fg="#ffffff", font=("Calibri", 20))
text.place(x=250, y=150)
# 检测系统标题
ttl = Label(root, text="人手关键点检测", bg="#4f4d4a", fg="#fffbbb", font=("Calibri", 40))
ttl.place(x=100, y=50)
Video_frame = Frame(root, height=480, width=640, bg="red")
Video_Label = Label(root, height=480, width=640, bg="#4f4d4a")
Video_frame.place(x=350, y=200)
Video_Label.place(x=350, y=200)
# 处理窗口关闭事件
def on_closing():
global window_closed
window_closed = True
root.destroy()
root.protocol("WM_DELETE_WINDOW", on_closing)
def detect_hands_from_video():
VisionRunningMode = mp.tasks.vision.RunningMode
base_options = python.BaseOptions(model_asset_path='hand_landmarker.task')
options = vision.HandLandmarkerOptions(base_options=base_options, num_hands=2, running_mode=VisionRunningMode.VIDEO)
detector = vision.HandLandmarker.create_from_options(options)
frame_timestamp_ms = int(time.time())
while True:
# 从相机从捕获一帧
ret, img = cap.read()
# 如果窗口关闭,就结束循环
if window_closed:
break
if ret:
# 将图像从BGR颜色空间转换为Lab颜色空间
numpy_frame_from_opencv = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=numpy_frame_from_opencv)
# Perform hand landmarks detection on the provided single image.
# The hand landmarker must be created with the video mode.
# 需要一个单调递增的视频帧的时间戳frame_timestamp_ms
detection_result = detector.detect_for_video(mp_image, frame_timestamp_ms)
annotated_image = draw_landmarks_on_image(img, detection_result)
annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGRA2RGB)
# 创建一个Tkinter兼容的照片图像(photo image),它可在Tkinter期望一个图像对象的任何地方使用
image = ImageTk.PhotoImage(Image.fromarray(annotated_image), Image.LANCZOS)
# 在Label中显示图片
Video_Label["image"] = image
cv2.waitKey(25)
frame_timestamp_ms += 25
else:
# 将视频的播放位置移动到第一帧,从视频的开头重新开始播放视频
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
root.update()
print('*'*50)
cap.release()
cv2.destroyAllWindows()
if __name__ == '__main__':
SimpleGUI()
detect_hands_from_video()
检测结果如下:
2.2.3 实时流的人手关键点检测
- 需要设置running_mode为VisionRunningMode.LIVE_STREAM
- 这里实现一个回调函数result_callback,用来将检测结果输入到output.avi文件中
- 检测是异步检测detector.detect_async()
import time
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
import cv2
from utils import draw_landmarks_on_image
VisionRunningMode = mp.tasks.vision.RunningMode
HandLandmarkerResult = mp.tasks.vision.HandLandmarkerResult
# 创建一个视频编解码器
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter(filename='output.avi', fourcc=fourcc, fps=20.0, frameSize=(640, 480))
# 回调函数,将检测的结果保存为视频
def print_result(result: HandLandmarkerResult, output_image: mp.Image, timestamp_ms: int):
if cap.isOpened():
annotated_image = draw_landmarks_on_image(img, result)
out.write(annotated_image)
# 基础配置,并创建检测器
base_options = python.BaseOptions(model_asset_path='hand_landmarker.task')
options = vision.HandLandmarkerOptions(base_options=base_options
, num_hands=2
, running_mode=VisionRunningMode.LIVE_STREAM
, result_callback=print_result
)
detector = vision.HandLandmarker.create_from_options(options)
cap = cv2.VideoCapture(0)
frame_timestamp_ms = int(time.time())
while True:
if cap.isOpened():
# 捕获一帧图片
_, img = cap.read()
cv2.imshow('original', img)
# 获取视频的FPS(每秒帧数)
fps = cap.get(cv2.CAP_PROP_FPS)
print('fps = ', fps)
# 将图像从BGR颜色空间转换为RGB颜色空间
numpy_frame_from_opencv = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=numpy_frame_from_opencv)
# 异步检测
detector.detect_async(mp_image, int(frame_timestamp_ms))
# 按esc键结束捕获
if cv2.waitKey(25) == 27:
print('exited')
break
frame_timestamp_ms += 25
cap.release()
out.release()
cv2.destroyAllWindows()
上文代码中所用的检测模型、图片及视频素材也已经上传到Git仓库。