https://github.com/Sthyao/ByteTrack_simplified
已将相关代码上传至github。原参考blog:
部署并应用ByteTrack实现目标跟踪_yolov8集成bytetrack-CSDN博客
前言:
当你使用https://github.com/ifzhang/ByteTrack原作者给出的demo时,参见在windows平台上部署bytetrack_bytetrack 安装 python setup.py-CSDN博客
笔者的需求是对ByteTrack进行部署,使用其0shot能力来实现某个项目的一部分。因此源代码中的评估、训练等模块对笔者来说是没有用的。所以本文记录了从源码中拆解“多目标追踪任务”的纯任务代码。以达到
1、减少复杂的环境依赖,比如yolox;使用ultralytics来代替。删除lap依赖,Cpython依赖。
2、告诉读者核心代码实现逻辑和功能逻辑
源码:
在源代码的运行脚本中:
python tools/demo_track.py video -f exps/example/mot/yolox_nano_mix_det.py -c pretrained/bytetrack_nano_mot17.pth.tar --fuse --save_result
很明显注意到入口是tools/demo_track.py,现在我们打开demo_track.py看到头文件引用部分:
import argparse
import os
import os.path as osp
import time
import cv2
import torch
from loguru import logger
from yolox.data.data_augment import preproc
from yolox.exp import get_exp
from yolox.utils import fuse_model, get_model_info, postprocess
from yolox.utils.visualize import plot_tracking
from yolox.tracker.byte_tracker import BYTETracker
from yolox.tracking_utils.timer import Timer
其中,yolox模块是调用yolox的主入口,不必关注具体实现和功能。有兴趣的读者可以关注get_exp,是yolox模型的推理类。很显然,我们使用新版的yolo时,都可以使用ultralytics集成的yolo推理类。
而`from yolox.tracker.byte_tracker import BYTETracker`则是我们关注的核心:请看调用部分:
tracker = BYTETracker(args, frame_rate=30)
timer = Timer()
frame_id = 0
results = []
while True:
if frame_id % 20 == 0:
logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time)))
ret_val, frame = cap.read()
if ret_val:
outputs, img_info = predictor.inference(frame, timer)
if outputs[0] is not None:
online_targets = tracker.update(outputs[0], [img_info['height'], img_info['width']], exp.test_size)
online_tlwhs = []
online_ids = []
online_scores = []
for t in online_targets:
tlwh = t.tlwh
tid = t.track_id
vertical = tlwh[2] / tlwh[3] > args.aspect_ratio_thresh
if tlwh[2] * tlwh[3] > args.min_box_area and not vertical:
online_tlwhs.append(tlwh)
online_ids.append(tid)
online_scores.append(t.score)
results.append(
f"{frame_id},{tid},{tlwh[0]:.2f},{tlwh[1]:.2f},{tlwh[2]:.2f},{tlwh[3]:.2f},{t.score:.2f},-1,-1,-1\n"
)
timer.toc()
online_im = plot_tracking(
img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id + 1, fps=1. / timer.average_time
)
else:
timer.toc()
online_im = img_info['raw_img']
if args.save_result:
vid_writer.write(online_im)
ch = cv2.waitKey(1)
if ch == 27 or ch == ord("q") or ch == ord("Q"):
break
else:
break
frame_id += 1
首先是tracker = BYTETracker(args, frame_rate=30)实例化BYTETracker。然后实现“多目标追踪任务”的关键就是
online_targets = tracker.update(outputs[0], [img_info['height'], img_info['width']], exp.test_size)
在一个实例化的BYTETracker中,使用update方法接受yolo模型的输出,返回一个online_targets的list,这个list中就是保存了一系列实例化的类(list[STrack]),每一个实例化的类都是一个追踪目标。在
tlwh = t.tlwh
tid = t.track_id
则可以很明显看出,这个实例化的类(STrack)中的核心是两个属性:track_id和tlwh,他们代表的就是追踪对象的id和追踪对象的位置tlwh(和原始尺寸一致)。
现在我们的思路就非常明确:ByteTrack的功能实现是:loop 【视频帧->yolo推理结果->ByteTrack.updata->每一帧的追踪对象】,如果使用ultralytics来进行推理,则只需要分离ByteTrack源码中对于ByteTrack类的定义即可。
依赖改写:
在源github中,ByteTrack的源码位置是:yolox/tracker/下的
basetrack.py
byte_tracker.py
kalman_filter.py
matching.py
我们可以自上而下,来分别看这几个文件。首先是ByteTrack的定义类byte_tracker.py和父类basetrack.py
#basetrack.py
class TrackState(object):
New = 0
Tracked = 1
Lost = 2
Removed = 3
class BaseTrack(object):
##省略##
#byte_tracker.py
class STrack(BaseTrack):
##省略##
class BYTETracker(object):
def __init__(self, args, frame_rate=30):
self.tracked_stracks = [] # type: list[STrack]
self.lost_stracks = [] # type: list[STrack]
self.removed_stracks = [] # type: list[STrack]
self.frame_id = 0
self.args = args
#self.det_thresh = args.track_thresh
self.det_thresh = args.track_thresh + 0.1
self.buffer_size = int(frame_rate / 30.0 * args.track_buffer)
self.max_time_lost = self.buffer_size
self.kalman_filter = KalmanFilter()
##省略##
第一是整体代码框架:在BYTETracker类中,维护的数据类型是list[STrack],上文中已经说过,每一个STrack代表的是一个被捕获的对象,STrack对象则是继承了BaseTrack,而同时TrackState则规定了STrack对象的基本状态(active 或者 Lost)。
第二是核心方法:BYTETracker.update。上文中已经说过BYTETracker.update负责将追踪对象更新位置,判断是否丢失(删除),最终return回来的就是当前画面中的追踪对象,详细代码实现比较复杂,好在源代码给了充分的注释:(太长可以不看)
def update(self, output_results, img_info, img_size):
self.frame_id += 1
activated_starcks = []
refind_stracks = []
lost_stracks = []
removed_stracks = []
##省略##
''' Add newly detected tracklets to tracked_stracks'''
##省略##
''' Step 2: First association, with high score detection boxes'''
##省略##
''' Step 3: Second association, with low score detection boxes'''
# association the untrack to the low score detections
##省略##
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
##省略##
""" Step 4: Init new stracks"""
##省略##
""" Step 5: Update state"""
##省略##
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
self.lost_stracks.extend(lost_stracks)
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
self.removed_stracks.extend(removed_stracks)
self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
# get scores of lost tracks
output_stracks = [track for track in self.tracked_stracks if track.is_activated]
return output_stracks
大概的处理逻辑就是,捕捉新的yolo对象,先判断是否置信度足够高,将这些对象和原有对象关联(如果关联不上就新建),然后回收低置信度对象,和原有对象关联(低置信度对象不会生成新对象),将那些丢失对象状态重置。最后,根据前面的判断逻辑下重置的state,更新list【STrack】。
(原作者的代码相对很好懂)
我们进入自上而下的第二部分:现在我们已经梳理了BYTETracker类的实现逻辑。源代码中最麻烦的部分是作者使用的lap和cpython依赖。现在我们可以定位到:
#matching.py
import cv2
import numpy as np
import scipy
import lap
from scipy.spatial.distance import cdist
from cython_bbox import bbox_overlaps as bbox_ious
from yolox.tracker import kalman_filter
import time
##省略##
def linear_assignment(cost_matrix, thresh):
if cost_matrix.size == 0:
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
matches, unmatched_a, unmatched_b = [], [], []
cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
for ix, mx in enumerate(x):
if mx >= 0:
matches.append([ix, mx])
unmatched_a = np.where(x < 0)[0]
unmatched_b = np.where(y < 0)[0]
matches = np.asarray(matches)
return matches, unmatched_a, unmatched_b
##省略##
def ious(atlbrs, btlbrs):
"""
Compute cost based on IoU
:type atlbrs: list[tlbr] | np.ndarray
:type atlbrs: list[tlbr] | np.ndarray
:rtype ious np.ndarray
"""
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float)
if ious.size == 0:
return ious
ious = bbox_ious(
np.ascontiguousarray(atlbrs, dtype=np.float),
np.ascontiguousarray(btlbrs, dtype=np.float)
)
return ious
cython.bbox_ious和lap.lapjv分别是cpython和lap的唯一调用。可见BYTETrack和cpython,lap的耦合度其实是非常低的!
cython.bbox_ious:一个很简单的计算ious的逻辑,我们可以很方便改写
lap.lapjv:匈牙利算法解矩阵最短路径线性规划:scipy.optimize.linear_sum_assignment也可以丝滑解决,并且,scipy.optimize.linear_sum_assignment仅仅是对lapjv做的api封装。
我们仅需重新编写两个依赖,即可实现一样的效果:
import numpy as np
from scipy.optimize import linear_sum_assignment
def lapjv(cost_matrix, extend_cost=True, cost_limit=np.inf):
"""
使用 SciPy 的 linear_sum_assignment 实现类似于 lap.lapjv 的功能。
参数:
- cost_matrix (2D array): 成本矩阵。
- extend_cost (bool): 是否扩展成本矩阵以允许未匹配的分配。
- cost_limit (float): 每个分配的成本上限,超过该值的分配将被视为未匹配。
返回:
- total_cost (float): 符合阈值的总匹配成本。
- row_ind (1D array): 匹配的行索引。
- col_ind (1D array): 匹配的列索引。
"""
num_rows, num_cols = cost_matrix.shape
orig_num_rows, orig_num_cols = num_rows, num_cols
if extend_cost:
# 扩展成本矩阵到方阵
size = max(num_rows, num_cols)
if num_rows != num_cols:
extended_cost_matrix = np.full((size, size), fill_value=cost_limit, dtype=cost_matrix.dtype)
extended_cost_matrix[:num_rows, :num_cols] = cost_matrix
else:
extended_cost_matrix = cost_matrix.copy()
else:
extended_cost_matrix = cost_matrix.copy()
# 使用匈牙利算法进行匹配
row_ind, col_ind = linear_sum_assignment(extended_cost_matrix)
# 如果扩展了成本矩阵,过滤掉虚拟匹配
if extend_cost:
valid_indices = (row_ind < orig_num_rows) & (col_ind < orig_num_cols)
row_ind = row_ind[valid_indices]
col_ind = col_ind[valid_indices]
# 应用成本阈值
mask = cost_matrix[row_ind, col_ind] <= cost_limit
filtered_row_ind = row_ind[mask]
filtered_col_ind = col_ind[mask]
total_cost = cost_matrix[filtered_row_ind, filtered_col_ind].sum()
# 创建完整的匹配数组,未匹配的用 -1 表示
x = -1 * np.ones(orig_num_rows, dtype=int)
y = -1 * np.ones(orig_num_cols, dtype=int)
x[filtered_row_ind] = filtered_col_ind
y[filtered_col_ind] = filtered_row_ind
return total_cost, x, y
def bbox_overlaps_python(boxes, query_boxes):
"""
Parameters
----------
boxes: (N, 4) 格式的边界框数组 [x1, y1, x2, y2]
query_boxes: (K, 4) 格式的查询框数组 [x1, y1, x2, y2]
Returns
-------
overlaps: (N, K) IoU矩阵
"""
N = boxes.shape[0]
K = query_boxes.shape[0]
overlaps = np.zeros((N, K), dtype=np.float32)
for k in range(K):
query_box_area = (
(query_boxes[k, 2] - query_boxes[k, 0]) *
(query_boxes[k, 3] - query_boxes[k, 1])
)
for n in range(N):
box_area = (
(boxes[n, 2] - boxes[n, 0]) *
(boxes[n, 3] - boxes[n, 1])
)
# 计算交集区域
ix1 = max(boxes[n, 0], query_boxes[k, 0])
iy1 = max(boxes[n, 1], query_boxes[k, 1])
ix2 = min(boxes[n, 2], query_boxes[k, 2])
iy2 = min(boxes[n, 3], query_boxes[k, 3])
# 如果没有重叠区域
if ix2 < ix1 or iy2 < iy1:
continue
# 计算交集面积
intersection = (ix2 - ix1) * (iy2 - iy1)
# 计算IoU
union = box_area + query_box_area - intersection
overlaps[n, k] = intersection / union
return overlaps
总结与代码示例:
来亲自做个尝试,把demo跑通之后再看看代码。本文所有的代码已经上传到该github仓库。