【ECCV2024-CVPPA-MOT 3rd 解决方案】

零、快速链接

codalab
dataset description
workshpr
ByteTrack
my GitHub

一、任务介绍

使用官方提供的 RGB-D 数据集以及对应的 mask2former 实例分割结果、训练集弱标签实现对甜椒的跟踪。

示例 1示例 2示例 3
在这里插入图片描述在这里插入图片描述在这里插入图片描述

二、解决方案:ByteTrack + 后处理

刚开始用 bytetrack_x_mot17.pth.tar 预训练权重直接推理,配置文件是 yolox_x_mix_det.py,然后只推理了一个看看效果,完全没有检测到,所以应该要训练。训练首先需要将数据转换为 MOT20 格式,然后预处理、改配置文件、预训练权重。

1. 配置环境

## 创建环境
conda create -n bytetrack python=3.7
## torch 版本
pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
## 克隆仓库
git clone https://github.com/ifzhang/ByteTrack.git
cd ByteTrack
## 其他包
pip3 install -r requirements.txt # 注释掉里面的torch和torchvision
python3 setup.py develop
pip3 install cython; pip3 install 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
pip3 install cython_bbox

2. 处理数据集

具体参考my GitHub。大致流程是:

  • CVPPA 数据集处理成 MOT20 数据格式,需要使用 mask2former 生成 det.txt,weak_label 生成 gt.txt。
    MOT20 数据格式参考链接
  • 原 MOT20 数据与 CrowdHuman 数据集进行 mix 得到新的数据集,而本任务只有跟踪数据,如果不进行预处理,格式不一致。首先转 coco 格式生成对应 json 文件,然后进行 mix。

3. 训练

按照 mix mot20 配置进行更改得到 CVPPA 的训练配置文件:

# encoding: utf-8
import os
import random
import torch
import torch.nn as nn
import torch.distributed as dist

from yolox.exp import Exp as MyExp
from yolox.data import get_yolox_datadir

class Exp(MyExp):
    def __init__(self):
        super(Exp, self).__init__()
        self.num_classes = 1
        self.depth = 1.33
        self.width = 1.25
        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
        self.train_ann = "train.json"
        self.val_ann = "train.json"    # change to train.json when running on training set
        self.input_size = (800, 1440)
        self.test_size = (800, 1440)
        self.random_size = (18, 32)
        self.max_epoch = 80
        self.print_interval = 20
        self.eval_interval = 1
        self.test_conf = 0.001
        self.nmsthre = 0.7
        self.no_aug_epochs = 10
        self.basic_lr_per_img = 0.001 / 64.0
        self.warmup_epochs = 1

    def get_data_loader(self, batch_size, is_distributed, no_aug=False):
        from yolox.data import (
            MOTDataset,
            TrainTransform,
            YoloBatchSampler,
            DataLoader,
            InfiniteSampler,
            MosaicDetection,
        )

        dataset = MOTDataset(
            data_dir=os.path.join(get_yolox_datadir(), "mix_mot20_cvppa"),
            json_file=self.train_ann,
            name='',
            img_size=self.input_size,
            preproc=TrainTransform(
                rgb_means=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225),
                max_labels=500,
            ),
        )

        dataset = MosaicDetection(
            dataset,
            mosaic=not no_aug,
            img_size=self.input_size,
            preproc=TrainTransform(
                rgb_means=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225),
                max_labels=1000,
            ),
            degrees=self.degrees,
            translate=self.translate,
            scale=self.scale,
            shear=self.shear,
            perspective=self.perspective,
            enable_mixup=self.enable_mixup,
        )

        self.dataset = dataset

        if is_distributed:
            batch_size = batch_size // dist.get_world_size()

        sampler = InfiniteSampler(
            len(self.dataset), seed=self.seed if self.seed else 0
        )

        batch_sampler = YoloBatchSampler(
            sampler=sampler,
            batch_size=batch_size,
            drop_last=False,
            input_dimension=self.input_size,
            mosaic=not no_aug,
        )

        dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}
        dataloader_kwargs["batch_sampler"] = batch_sampler
        train_loader = DataLoader(self.dataset, **dataloader_kwargs)

        return train_loader

    def get_eval_loader(self, batch_size, is_distributed, testdev=False):
        from yolox.data import MOTDataset, ValTransform

        valdataset = MOTDataset(
            data_dir=os.path.join(get_yolox_datadir(), "MOT20_CVPPA"),
            json_file=self.val_ann,
            img_size=self.test_size,
            name='train',   # change to train when running on training set
            preproc=ValTransform(
                rgb_means=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225),
            ),
        )

        if is_distributed:
            batch_size = batch_size // dist.get_world_size()
            sampler = torch.utils.data.distributed.DistributedSampler(
                valdataset, shuffle=False
            )
        else:
            sampler = torch.utils.data.SequentialSampler(valdataset)

        dataloader_kwargs = {
            "num_workers": self.data_num_workers,
            "pin_memory": True,
            "sampler": sampler,
        }
        dataloader_kwargs["batch_size"] = batch_size
        val_loader = torch.utils.data.DataLoader(valdataset, **dataloader_kwargs)

        return val_loader

    def get_evaluator(self, batch_size, is_distributed, testdev=False):
        from yolox.evaluators import COCOEvaluator

        val_loader = self.get_eval_loader(batch_size, is_distributed, testdev=testdev)
        evaluator = COCOEvaluator(
            dataloader=val_loader,
            img_size=self.test_size,
            confthre=self.test_conf,
            nmsthre=self.nmsthre,
            num_classes=self.num_classes,
            testdev=testdev,
        )
        return evaluator

运行脚本:

CUDA_VISIBLE_DEVICES=0,1,2,3 python3 tools/train.py -f exps/example/mot/yolox_x_mix_det_cvppa.py -d 4 -b 16 --fp16 -o \
-c /data/ChaiJM/Competition/CVPPA-DMOT/Code/ByteTrack-main/YOLOX_outputs/yolox_x_mix_det_cvppa/latest_ckpt.pth.tar \
-expn cvppa_resume  -e 4

4. 推理

首先将 demo_track 修改为推理测试集所有数据:

import argparse
import os
import os.path as osp
import time
import cv2
import torch

from loguru import logger

from yolox.data.data_augment import preproc
from yolox.exp import get_exp
from yolox.utils import fuse_model, get_model_info, postprocess
from yolox.utils.visualize import plot_tracking
from yolox.tracker.byte_tracker import BYTETracker
from yolox.tracking_utils.timer import Timer

IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png", ".tiff"]

def make_parser():
    parser = argparse.ArgumentParser("ByteTrack Demo!")
    parser.add_argument(
        "--demo", default="image", help="demo type, eg. image, video and webcam"
    )
    parser.add_argument("-expn", "--experiment-name", type=str, 
                        default='cvppa-test2')
    parser.add_argument("-n", "--name", type=str, default='yolox-x', help="model name")

    parser.add_argument(
        "--path", default="/data/ChaiJM/Competition/CVPPA-DMOT/Dataset/MOT_CVPPA24_DATA/test/rgb", help="path to images or video sequences"
    )
    parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
    parser.add_argument(
        "--save_result",
        action="store_true",
        help="whether to save the inference result of image/video",
    )

    # exp file
    parser.add_argument(
        "-f",
        "--exp_file",
        default='/data/ChaiJM/Competition/CVPPA-DMOT/Code/ByteTrack-main/exps/example/mot/yolox_x_mix_det_cvppa.py',
        type=str,
        help="pls input your expriment description file",
    )
    parser.add_argument("-c", "--ckpt", default='/data/ChaiJM/Competition/CVPPA-DMOT/Code/ByteTrack-main/YOLOX_outputs/yolox_x_mix_det_cvppa/best_ckpt.pth.tar', type=str, help="ckpt for eval")
    parser.add_argument(
        "--device",
        default="gpu",
        type=str,
        help="device to run our model, can either be cpu or gpu",
    )
    parser.add_argument("--conf", default=0.001, type=float, help="test conf")
    parser.add_argument("--nms", default=0.7, type=float, help="test nms threshold")
    parser.add_argument("--tsize", default=1440, type=int, help="test img size")
    parser.add_argument("--fps", default=30, type=int, help="frame rate (fps)")
    parser.add_argument(
        "--fp16",
        dest="fp16",
        default=False,
        action="store_true",
        help="Adopting mix precision evaluating.",
    )
    parser.add_argument(
        "--fuse",
        dest="fuse",
        default=False,
        action="store_true",
        help="Fuse conv and bn for testing.",
    )
    parser.add_argument(
        "--trt",
        dest="trt",
        default=False,
        action="store_true",
        help="Using TensorRT model for testing.",
    )
    # tracking args
    parser.add_argument("--track_thresh", type=float, default=0.5, help="tracking confidence threshold")
    parser.add_argument("--track_buffer", type=int, default=30, help="the frames for keep lost tracks")
    parser.add_argument("--match_thresh", type=float, default=0.8, help="matching threshold for tracking")
    parser.add_argument(
        "--aspect_ratio_thresh", type=float, default=1.6,
        help="threshold for filtering out boxes of which aspect ratio are above the given value."
    )
    parser.add_argument('--min_box_area', type=float, default=10, help='filter out tiny boxes')
    parser.add_argument("--mot20", dest="mot20", default=False, action="store_true", help="test mot20.")
    return parser

def get_image_list(path):
    image_names = []
    for maindir, subdir, file_name_list in os.walk(path):
        for filename in file_name_list:
            apath = osp.join(maindir, filename)
            ext = osp.splitext(apath)[1]
            if ext in IMAGE_EXT:
                image_names.append(apath)
    return image_names

def write_results(filename, results):
    save_format = '{frame},{id},{x1},{y1},{w},{h},{s}\n'
    with open(filename, 'w') as f:
        for frame_id, tlwhs, track_ids, scores in results:
            for tlwh, track_id, score in zip(tlwhs, track_ids, scores):
                if track_id < 0:
                    continue
                x1, y1, w, h = tlwh
                line = save_format.format(frame=frame_id, id=track_id, x1=round(x1, 1), y1=round(y1, 1), w=round(w, 1), h=round(h, 1), s=round(score, 2))
                f.write(line)
    logger.info('save results to {}'.format(filename))

class Predictor(object):
    def __init__(
        self,
        model,
        exp,
        trt_file=None,
        decoder=None,
        device=torch.device("cpu"),
        fp16=False
    ):
        self.model = model
        self.decoder = decoder
        self.num_classes = exp.num_classes
        self.confthre = exp.test_conf
        self.nmsthre = exp.nmsthre
        self.test_size = exp.test_size
        self.device = device
        self.fp16 = fp16
        if trt_file is not None:
            from torch2trt import TRTModule

            model_trt = TRTModule()
            model_trt.load_state_dict(torch.load(trt_file))

            x = torch.ones((1, 3, exp.test_size[0], exp.test_size[1]), device=device)
            self.model(x)
            self.model = model_trt
        self.rgb_means = (0.485, 0.456, 0.406)
        self.std = (0.229, 0.224, 0.225)

    def inference(self, img, timer):
        img_info = {"id": 0}
        if isinstance(img, str):
            img_info["file_name"] = osp.basename(img)
            img = cv2.imread(img)
        else:
            img_info["file_name"] = None

        height, width = img.shape[:2]
        img_info["height"] = height
        img_info["width"] = width
        img_info["raw_img"] = img

        img, ratio = preproc(img, self.test_size, self.rgb_means, self.std)
        img_info["ratio"] = ratio
        img = torch.from_numpy(img).unsqueeze(0).float().to(self.device)
        if self.fp16:
            img = img.half()  # to FP16

        with torch.no_grad():
            timer.tic()
            outputs = self.model(img)
            if self.decoder is not None:
                outputs = self.decoder(outputs, dtype=outputs.type())
            outputs = postprocess(
                outputs, self.num_classes, self.confthre, self.nmsthre
            )
            #logger.info("Infer time: {:.4f}s".format(time.time() - t0))
        return outputs, img_info

def image_demo(predictor, vis_folder, current_time, args):
    if osp.isdir(args.path):
        files = get_image_list(args.path)
    else:
        files = [args.path]
    files.sort()
    tracker = BYTETracker(args, frame_rate=args.fps)
    timer = Timer()
    results = []

    for frame_id, img_path in enumerate(files, 1):
        outputs, img_info = predictor.inference(img_path, timer)
        if outputs[0] is not None:
            online_targets = tracker.update(outputs[0], [img_info['height'], img_info['width']], exp.test_size)
            online_tlwhs = []
            online_ids = []
            online_scores = []
            for t in online_targets:
                tlwh = t.tlwh
                tid = t.track_id
                vertical = tlwh[2] / tlwh[3] > args.aspect_ratio_thresh
                if tlwh[2] * tlwh[3] > args.min_box_area and not vertical:
                    online_tlwhs.append(tlwh)
                    online_ids.append(tid)
                    online_scores.append(t.score)
                    # save results
                    results.append(
                        f"{frame_id},{tid},{tlwh[0]:.2f},{tlwh[1]:.2f},{tlwh[2]:.2f},{tlwh[3]:.2f},{t.score:.2f}\n"
                    )
            timer.toc()
            online_im = plot_tracking(
                img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id, fps=1. / timer.average_time
            )
        else:
            timer.toc()
            online_im = img_info['raw_img']

        # result_image = predictor.visual(outputs[0], img_info, predictor.confthre)
        if args.save_result:
            timestamp = time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
            save_folder = osp.join(vis_folder, timestamp)
            os.makedirs(save_folder, exist_ok=True)
            cv2.imwrite(osp.join(save_folder, osp.basename(img_path)), online_im)

        if frame_id % 20 == 0:
            logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time)))

        ch = cv2.waitKey(0)
        if ch == 27 or ch == ord("q") or ch == ord("Q"):
            break

    if args.save_result:
        res_file = osp.join(vis_folder, f"{timestamp}.txt")
        with open(res_file, 'w') as f:
            f.writelines(results)
        logger.info(f"save results to {res_file}")

def imageflow_demo(predictor, vis_folder, current_time, args):
    cap = cv2.VideoCapture(args.path if args.demo == "video" else args.camid)
    width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)  # float
    height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)  # float
    fps = cap.get(cv2.CAP_PROP_FPS)
    timestamp = time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
    save_folder = osp.join(vis_folder, timestamp)
    os.makedirs(save_folder, exist_ok=True)
    if args.demo == "video":
        save_path = osp.join(save_folder, args.path.split("/")[-1])
    else:
        save_path = osp.join(save_folder, "camera.mp4")
    logger.info(f"video save_path is {save_path}")
    vid_writer = cv2.VideoWriter(
        save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
    )
    tracker = BYTETracker(args, frame_rate=30)
    timer = Timer()
    frame_id = 0
    results = []
    while True:
        if frame_id % 20 == 0:
            logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time)))
        ret_val, frame = cap.read()
        if ret_val:
            outputs, img_info = predictor.inference(frame, timer)
            if outputs[0] is not None:
                online_targets = tracker.update(outputs[0], [img_info['height'], img_info['width']], exp.test_size)
                online_tlwhs = []
                online_ids = []
                online_scores = []
                for t in online_targets:
                    tlwh = t.tlwh
                    tid = t.track_id
                    vertical = tlwh[2] / tlwh[3] > args.aspect_ratio_thresh
                    if tlwh[2] * tlwh[3] > args.min_box_area and not vertical:
                        online_tlwhs.append(tlwh)
                        online_ids.append(tid)
                        online_scores.append(t.score)
                        results.append(
                            f"{frame_id},{tid},{tlwh[0]:.2f},{tlwh[1]:.2f},{tlwh[2]:.2f},{tlwh[3]:.2f},{t.score:.2f}\n"
                        )
                timer.toc()
                online_im = plot_tracking(
                    img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id + 1, fps=1. / timer.average_time
                )
            else:
                timer.toc()
                online_im = img_info['raw_img']
            if args.save_result:
                vid_writer.write(online_im)
            ch = cv2.waitKey(1)
            if ch == 27 or ch == ord("q") or ch == ord("Q"):
                break
        else:
            break
        frame_id += 1

    if args.save_result:
        res_file = osp.join(vis_folder, f"{timestamp}.txt")
        with open(res_file, 'w') as f:
            f.writelines(results)
        logger.info(f"save results to {res_file}")

def process_image_sequence(predictor, vis_folder, current_time, args, sequence_path):
    files = get_image_list(sequence_path)
    files.sort()
    tracker = BYTETracker(args, frame_rate=args.fps)
    timer = Timer()
    results = []

    for frame_id, img_path in enumerate(files, 1):
        outputs, img_info = predictor.inference(img_path, timer)
        if outputs[0] is not None:
            online_targets = tracker.update(outputs[0], [img_info['height'], img_info['width']], exp.test_size)
            online_tlwhs = []
            online_ids = []
            online_scores = []
            for t in online_targets:
                tlwh = t.tlwh
                tid = t.track_id
                vertical = tlwh[2] / tlwh[3] > args.aspect_ratio_thresh
                if tlwh[2] * tlwh[3] > args.min_box_area and not vertical:
                    online_tlwhs.append(tlwh)
                    online_ids.append(tid)
                    online_scores.append(t.score)
                    results.append(
                        f"{os.path.basename(img_path).split('.')[0]},{tid},{tlwh[0]:.2f},{tlwh[1]:.2f},{tlwh[2]:.2f},{tlwh[3]:.2f},{t.score:.2f}\n"
                    )
            timer.toc()
            online_im = plot_tracking(
                img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id, fps=1. / timer.average_time
            )
        else:
            timer.toc()
            online_im = img_info['raw_img']

        if args.save_result:
            save_folder = osp.join(vis_folder, osp.basename(sequence_path))
            os.makedirs(save_folder, exist_ok=True)
            cv2.imwrite(osp.join(save_folder, osp.basename(img_path)), online_im)

        if frame_id % 20 == 0:
            logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time)))

        ch = cv2.waitKey(0)
        if ch == 27 or ch == ord("q") or ch == ord("Q"):
            break

    if args.save_result:
        res_file = osp.join(vis_folder, f"{osp.basename(sequence_path)}.txt")
        with open(res_file, 'w') as f:
            f.writelines(results)
        logger.info(f"save results to {res_file}")

def main(exp, args):
    if not args.experiment_name:
        args.experiment_name = exp.exp_name

    output_dir = osp.join(exp.output_dir, args.experiment_name)
    os.makedirs(output_dir, exist_ok=True)

    if args.save_result:
        vis_folder = osp.join(output_dir, "track_vis")
        os.makedirs(vis_folder, exist_ok=True)

    if args.trt:
        args.device = "gpu"
    args.device = torch.device("cuda" if args.device == "gpu" else "cpu")

    logger.info("Args: {}".format(args))

    if args.conf is not None:
        exp.test_conf = args.conf
    if args.nms is not None:
        exp.nmsthre = args.nms
    if args.tsize is not None:
        exp.test_size = (args.tsize, args.tsize)

    model = exp.get_model().to(args.device)
    logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
    model.eval()

    if not args.trt:
        if args.ckpt is None:
            ckpt_file = osp.join(output_dir, "best_ckpt.pth.tar")
        else:
            ckpt_file = args.ckpt
        logger.info("loading checkpoint")
        ckpt = torch.load(ckpt_file, map_location="cpu")
        model.load_state_dict(ckpt["model"])
        logger.info("loaded checkpoint done.")

    if args.fuse:
        logger.info("\tFusing model...")
        model = fuse_model(model)

    if args.fp16:
        model = model.half()

    if args.trt:
        trt_file = osp.join(output_dir, "model_trt.pth")
        assert osp.exists(trt_file), "TensorRT model is not found!\n Run python3 tools/trt.py first!"
        model.head.decode_in_inference = False
        decoder = model.head.decode_outputs
        logger.info("Using TensorRT to inference")
    else:
        trt_file = None
        decoder = None

    predictor = Predictor(model, exp, trt_file, decoder, args.device, args.fp16)
    current_time = time.localtime()

    # 遍历路径下的所有子目录,每个子目录作为一个图片序列处理
    for sequence in sorted(os.listdir(args.path)):
        sequence_path = osp.join(args.path, sequence)
        if osp.isdir(sequence_path):
            logger.info(f"Processing sequence: {sequence_path}")
            process_image_sequence(predictor, vis_folder, current_time, args, sequence_path)

if __name__ == "__main__":
    args = make_parser().parse_args()
    exp = get_exp(args.exp_file, args.name)

    main(exp, args)

运行脚本:

CUDA_VISIBLE_DEVICES=7 python tools/demo_track_cvppa.py --save_result -expn cvppa-test5 \
-c pth

5. 后处理

官方有提到前景与背景,可以通过深度图确定前景或者背景的范围,同时提供了 mask2former 的分割结果,所以可以将处于背景部分的预测结果去除,提高准确率。具体思路如下:
对于每一个图片序列会生成一个如下结构的txt

1600938533420040,1,699.95,187.54,18.34,66.91,0.82
1600938533486674,1,692.61,183.40,20.66,74.16,0.90
1600938533553308,1,683.85,181.84,23.04,79.69,0.92
1600938533553308,2,698.09,235.79,16.88,81.49,0.92
1600938533619944,1,676.24,176.31,25.18,81.87,0.93
1600938533619944,2,692.14,229.69,18.18,85.75,0.92

其中,第一列表示图片id,第二列表示目标id,而后是xyweightheight,置信度。

同时对一个图片序列的每一张图,存在一个 mask2former 的 pkl 结果以及对应的深度图,通过深度图可以大致划分前景背景,最终的结果只需要前景的目标,所以可以使用深度图进行过滤。

实现的基本思路是:对于一个图片序列的每张图的每个目标id,与对应 pkl 文件的每一个目标id计算IoU,如果IoU大于 0.8,则认为是同一个目标,然后根据 pkl 中id的二值掩码,统计其对应像素位置深度大于 1200 的像素数量,如果前景深度大于 1200 像素数量占比超过 0.5,则认为该目标属于背景,对其过滤,即删除txt对应检测结果。

而后考虑到 mask2former 分割结果可能更好,同时如果一个 bbox 与 mask2former 所有目标 iou 都小于 0.4,那可以认为是错检,所以最终的后处理代码如下:

import os
import pickle
import cv2
import numpy as np
import mmcv

TXT_ROOT = '/data/ChaiJM/Competition/CVPPA-DMOT/Code/ByteTrack-main/results/4ep'
DEPTH_ROOT = '/data/ChaiJM/Competition/CVPPA-DMOT/Dataset/MOT_CVPPA24_DATA/test/depth'
PKL_ROOT = '/data/ChaiJM/Competition/CVPPA-DMOT/Dataset/MOT_CVPPA24_DATA/test/mask2former_output'
DEPTH_Thr = 1200
Valid_Rate = 0.5
SAVE_ROOT = TXT_ROOT + '_filter_mask2former_1200_0.5_0.3'
os.makedirs(SAVE_ROOT, exist_ok=True)
MAX_IoU_Thr = 0.5
MIN_IoU_Thr = 0.3

def load_txt(file_path):
    """加载txt文件"""
    with open(file_path, 'r') as f:
        lines = f.readlines()
    detections = []
    for line in lines:
        parts = line.strip().split(',')
        detection = {
            'image_id': parts[0],
            'target_id': parts[1],
            'x': float(parts[2]),
            'y': float(parts[3]),
            'w': float(parts[4]),
            'h': float(parts[5]),
            'confidence': float(parts[6])
        }
        detections.append(detection)
    return detections

def save_txt(detections, file_path):
    """保存txt文件"""
    with open(file_path, 'w') as f:
        for det in detections:
            line = f"{det['image_id']},{det['target_id']},{det['x']},{det['y']},{det['w']},{det['h']},{det['confidence']}\n"
            f.write(line)

def load_pkl(pkl_file_path):
    """加载pkl文件"""
    with open(pkl_file_path, 'rb') as f:
        data = pickle.load(f)
    return data

def load_depth_map(depth_file_path):
    """加载深度图"""
    depth_map = cv2.imread(depth_file_path, cv2.IMREAD_UNCHANGED)
    return depth_map

def calculate_iou(boxA, boxB):
    """计算两个bbox之间的IoU"""
    xA = max(boxA['x'], boxB['x'])
    yA = max(boxA['y'], boxB['y'])
    xB = min(boxA['x'] + boxA['w'], boxB['x'] + boxB['w'])
    yB = min(boxA['y'] + boxA['h'], boxB['y'] + boxB['h'])

    interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
    boxAArea = boxA['w'] * boxA['h']
    boxBArea = boxB['w'] * boxB['h']

    iou = interArea / float(boxAArea + boxBArea - interArea)
    return iou

def is_background(mask, depth_map):
    """判断是否属于背景"""
    # mask_pixels = mask > 0
    # depth_pixels = depth_map[mask_pixels]
    # foreground_pixels = depth_pixels > 1200
    # return np.sum(foreground_pixels) / np.sum(mask_pixels) < 0.5
    indices = np.where(mask > 0)
    valid_indices_number = np.logical_and(depth_map[indices] <= DEPTH_Thr, depth_map[indices] > 0)
    valid_indices_number = np.count_nonzero(valid_indices_number)
    if valid_indices_number / len(indices[0]) > Valid_Rate:
        # it's not bg object
        return False
    else:
        # bg object
        return True

def process_one_scene(scene_id):
    """遍历目录下的所有图片序列,处理每一张图"""
    scene_id = str(scene_id)
    txt_path = os.path.join(TXT_ROOT, scene_id + '.txt')
    new_txt_path = os.path.join(SAVE_ROOT, scene_id + '.txt')
    detections = load_txt(txt_path)
    
    filtered_detections = []
    for det in detections:
        image_id = det['image_id']
        pkl_path = os.path.join(PKL_ROOT, scene_id, image_id + '.pkl')
        depth_path = os.path.join(DEPTH_ROOT, scene_id, image_id + '.tiff')
        pkl_data = load_pkl(pkl_path)
        depth_map = load_depth_map(depth_path)
        boxA = {'x': det['x'], 'y': det['y'], 'w': det['w'], 'h': det['h']}
        keep = True
        max_iou = 0.0
        for instance_id, instance_data in pkl_data.items():
            boxB = {'x': instance_data['bbox'][0], 'y': instance_data['bbox'][1], 'w': instance_data['bbox'][2], 'h': instance_data['bbox'][3]}
            iou = calculate_iou(boxA, boxB)
            max_iou = max(max_iou, iou)
            
            if iou > MAX_IoU_Thr:
                det['x'], det['y'], det['w'], det['h'] = instance_data['bbox'][0], instance_data['bbox'][1], instance_data['bbox'][2], instance_data['bbox'][3]
                mask = instance_data['instance_mask']
                if is_background(mask, depth_map):
                    keep = False
                    break

        if max_iou < MIN_IoU_Thr:
            keep = False

        if keep:
            filtered_detections.append(det)

    save_txt(filtered_detections, new_txt_path)

if __name__ == "__main__":
    scene_list = os.listdir(DEPTH_ROOT)
    mmcv.track_parallel_progress(process_one_scene, scene_list, 16)

其实还考虑到,预测结果存在负值,所以可以将负值变为 0,但同时应该也要对宽高进行处理,起初没有注意到需要对宽高处理,所以结果会变差,但也不确定处理后是不是会变好。

三、经验教训

刚开始一直在提高准确率,还说怎么模型越学越差,最后发现指标是HOTA。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值