因为项目需要,画出来的关键点检测图不需要脸部和手部。找了找,网上没有介绍的,自己摸索了一下,主要是控制draw_landmarks函数的传入参数
draw_landmarks函数源代码
def draw_landmarks(
image: np.ndarray,
landmark_list: landmark_pb2.NormalizedLandmarkList,
connections: Optional[List[Tuple[int, int]]] = None,
landmark_drawing_spec: Union[DrawingSpec,
Mapping[int, DrawingSpec]] = DrawingSpec(
color=RED_COLOR),
connection_drawing_spec: Union[DrawingSpec,
Mapping[Tuple[int, int],
DrawingSpec]] = DrawingSpec()):
"""Draws the landmarks and the connections on the image.
Args:
image: A three channel BGR image represented as numpy ndarray.
landmark_list: A normalized landmark list proto message to be annotated on
the image.
connections: A list of landmark index tuples that specifies how landmarks to
be connected in the drawing.
landmark_drawing_spec: Either a DrawingSpec object or a mapping from hand
landmarks to the DrawingSpecs that specifies the landmarks' drawing
settings such as color, line thickness, and circle radius. If this
argument is explicitly set to None, no landmarks will be drawn.
connection_drawing_spec: Either a DrawingSpec object or a mapping from hand
connections to the DrawingSpecs that specifies the connections' drawing
settings such as color and line thickness. If this argument is explicitly
set to None, no landmark connections will be drawn.
Raises:
ValueError: If one of the followings:
a) If the input image is not three channel BGR.
b) If any connetions contain invalid landmark index.
"""
if not landmark_list:
return
if image.shape[2] != _BGR_CHANNELS:
raise ValueError('Input image must contain three channel bgr data.')
image_rows, image_cols, _ = image.shape
idx_to_coordinates = {}
for idx, landmark in enumerate(landmark_list.landmark):
if ((landmark.HasField('visibility') and
landmark.visibility < _VISIBILITY_THRESHOLD) or
(landmark.HasField('presence') and
landmark.presence < _PRESENCE_THRESHOLD)):
continue
landmark_px = _normalized_to_pixel_coordinates(landmark.x, landmark.y,
image_cols, image_rows)
if landmark_px:
idx_to_coordinates[idx] = landmark_px
if connections:
num_landmarks = len(landmark_list.landmark)
# Draws the connections if the start and end landmarks are both visible.
for connection in connections:
start_idx = connection[0]
end_idx = connection[1]
if not (0 <= start_idx < num_landmarks and 0 <= end_idx < num_landmarks):
raise ValueError(f'Landmark index is out of range. Invalid connection '
f'from landmark #{start_idx} to landmark #{end_idx}.')
if start_idx in idx_to_coordinates and end_idx in idx_to_coordinates:
drawing_spec = connection_drawing_spec[connection] if isinstance(
connection_drawing_spec, Mapping) else connection_drawing_spec
cv2.line(image, idx_to_coordinates[start_idx],
idx_to_coordinates[end_idx], drawing_spec.color,
drawing_spec.thickness)
# Draws landmark points after finishing the connection lines, which is
# aesthetically better.
if landmark_drawing_spec:
for idx, landmark_px in idx_to_coordinates.items():
drawing_spec = landmark_drawing_spec[idx] if isinstance(
landmark_drawing_spec, Mapping) else landmark_drawing_spec
# White circle border
circle_border_radius = max(drawing_spec.circle_radius + 1,
int(drawing_spec.circle_radius * 1.2))
cv2.circle(image, landmark_px, circle_border_radius, WHITE_COLOR,
drawing_spec.thickness)
# Fill color into the circle
cv2.circle(image, landmark_px, drawing_spec.circle_radius,
drawing_spec.color, drawing_spec.thickness)
这里主要通过改变landmark_list和connections两个参数
landmark_list一般直接传self.results.pose_landmarks,但是其长度33
所以通过self.results.pose_landmarks.landmark.pop()
把不需要的关键点索引全部pop出去
同时修改对应的connction,也就是不能直接传pose.POSE_CONNECTIONS了,可以自己写个POSE_CONNECTIONS
如下形式
POSE_CONNECTIONS = frozenset([(0, 1), (0, 2),(2, 4), (1, 3), (3, 5), (1, 7), (0, 6), (7, 9), (6, 8),
(9, 11), (8, 10), (11, 13), (11, 15), (13, 15),
(10, 12), (10, 14), (12, 14)])
最终修改的地方如下:
# 增加变量
POSE_CONNECTIONS = frozenset([(0, 1), (0, 2),(2, 4), (1, 3), (3, 5), (1, 7), (0, 6), (7, 9), (6, 8),
(9, 11), (8, 10), (11, 13), (11, 15), (13, 15),
(10, 12), (10, 14), (12, 14)])
# 我这里是去除脸部和手部关键点
if self.results.pose_landmarks:
if draw:
draw_landmarks = self.results.pose_landmarks
for i in range(11):
draw_landmarks.landmark.pop(0)
for i in range(6):
draw_landmarks.landmark.pop(6)
self.mpDraw.draw_landmarks(img, draw_landmarks,
POSE_CONNECTIONS)