文章目录
- 数据集描述
- 数据集下载
- 生成跟踪结果
- 精度评定
- 结果展示
数据集描述
DanceTrack是一个基准数据集,用于在统一的外观和不同的运动中跟踪多个对象。
DanceTrack提供框和身份注释。它包含100个视频,40个用于训练(注释公共),25个用于验证(注释公共),35个用于测试(注释非公共)。

数据集下载
https://github.com/DanceTrack/DanceTrack
生成跟踪结果
以我的代码为例
from infer_picodet import ObjectDetector
import sort
import os
import glob
import argparse
import tqdm
import cv2
import re
from boxmot import DeepOCSORT, OCSORT
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--det_model_dir",
default="/training/SAMScore/PP_PicoDet_V2_S_Pedestrian_320x320_infer",
help="Path of PaddleDetection model directory")
parser.add_argument(
"--image", default="/training/datasets/orchard_imgs/1722408726.427306.jpg", help="Path of test image file.")
parser.add_argument(
"--device",
type=str,
default='gpu',
help="Type of inference device, support 'kunlunxin', 'cpu' or 'gpu'.")
parser.add_argument(
"--use_trt",
type=ast.literal_eval,
default=False,
help="Wether to use tensorrt.")
return parser.parse_args()
def mkidr(path):
if not os.path.exists(path):
os.mkdir(path)
# 定义一个函数,从文件名中提取数字
def extract_number(filename):
return int(re.search(r'\d+', filename).group())
def write_results_no_score(filename, results):
save_format = '{frame},{id},{x1},{y1},{w},{h},-1,-1,-1,-1\n'
with open(filename, 'w') as f:
for frame_id, tlwhs, track_ids in results:
for tlwh, track_id in zip(tlwhs, track_ids):
if track_id < 0:
continue
x1, y1, w, h = tlwh
line = save_format.format(frame=frame_id, id=track_id, x1=round(x1, 1), y1=round(y1, 1), w=round(w, 1), h=round(h, 1))
f.write(line)
if __name__ == "__main__":
args = parse_arguments()
output = "boxmot_test"
track_max_age = 45
track_min_hits= 15
track_iou_threshold = 0.3
track_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0),
(0, 255, 255), (255, 0, 255), (255, 127, 255),
(127, 0, 255), (127, 0, 127)]
mkidr(output)
od = ObjectDetector(args)
foldRoot = "/training/datasets/DanceTrack/test"
foldList = glob.glob(f"{foldRoot}/*")
# print(foldList)
for foldPath in foldList:
foldName = foldPath.split("/")[-1]
print(f"foldName : {foldName}")
imgList = glob.glob(f"{foldPath}/img1/*.jpg")
imgList = sorted(imgList, key=lambda x: int(re.search(r'\d+', os.path.basename(x)).group()))
tracker = OCSORT()
file = open(f"{output}/{foldName}.txt", "w")
# 设置一个用来存放对象的字典
results = []
for imgPath in tqdm.tqdm(imgList):
frame = cv2.imread(imgPath)
detections = od.infer_one_img(frame)
baseName = os.path.basename(imgPath)
frameNumber = re.search(r'\d+', baseName).group()
# print(f"{baseName}, {frameNumber}")
outputPath = os.path.join(output, baseName)
# Update object tracker
online_targets = tracker.update(detections, frame)
# cv2.imwrite(outputPath, frame)
for a in tracker.active_tracks:
if a.history_observations:
if len(a.history_observations) > 2:
box = a.history_observations[-1]
# img = self.plot_box_on_img(img, box, a.conf, a.cls, a.id)
tlwh = [box[0], box[1], box[2] - box[0], box[3] - box[1]]
if tlwh[2] * tlwh[3] > 100:
line = f"{frameNumber},{a.id},{int(tlwh[0])},{int(tlwh[1])},{int(tlwh[2])},{int(tlwh[3])},-1,-1,-1"
file.write(line + "\n")
file.close()
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
- 45.
- 46.
- 47.
- 48.
- 49.
- 50.
- 51.
- 52.
- 53.
- 54.
- 55.
- 56.
- 57.
- 58.
- 59.
- 60.
- 61.
- 62.
- 63.
- 64.
- 65.
- 66.
- 67.
- 68.
- 69.
- 70.
- 71.
- 72.
- 73.
- 74.
- 75.
- 76.
- 77.
- 78.
- 79.
- 80.
- 81.
- 82.
- 83.
- 84.
- 85.
- 86.
- 87.
- 88.
- 89.
- 90.
- 91.
- 92.
- 93.
- 94.
- 95.
- 96.
- 97.
- 98.
- 99.
- 100.
- 101.
- 102.
- 103.
PS: 请根据自己的代码进行修改
结果目录为
精度评定
python3 TrackEval/scripts/run_mot_challenge.py --SPLIT_TO_EVAL val --METRICS HOTA CLEAR Identity --GT_FOLDER dancetrack/val --SEQMAP_FILE dancetrack/val_seqmap.txt --SKIP_SPLIT_FOL True --TRACKERS_TO_EVAL '' --TRACKER_SUB_FOLDER '' --USE_PARALLEL True --NUM_PARALLEL_CORES 8 --PLOT_CURVES False --TRACKERS_FOLDER val/TRACKER_NAME
- 1.
PS: 需要修改GT的目录和跟踪结果目录
结果展示




4089

被折叠的 条评论
为什么被折叠?



