1.摘要
MediaPipe Objectron 是一种用于日常物体的移动实时 3D 物体检测解决方案。它检测 2D 图像中的物体,并通过机器学习 (ML) 模型估计它们的姿势,该模型在 Objectron 数据集上训练。
对象检测是一个被广泛研究的计算机视觉问题,但大部分研究都集中在二维对象预测上。虽然 2D 预测仅提供 2D 边界框,但通过将预测扩展到 3D,人们可以捕捉物体在世界中的大小、位置和方向,从而在机器人、自动驾驶汽车、图像检索和增强现实等领域得到广泛应用。尽管 2D 物体检测相对成熟并已在行业中得到广泛应用,但由于缺乏数据且类别内物体的外观和形状具有多样性,因此从 2D 图像中检测 3D 物体是一个具有挑战性的问题。
2.获取真实世界的3D训练数据
虽然街景有大量 3D 数据,但由于对依赖于 LIDAR 等 3D 捕获传感器的自动驾驶汽车的研究的普及,具有用于更精细日常物体的真实 3D 标注的数据集极为有限。为了克服这个问题,我们开发了一种使用移动增强现实 (AR) 会话数据的新型数据管道。随着 ARCore 和 ARKit 的到来,数以亿计的智能手机现在具有 AR 功能,并且能够在 AR 会话期间捕获附加信息,包括相机姿势、稀疏 3D 点云、估计照明和平面。
为了标记真实数据,我们构建了一个新颖的注释工具,用于 AR 会话数据,它允许注释者快速标记对象的 3D 边界框。此工具使用分屏视图来显示 2D 视频帧,其左侧是重叠的 3D 边界框,右侧是显示 3D 点云、相机位置和检测到的平面的视图。注释者在 3D 视图中绘制 3D 边界框,并通过查看 2D 视频帧中的投影来验证其位置。对于静态对象,我们只需要在单个帧中注释一个对象,并使用来自 AR 会话数据的真实相机姿势信息将其位置传播到所有帧,这使得该过程非常高效。
此处使用到了PIL来调节GIF图像大小:
from PIL import Image
def crop_gif_short(gif, gif_out, size):
im = Image.open(gif)
transparency = im.info["transparency"]
frames = [im.resize(size) for frame in range(0, im.n_frames) if not im.seek(frame)]
frames[0].save(gif_out, save_all=True, append_images=frames, loop=0, duration=im.info['duration'],
transparency=transparency)
if __name__ == '__main__':
gif_input = 'objectron_data_annotation.gif'
gif_output = 'objectron_data_annotation_resized.gif'
size = (225, 150)
crop_gif_short(gif_input, gif_output, size)
3.AR合成数据生成
一种流行的方法是用合成数据补充现实世界的数据,以提高预测的准确性。然而,这样做的尝试通常会产生糟糕的、不真实的数据,或者在照片级渲染的情况下,需要大量的努力和计算。我们的新方法称为 AR 合成数据生成,将虚拟对象放置到具有 AR 会话数据的场景中,这使我们能够利用相机姿势、检测到的平面和估计的照明来生成物理上可能的位置,并使用与场景匹配的照明。这种方法产生了高质量的合成数据,并渲染了符合场景几何形状的物体,并无缝地融入真实背景中。通过结合真实世界数据和 AR 合成数据,我们能够将准确率提高约 10%。
4.用于 3D 对象检测的 ML 管道
我们构建了两个 ML 管道来从单个 RGB 图像预测对象的 3D 边界框:一个是两级管道,另一个是单级管道。两级管道比具有相似或更好精度的单级管道快 3 倍。单级管道擅长检测多个对象,而两级管道适合单个对象。
4.1两级管道
我们的两级管道由下图中的图表说明。第一阶段使用对象检测器来查找对象的 2D 裁剪。第二阶段进行图像裁剪并估计 3D 边界框。同时,它还为下一帧计算对象的 2D 裁剪,这样对象检测器就不需要每帧都运行。
我们可以在第一阶段使用任何 2D 物体检测器。在此解决方案中,我们使用 TensorFlow Object Detection训练 Open Images 数据集。我们发布的第二阶段 3D 边界框预测器在 Adreno 650 移动 GPU 上运行 83FPS。
4.2单级管道
我们的单级管道如上图所示,模型主干具有编码器-解码器架构,构建于 MobileNetv2 之上。我们采用多任务学习方法,通过检测和回归联合预测对象的形状。形状任务根据可用的标注信息(例如分割)来预测对象的形状信号。如果训练数据中没有形状注释,则这是可选的。对于检测任务,我们使用带注释的边界框并将高斯拟合到框,中心位于框质心,标准偏差与框大小成正比。检测的目标是预测该分布,其峰值代表对象的中心位置。回归任务估计八个边界框顶点的二维投影。为了获得边界框的最终 3D 坐标,我们利用了完善的姿态估计算法 (EPnP)。它可以恢复对象的 3D 边界框,而无需先验了解对象尺寸。给定 3D 边界框,我们可以轻松计算对象的姿势和大小。该模型足够轻,可以在移动设备上实时运行(在 Adreno 650 移动 GPU 上以 26 FPS 运行)。
我
们
网
络
的
样
本
结
果
:
(
左
)
带
有
估
计
边
界
框
的
原
始
2
D
图
像
,
(
中
)
通
过
高
斯
分
布
检
测
目
标
,
(
右
)
预
测
的
分
割
掩
码
。
我们网络的样本结果:(左)带有估计边界框的原始 2D 图像,(中)通过高斯分布检测目标,(右)预测的分割掩码。
我们网络的样本结果:(左)带有估计边界框的原始2D图像,(中)通过高斯分布检测目标,(右)预测的分割掩码。
4.3检测与跟踪
当模型应用于移动设备捕获的每一帧时,由于在每一帧中估计的 3D 边界框的模糊性,它可能会受到抖动的影响。为了缓解这种情况,我们在 MediaPipe Box Tracking 中的 2D 对象检测和跟踪管道中采用了相同的检测 + 跟踪策略。这减少了在每一帧上运行网络的需要,允许使用更重且因此更准确的模型,同时在移动设备上保持管道实时。它还保留了帧之间的对象身份,并确保预测是时间一致的,减少了抖动。
Objectron 3D 对象检测和跟踪管道作为 MediaPipe 图实现,它在内部使用检测子图和跟踪子图。检测子图每几帧仅执行一次 ML 推理以减少计算负载,并将输出张量解码为包含 9 个关键点的 FrameAnnotation:3D 边界框的中心及其八个顶点。跟踪子图每帧运行一次,使用 MediaPipe Box Tracking 中的 box traker 跟踪紧密包围 3D 边界框投影的 2D 框,并使用 EPnP 将跟踪的 2D 关键点提升到 3D。当检测子图中有新的检测可用时,跟踪子图还负责根据重叠区域在检测和跟踪结果之间进行合并。
5.解决方案API
5.1跨平台的配置选项
STATIC_IMAGE_MODE
:如果设置为 false,该解决方案会将输入图像视为视频流。它将尝试在第一张图像中检测对象,并在成功检测后进一步定位 3D 边界框地标。在随后的图像中,一旦检测到所有 max_num_objects 对象并且定位了相应的 3D 边界框地标,它就会简单地跟踪这些地标,而不会调用另一个检测,直到它失去对任何对象的跟踪。这减少了延迟,非常适合处理视频帧。如果设置为 true,对象检测会运行每个输入图像,非常适合处理一批静态的、可能不相关的图像。默认为false。MAX_NUM_OBJECTS
:要检测的对象的最大数目。默认为5。MIN_DETECTION_CONFIDENCE
:来自对象检测模型的最小置信值 ([0.0, 1.0]),以便将检测视为成功。默认为 0.5。MIN_TRACKING_CONFIDENCE
:来自地标跟踪模型的最小置信值 ([0.0, 1.0]),用于将 3D 边界框地标视为成功跟踪,否则将在下一个输入图像上自动调用对象检测。将其设置为更高的值可以提高解决方案的稳健性,但代价是更高的延迟。如果 static_image_mode 为 true,则忽略,对象检测在每个图像上运行。默认为 0.99。MODEL_NAME
:用于预测3D边界框地标的模型名称。目前支持{'Shoe', 'Chair', 'Cup', 'Camera'}
。默认的Shoe
。FOCAL_LENGTH
:默认情况下,相机焦距在 NDC 空间中定义,即 (fx, fy)。默认为 (1.0, 1.0)。要在像素空间中指定焦距,即 (fx_pixel, fy_pixel),用户应提供 image_size = (image_width, image_height) 以启用 API 内部的转换。PRINCIPAL_POINT
:默认情况下,相机主点定义在 NDC 空间中,即 (px, py)。默认为 (0.0, 0.0)。要指定像素空间中的主要点,即(px_pixel, py_pixel),用户应提供 image_size = (image_width, image_height) 以启用 API 内部的转换。IMAGE_SIZE
:仅当在像素空间中指定了 focus_length 和 principal_point 时才指定。 输入图像的大小,即 (image_width, image_height)。
5.2输出
DETECTED_OBJECTS
:检测到的3D包围框列表。每个3D包围盒由以下内容组成:landmarks_2d
:对象 3D 边界框的 2D 地标。地标坐标分别通过图像宽度和高度归一化为 [0.0, 1.0]。landmarks_3d
:对象 3D 边界框的 3D 地标。地标坐标在相机坐标系中表示。rotation
:从物体坐标系到相机坐标系的旋转矩阵。translation
:从物体坐标系到相机坐标系的平移向量。scale
:对象沿 x、y 和 z 方向的相对比例。
6.Python解决方案API
支持配置选项:
- static_image_mode
- max_num_objects
- min_detection_confidence
- min_tracking_confidence
- model_name
- focal_length
- principal_point
- image_size
# python3.6.5 mediapipe= 0.8.3
import cv2
import mediapipe as mp
mp_drawing = mp.solutions.drawing_utils
mp_objectron = mp.solutions.objectron
# For static images:
IMAGE_FILES = []
with mp_objectron.Objectron(static_image_mode=True,
max_num_objects=5,
min_detection_confidence=0.5,
model_name='Shoe') as objectron:
for idx, file in enumerate(IMAGE_FILES):
image = cv2.imread(file)
# Convert the BGR image to RGB and process it with MediaPipe Objectron.
results = objectron.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
# Draw box landmarks.
if not results.detected_objects:
print(f'No box landmarks detected on {file}')
continue
print(f'Box landmarks of {file}:')
annotated_image = image.copy()
for detected_object in results.detected_objects:
mp_drawing.draw_landmarks(
annotated_image, detected_object.landmarks_2d, mp_objectron.BOX_CONNECTIONS)
mp_drawing.draw_axis(annotated_image, detected_object.rotation,
detected_object.translation)
cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image)
# For webcam input:
cap = cv2.VideoCapture(0)
with mp_objectron.Objectron(static_image_mode=False,
max_num_objects=5,
min_detection_confidence=0.5,
min_tracking_confidence=0.99,
model_name='Shoe') as objectron:
while cap.isOpened():
success, image = cap.read()
if not success:
print("Ignoring empty camera frame.")
# If loading a video, use 'break' instead of 'continue'.
continue
# Convert the BGR image to RGB.
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# To improve performance, optionally mark the image as not writeable to
# pass by reference.
image.flags.writeable = False
results = objectron.process(image)
# Draw the box landmarks on the image.
image.flags.writeable = True
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
if results.detected_objects:
for detected_object in results.detected_objects:
mp_drawing.draw_landmarks(
image, detected_object.landmarks_2d, mp_objectron.BOX_CONNECTIONS)
mp_drawing.draw_axis(image, detected_object.rotation,
detected_object.translation)
cv2.imshow('MediaPipe Objectron', image)
if cv2.waitKey(5) & 0xFF == 27:
break
cap.release()