def draw_predictions(self, task):
"""Visualize stdet predictions on raw frames."""
# read bboxes from task
bboxes = task.display_bboxes.cpu().numpy()
# draw predictions and update task
keyframe_idx = len(task.frames) // 2
draw_range = [
keyframe_idx - task.clip_vis_length // 2,
keyframe_idx + (task.clip_vis_length - 1) // 2
]
assert draw_range[0] >= 0 and draw_range[1] < len(task.frames)
preds_filter, bboxes_filter = self.filter_result(task.action_preds, bboxes, ['fight/hit (a person)','fall down'])
task.frames = self.draw_clip_range(task.frames, preds_filter,
bboxes_filter, draw_range)
return task
def filter_result(self, preds, bbox, keep_class_name_list=['stand']):
if preds is None:
return preds, bbox
preds1 = copy.deepcopy(preds)
bbox1=copy.deepcopy(bbox)
print('filter before preds:',preds1)
print('filter before bbox:',bbox1)
for i in range(len(preds1) - 1, -1, -1):
if len(preds1[i])==0:
continue
if preds1[i][0][0] not in keep_class_name_list:
bbox1 = np.delete(bbox1, i, 0)
del preds1[i]
print(preds1)
print(bbox1)
return preds1, bbox1
代码 是修改mmaction2-0.23.0/demo/webcam_demo_spatiotemporal_det.py