FaceBook的SlowFast画框demo代码修改

该文描述了对源码的修改,包括更新TaskInfo类以存储检测框信息,Detectron2Predictor类以获取并添加预测标签,以及多个绘图函数以显示预测结果。修改着重于提高检测框的精度和可视化效果,特别是针对人类实例的检测。
摘要由CSDN通过智能技术生成

关于画框源码部分分析在此处

一 修改方案

1 修改/root/autodl-tmp/SlowFast/slowfast/visualization/utils.py py342 的TaskInfo类

class TaskInfo:
    def __init__(self):
        self.frames = None
        self.id = -1
        self.bboxes = None
        self.action_preds = None
        self.detect_preds = None
        self.num_buffer_frames = 0
        self.img_height = -1
        self.img_width = -1
        self.crop_size = -1
        self.clip_vis_size = -1

    def add_frames(self, idx, frames):
        """
        Add the clip and corresponding id.
        Args:
            idx (int): the current index of the clip.
            frames (list[ndarray]): list of images in "BGR" format.
        """
        self.frames = frames
        self.id = idx

    def add_bboxes(self, bboxes):
        """
        Add correspondding bounding boxes.
        """
        self.bboxes = bboxes

    def add_action_preds(self, preds):
        """
        Add the corresponding action predictions.
        """
        self.action_preds = preds
    def add_detect_preds(self, preds):
        """
        Add the corresponding action predictions.
        """
        self.detect_preds = preds

2 修改detectron2调用部分增添检测框部分代码

predictor.py py157

from slowfast.visualization.video_visualizer import get_class_names
    class Detectron2Predictor:
    """
    Wrapper around Detectron2 to return the required predicted bounding boxes
    as a ndarray.
    """

    def __init__(self, cfg, gpu_id=None):
        """
        Args:
            cfg (CfgNode): configs. Details can be found in
                slowfast/config/defaults.py
            gpu_id (Optional[int]): GPU id.
        """

        self.cfg = get_cfg()
        self.cfg.merge_from_file(
            model_zoo.get_config_file(cfg.DEMO.DETECTRON2_CFG)
        )
        self.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = cfg.DEMO.DETECTRON2_THRESH
        self.cfg.MODEL.WEIGHTS = cfg.DEMO.DETECTRON2_WEIGHTS
        self.cfg.INPUT.FORMAT = cfg.DEMO.INPUT_FORMAT
        """ 从json文件中获取类别及编号 """
        self.detect_names,_,_ =get_class_names(cfg.DEMO.Detect_File_Path,None,None)
        if cfg.NUM_GPUS and gpu_id is None:
            gpu_id = torch.cuda.current_device()
        self.cfg.MODEL.DEVICE = (
            "cuda:{}".format(gpu_id) if cfg.NUM_GPUS > 0 else "cpu"
        )

        logger.info("Initialized Detectron2 Object Detection Model.")

        self.predictor = DefaultPredictor(self.cfg)

    def __call__(self, task):
        """
        Return bounding boxes predictions as a tensor.
        Args:
            task (TaskInfo object): task object that contain
                the necessary information for action prediction. (e.g. frames)
        Returns:
            task (TaskInfo object): the same task info object but filled with
                prediction values (a tensor) and the corresponding boxes for
                action detection task.
        """
        # middle_frame = task.frames[len(task.frames) // 2]
        # outputs = self.predictor(middle_frame)
        # # Get only human instances
        # mask = outputs["instances"].pred_classes == 0
        # pred_boxes = outputs["instances"].pred_boxes.tensor[mask]
        # task.add_bboxes(pred_boxes)

        middle_frame = task.frames[len(task.frames) // 2]
        outputs = self.predictor(middle_frame)
        # Get only human instances
        """类别及阈值划分"""
        mask = (outputs["instances"].scores>=0.7) & (outputs["instances"].pred_classes == 0)
        """获取预测框"""
        pred_boxes = outputs["instances"].pred_boxes.tensor[mask]
        """得到预测置信度"""
        scores=outputs["instances"].socres[mask].tolist()
        """获取类别标签"""
        pred_labels=outputs["instances"].pred_classes[mask]
        pred_labels=pred_labels.tolist()
        """进行标签匹配"""
        for i in range(len(pred_labels)):
            pred_labels[i]=self.detect_names[pred_labels[i]]
        preds=[
            "[{:.4f}] {}".format(s, labels) for s, labels in zip(scores,pred_labels)
        ]
        """加入预测标签"""
        task.add_detect_preds(preds)
        task.add_bboxes(pred_boxes)

        return task

3 修改对应的绘图函数

(1)async_predictor.py py 276

def draw_predictions(task, video_vis):
    """
    Draw prediction for the given task.
    Args:
        task (TaskInfo object): task object that contain
            the necessary information for visualization. (e.g. frames, preds)
            All attributes must lie on CPU devices.
        video_vis (VideoVisualizer object): the video visualizer object.
    """
    boxes = task.bboxes
    frames = task.frames
    preds = task.action_preds
    """ 提出检测类别"""
    detect_pred=task.detect_preds
    if boxes is not None:
        img_width = task.img_width
        img_height = task.img_height
        if boxes.device != torch.device("cpu"):
            boxes = boxes.cpu()
        boxes = cv2_transform.revert_scaled_boxes(
            task.crop_size, boxes, img_height, img_width
        )

    keyframe_idx = len(frames) // 2 - task.num_buffer_frames
    draw_range = [
        keyframe_idx - task.clip_vis_size,
        keyframe_idx + task.clip_vis_size,
    ]
    buffer = frames[: task.num_buffer_frames]
    frames = frames[task.num_buffer_frames :]
    if boxes is not None:
        if len(boxes) != 0:
        """修改 draw_clip_range函数,加入detect_pred变量"""
            frames = video_vis.draw_clip_range(
                frames,
                preds,
                detect_pred,
                boxes,
                keyframe_idx=keyframe_idx,
                draw_range=draw_range,
            )
    else:
        frames = video_vis.draw_clip_range(
            frames, preds, keyframe_idx=keyframe_idx, draw_range=draw_range
        )
    del task

    return buffer + frames

(2) video_visualizer.py py 514

"""加入detect_pred属性"""
    def draw_clip_range(
        self,
        frames,
        preds,
        detect_pred,
        bboxes=None,
        text_alpha=0.5,
        ground_truth=False,
        keyframe_idx=None,
        draw_range=None,
        repeat_frame=1,
    ):
        """
        Draw predicted labels or ground truth classes to clip. Draw bouding boxes to clip
        if bboxes is provided. Boxes will gradually fade in and out the clip, centered around
        the clip's central frame, within the provided `draw_range`.
        Args:
            frames (array-like): video data in the shape (T, H, W, C).
            preds (tensor): a tensor of shape (num_boxes, num_classes) that contains all of the confidence scores
                of the model. For recognition task or for ground_truth labels, input shape can be (num_classes,).
            bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates of the bounding boxes.
            text_alpha (float): transparency label of the box wrapped around text labels.
            ground_truth (bool): whether the prodived bounding boxes are ground-truth.
            keyframe_idx (int): the index of keyframe in the clip.
            draw_range (Optional[list[ints]): only draw frames in range [start_idx, end_idx] inclusively in the clip.
                If None, draw on the entire clip.
            repeat_frame (int): repeat each frame in draw_range for `repeat_frame` time for slow-motion effect.
        """
        if draw_range is None:
            draw_range = [0, len(frames) - 1]
        if draw_range is not None:
            draw_range[0] = max(0, draw_range[0])
            left_frames = frames[: draw_range[0]]
            right_frames = frames[draw_range[1] + 1 :]

        draw_frames = frames[draw_range[0] : draw_range[1] + 1]
        if keyframe_idx is None:
            keyframe_idx = len(frames) // 2

        img_ls = (
            list(left_frames)
            """修改draw_clip函数"""
            + self.draw_clip(
                draw_frames,
                preds,
                detect_pred,
                bboxes=bboxes,
                text_alpha=text_alpha,
                ground_truth=ground_truth,
                keyframe_idx=keyframe_idx - draw_range[0],
                repeat_frame=repeat_frame,
            )
            + list(right_frames)
        )

        return img_ls

(3)video_visualizer.py py 568

"""加入detect_pred参数"""
    def draw_clip(
        self,
        frames,
        preds,
        detect_pred,
        bboxes=None,
        text_alpha=0.5,
        ground_truth=False,
        keyframe_idx=None,
        repeat_frame=1,
    ):
        """
        Draw predicted labels or ground truth classes to clip. Draw bouding boxes to clip
        if bboxes is provided. Boxes will gradually fade in and out the clip, centered around
        the clip's central frame.
        Args:
            frames (array-like): video data in the shape (T, H, W, C).
            preds (tensor): a tensor of shape (num_boxes, num_classes) that contains all of the confidence scores
                of the model. For recognition task or for ground_truth labels, input shape can be (num_classes,).
            bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates of the bounding boxes.
            text_alpha (float): transparency label of the box wrapped around text labels.
            ground_truth (bool): whether the prodived bounding boxes are ground-truth.
            keyframe_idx (int): the index of keyframe in the clip.
            repeat_frame (int): repeat each frame in draw_range for `repeat_frame` time for slow-motion effect.
        """
        assert repeat_frame >= 1, "`repeat_frame` must be a positive integer."

        repeated_seq = range(0, len(frames))
        repeated_seq = list(
            itertools.chain.from_iterable(
                itertools.repeat(x, repeat_frame) for x in repeated_seq
            )
        )

        frames, adjusted = self._adjust_frames_type(frames)
        if keyframe_idx is None:
            half_left = len(repeated_seq) // 2
            half_right = (len(repeated_seq) + 1) // 2
        else:
            mid = int((keyframe_idx / len(frames)) * len(repeated_seq))
            half_left = mid
            half_right = len(repeated_seq) - mid

        alpha_ls = np.concatenate(
            [
                np.linspace(0, 1, num=half_left),
                np.linspace(1, 0, num=half_right),
            ]
        )
        text_alpha = text_alpha
        frames = frames[repeated_seq]
        img_ls = []
        for alpha, frame in zip(alpha_ls, frames):
        """修改draw_one_frame函数"""
            draw_img = self.draw_one_frame(
                frame,
                preds,
                detect_pred,
                bboxes,
                alpha=alpha,
                text_alpha=text_alpha,
                ground_truth=ground_truth,
            )
            if adjusted:
                draw_img = draw_img.astype("float32") / 255

            img_ls.append(draw_img)

        return img_ls

(4)video_visualizer.py py 404

"""加入detect_pred属性"""
    def draw_one_frame(
        self,
        frame,
        preds,
        detect_pred,
        bboxes=None,
        alpha=0.5,
        text_alpha=0.7,
        ground_truth=False,
    ):
        """
        Draw labels and bouding boxes for one image. By default, predicted labels are drawn in
        the top left corner of the image or corresponding bounding boxes. For ground truth labels
        (setting True for ground_truth flag), labels will be drawn in the bottom left corner.
        Args:
            frame (array-like): a tensor or numpy array of shape (H, W, C), where H and W correspond to
                the height and width of the image respectively. C is the number of
                color channels. The image is required to be in RGB format since that
                is a requirement of the Matplotlib library. The image is also expected
                to be in the range [0, 255].
            preds (tensor or list): If ground_truth is False, provide a float tensor of shape (num_boxes, num_classes)
                that contains all of the confidence scores of the model.
                For recognition task, input shape can be (num_classes,). To plot true label (ground_truth is True),
                preds is a list contains int32 of the shape (num_boxes, true_class_ids) or (true_class_ids,).
            bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates of the bounding boxes.
            alpha (Optional[float]): transparency level of the bounding boxes.
            text_alpha (Optional[float]): transparency level of the box wrapped around text labels.
            ground_truth (bool): whether the prodived bounding boxes are ground-truth.
        """
        if isinstance(preds, torch.Tensor):
            if preds.ndim == 1:
                preds = preds.unsqueeze(0)
            n_instances = preds.shape[0]
        elif isinstance(preds, list):
            n_instances = len(preds)
        else:
            logger.error("Unsupported type of prediction input.")
            return

        if ground_truth:
            top_scores, top_classes = [None] * n_instances, preds

        elif self.mode == "top-k":
            top_scores, top_classes = torch.topk(preds, k=self.top_k)
            top_scores, top_classes = top_scores.tolist(), top_classes.tolist()
        elif self.mode == "thres":
            top_scores, top_classes = [], []
            for pred in preds:
                mask = pred >= self.thres
                top_scores.append(pred[mask].tolist())
                top_class = torch.squeeze(torch.nonzero(mask), dim=-1).tolist()
                top_classes.append(top_class)

        # Create labels top k predicted classes with their scores.
        text_labels = []
        for i in range(n_instances):
            text_labels.append(
                _create_text_labels(
                    top_classes[i],
                    top_scores[i],
                    self.class_names,
                    ground_truth=ground_truth,
                )
            )
        frame_visualizer = ImgVisualizer(frame, meta=None)
        font_size = min(
            max(np.sqrt(frame.shape[0] * frame.shape[1]) // 35, 5), 9
        )
        top_corner = not ground_truth
        if bboxes is not None:
            assert len(preds) == len(
                bboxes
            ), "Encounter {} predictions and {} bounding boxes".format(
                len(preds), len(bboxes)
            )
            for i, box in enumerate(bboxes):
                text = text_labels[i]
                """加入检测标签"""
                text.append(detect_pred[i])
                pred_class = top_classes[i]
                """默认目标检测标签颜色"""
                colors1=[(0.9019607843137225,0.9607843137254902,0.788235294117667)]
                """为行为检测标签颜色"""
                colors2 = [self._get_color(pred) for pred in pred_class]
                colors=colors1+colors2

                box_color = "r" if ground_truth else "g"
                line_style = "--" if ground_truth else "-."
                frame_visualizer.draw_box(
                    box,
                    alpha=alpha,
                    edge_color=box_color,
                    line_style=line_style,
                )
                frame_visualizer.draw_multiple_text(
                    text,
                    box,
                    top_corner=top_corner,
                    font_size=font_size,
                    box_facecolors=colors,
                    alpha=text_alpha,
                )
        else:
            text = text_labels[0]
            pred_class = top_classes[0]
            colors = [self._get_color(pred) for pred in pred_class]
            frame_visualizer.draw_multiple_text(
                text,
                torch.Tensor([0, 5, frame.shape[1], frame.shape[0] - 5]),
                top_corner=top_corner,
                font_size=font_size,
                box_facecolors=colors,
                alpha=text_alpha,
            )

        return frame_visualizer.output.get_image()

4 配置文件

TRAIN:
  ENABLE: False
  DATASET: kinetics
  BATCH_SIZE: 32
  EVAL_PERIOD: 5
  CHECKPOINT_PERIOD: 5
  AUTO_RESUME: True
  CHECKPOINT_FILE_PATH: '/root/autodl-tmp/SlowFast/checkpoints/checkpoint_epoch_00196.pyth'
DATA:
  NUM_FRAMES: 32
  SAMPLING_RATE: 2
  TRAIN_JITTER_SCALES: [256, 320]
  TRAIN_CROP_SIZE: 224
  TEST_CROP_SIZE: 256
  INPUT_CHANNEL_NUM: [3, 3]
DETECTION:
  ENABLE: True
  ALIGNED: False
SLOWFAST:
  ALPHA: 8
  BETA_INV: 8
  FUSION_CONV_CHANNEL_RATIO: 2
  FUSION_KERNEL_SZ: 7
RESNET:
  ZERO_INIT_FINAL_BN: True
  WIDTH_PER_GROUP: 64
  NUM_GROUPS: 1
  DEPTH: 50
  TRANS_FUNC: bottleneck_transform
  STRIDE_1X1: False
  NUM_BLOCK_TEMP_KERNEL: [[3, 3], [4, 4], [6, 6], [3, 3]]
  SPATIAL_STRIDES: [[1, 1], [2, 2], [2, 2], [2, 2]]
  SPATIAL_DILATIONS: [[1, 1], [1, 1], [1, 1], [1, 1]]
NONLOCAL:
  LOCATION: [[[], []], [[], []], [[], []], [[], []]]
  GROUP: [[1, 1], [1, 1], [1, 1], [1, 1]]
  INSTANTIATION: dot_product
BN:
  USE_PRECISE_STATS: True
  NUM_BATCHES_PRECISE: 200
SOLVER:
  BASE_LR: 0.1
  LR_POLICY: steps_with_relative_lrs
  MAX_EPOCH: 400
  STEPS: [ 0, 100,185, 255, 355]
  LRS: [ 1, 0.1, 0.01, 0.001,0.0001]
  MOMENTUM: 0.9
  WEIGHT_DECAY: 1e-4
  WARMUP_EPOCHS: 40.0
  WARMUP_START_LR: 0.01
  OPTIMIZING_METHOD: sgd
MODEL:
  NUM_CLASSES: 7
  ARCH: slowfast
  MODEL_NAME: SlowFastDualAttention
  LOSS_FUNC: cross_entropy
  DROPOUT_RATE: 0.5
TEST:
  ENABLE: False
  DATASET: kinetics
  BATCH_SIZE: 32
DATA_LOADER:
  NUM_WORKERS: 32
  PIN_MEMORY: True
NUM_GPUS: 1
NUM_SHARDS: 1
RNG_SEED: 42
OUTPUT_DIR: .
DEMO:
  ENABLE: True
  LABEL_FILE_PATH: "/root/autodl-tmp/SlowFast/meta/7trail_labels.json"
  INPUT_VIDEO: "/root/autodl-tmp/SlowFast/detect/1.avi"
  OUTPUT_FILE: "/root/autodl-tmp/SlowFast/detect/out.avi"
  DETECTRON2_CFG: "COCO-owe/config.yaml"
  DETECTRON2_WEIGHTS: "/root/autodl-tmp/detectron2/output_trainsample/model_final.pth"
  Detect_File_Path: "/root/autodl-tmp/SlowFast/meta/10.json"
  DETECTRON2_THRESH: 0.1



5 修改结果对比

修改前
在这里插入图片描述
修改后:
在这里插入图片描述
对于检测json文件问题,DEMO.Detect_File_Path其对应文件为JSon

{"Normal": 0, "No_wear": 1, "Playing_phone": 2, "Smoking": 3, "Sleeping": 4, "Playing_phone_smoking": 5,
"No_wear_playing_phone": 6,"No_wear_smoking": 7,"No_wear_sleeping": 8,"No_wear_playing_phone_smoking": 9}
  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 28
    评论
评论 28
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值