ByteTrack算法详解

https://github.com/Sthyao/ByteTrack_simplified

已将相关代码上传至github。原参考blog:

部署并应用ByteTrack实现目标跟踪_yolov8集成bytetrack-CSDN博客

前言:

当你使用https://github.com/ifzhang/ByteTrack原作者给出的demo时,参见在windows平台上部署bytetrack_bytetrack 安装 python setup.py-CSDN博客

笔者的需求是对ByteTrack进行部署,使用其0shot能力来实现某个项目的一部分。因此源代码中的评估、训练等模块对笔者来说是没有用的。所以本文记录了从源码中拆解“多目标追踪任务”的纯任务代码。以达到

1、减少复杂的环境依赖,比如yolox;使用ultralytics来代替。删除lap依赖,Cpython依赖。

2、告诉读者核心代码实现逻辑和功能逻辑

源码:

在源代码的运行脚本中:

python tools/demo_track.py video -f exps/example/mot/yolox_nano_mix_det.py -c pretrained/bytetrack_nano_mot17.pth.tar --fuse --save_result

很明显注意到入口是tools/demo_track.py,现在我们打开demo_track.py看到头文件引用部分:

import argparse
import os
import os.path as osp
import time
import cv2
import torch

from loguru import logger

from yolox.data.data_augment import preproc
from yolox.exp import get_exp
from yolox.utils import fuse_model, get_model_info, postprocess
from yolox.utils.visualize import plot_tracking
from yolox.tracker.byte_tracker import BYTETracker
from yolox.tracking_utils.timer import Timer

其中,yolox模块是调用yolox的主入口,不必关注具体实现和功能。有兴趣的读者可以关注get_exp,是yolox模型的推理类。很显然,我们使用新版的yolo时,都可以使用ultralytics集成的yolo推理类。

而`from yolox.tracker.byte_tracker import BYTETracker`则是我们关注的核心:请看调用部分:

tracker = BYTETracker(args, frame_rate=30)
    timer = Timer()
    frame_id = 0
    results = []
    while True:
        if frame_id % 20 == 0:
            logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time)))
        ret_val, frame = cap.read()
        if ret_val:
            outputs, img_info = predictor.inference(frame, timer)
            if outputs[0] is not None:
                online_targets = tracker.update(outputs[0], [img_info['height'], img_info['width']], exp.test_size)
                online_tlwhs = []
                online_ids = []
                online_scores = []
                for t in online_targets:
                    tlwh = t.tlwh
                    tid = t.track_id
                    vertical = tlwh[2] / tlwh[3] > args.aspect_ratio_thresh
                    if tlwh[2] * tlwh[3] > args.min_box_area and not vertical:
                        online_tlwhs.append(tlwh)
                        online_ids.append(tid)
                        online_scores.append(t.score)
                        results.append(
                            f"{frame_id},{tid},{tlwh[0]:.2f},{tlwh[1]:.2f},{tlwh[2]:.2f},{tlwh[3]:.2f},{t.score:.2f},-1,-1,-1\n"
                        )
                timer.toc()
                online_im = plot_tracking(
                    img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id + 1, fps=1. / timer.average_time
                )
            else:
                timer.toc()
                online_im = img_info['raw_img']
            if args.save_result:
                vid_writer.write(online_im)
            ch = cv2.waitKey(1)
            if ch == 27 or ch == ord("q") or ch == ord("Q"):
                break
        else:
            break
        frame_id += 1

首先是tracker = BYTETracker(args, frame_rate=30)实例化BYTETracker。然后实现“多目标追踪任务”的关键就是

online_targets = tracker.update(outputs[0], [img_info['height'], img_info['width']], exp.test_size)
在一个实例化的BYTETracker中,使用update方法接受yolo模型的输出,返回一个online_targets的list,这个list中就是保存了一系列实例化的类(list[STrack]),每一个实例化的类都是一个追踪目标。在

tlwh = t.tlwh
tid = t.track_id

则可以很明显看出,这个实例化的类(STrack)中的核心是两个属性:track_id和tlwh,他们代表的就是追踪对象的id和追踪对象的位置tlwh(和原始尺寸一致)。

现在我们的思路就非常明确:ByteTrack的功能实现是:loop 【视频帧->yolo推理结果->ByteTrack.updata->每一帧的追踪对象】,如果使用ultralytics来进行推理,则只需要分离ByteTrack源码中对于ByteTrack类的定义即可。

依赖改写:

在源github中,ByteTrack的源码位置是:yolox/tracker/下的

basetrack.py

byte_tracker.py

kalman_filter.py

matching.py

我们可以自上而下,来分别看这几个文件。首先是ByteTrack的定义类byte_tracker.py和父类basetrack.py

#basetrack.py

class TrackState(object):
    New = 0
    Tracked = 1
    Lost = 2
    Removed = 3

class BaseTrack(object):

##省略##

#byte_tracker.py

class STrack(BaseTrack):
##省略##

class BYTETracker(object):
    def __init__(self, args, frame_rate=30):
        self.tracked_stracks = []  # type: list[STrack]
        self.lost_stracks = []  # type: list[STrack]
        self.removed_stracks = []  # type: list[STrack]

        self.frame_id = 0
        self.args = args
        #self.det_thresh = args.track_thresh
        self.det_thresh = args.track_thresh + 0.1
        self.buffer_size = int(frame_rate / 30.0 * args.track_buffer)
        self.max_time_lost = self.buffer_size
        self.kalman_filter = KalmanFilter()
##省略##


第一是整体代码框架:在BYTETracker类中,维护的数据类型是list[STrack],上文中已经说过,每一个STrack代表的是一个被捕获的对象,STrack对象则是继承了BaseTrack,而同时TrackState则规定了STrack对象的基本状态(active 或者 Lost)。

第二是核心方法:BYTETracker.update。上文中已经说过BYTETracker.update负责将追踪对象更新位置,判断是否丢失(删除),最终return回来的就是当前画面中的追踪对象,详细代码实现比较复杂,好在源代码给了充分的注释:(太长可以不看)

def update(self, output_results, img_info, img_size):
        self.frame_id += 1
        activated_starcks = []
        refind_stracks = []
        lost_stracks = []
        removed_stracks = []

        ##省略##

        ''' Add newly detected tracklets to tracked_stracks'''
        ##省略##

        ''' Step 2: First association, with high score detection boxes'''
        ##省略##

        ''' Step 3: Second association, with low score detection boxes'''
        # association the untrack to the low score detections
        ##省略##

        '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
        ##省略##

        """ Step 4: Init new stracks"""
        ##省略##
        """ Step 5: Update state"""
        ##省略##

        self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
        self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
        self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
        self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
        self.lost_stracks.extend(lost_stracks)
        self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
        self.removed_stracks.extend(removed_stracks)
        self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
        # get scores of lost tracks
        output_stracks = [track for track in self.tracked_stracks if track.is_activated]

        return output_stracks

大概的处理逻辑就是,捕捉新的yolo对象,先判断是否置信度足够高,将这些对象和原有对象关联(如果关联不上就新建),然后回收低置信度对象,和原有对象关联(低置信度对象不会生成新对象),将那些丢失对象状态重置。最后,根据前面的判断逻辑下重置的state,更新list【STrack】。

(原作者的代码相对很好懂)

我们进入自上而下的第二部分:现在我们已经梳理了BYTETracker类的实现逻辑。源代码中最麻烦的部分是作者使用的lap和cpython依赖。现在我们可以定位到:

#matching.py

import cv2
import numpy as np
import scipy
import lap
from scipy.spatial.distance import cdist

from cython_bbox import bbox_overlaps as bbox_ious
from yolox.tracker import kalman_filter
import time

##省略##
def linear_assignment(cost_matrix, thresh):
    if cost_matrix.size == 0:
        return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
    matches, unmatched_a, unmatched_b = [], [], []
    cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
    for ix, mx in enumerate(x):
        if mx >= 0:
            matches.append([ix, mx])
    unmatched_a = np.where(x < 0)[0]
    unmatched_b = np.where(y < 0)[0]
    matches = np.asarray(matches)
    return matches, unmatched_a, unmatched_b
##省略##
def ious(atlbrs, btlbrs):
    """
    Compute cost based on IoU
    :type atlbrs: list[tlbr] | np.ndarray
    :type atlbrs: list[tlbr] | np.ndarray

    :rtype ious np.ndarray
    """
    ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float)
    if ious.size == 0:
        return ious

    ious = bbox_ious(
        np.ascontiguousarray(atlbrs, dtype=np.float),
        np.ascontiguousarray(btlbrs, dtype=np.float)
    )

    return ious

cython.bbox_ious和lap.lapjv分别是cpython和lap的唯一调用。可见BYTETrack和cpython,lap的耦合度其实是非常低的!
cython.bbox_ious:一个很简单的计算ious的逻辑,我们可以很方便改写

lap.lapjv:匈牙利算法解矩阵最短路径线性规划:scipy.optimize.linear_sum_assignment也可以丝滑解决,并且,scipy.optimize.linear_sum_assignment仅仅是对lapjv做的api封装。

我们仅需重新编写两个依赖,即可实现一样的效果:

import numpy as np
from scipy.optimize import linear_sum_assignment

def lapjv(cost_matrix, extend_cost=True, cost_limit=np.inf):
    """
    使用 SciPy 的 linear_sum_assignment 实现类似于 lap.lapjv 的功能。
    
    参数:
    - cost_matrix (2D array): 成本矩阵。
    - extend_cost (bool): 是否扩展成本矩阵以允许未匹配的分配。
    - cost_limit (float): 每个分配的成本上限,超过该值的分配将被视为未匹配。
    
    返回:
    - total_cost (float): 符合阈值的总匹配成本。
    - row_ind (1D array): 匹配的行索引。
    - col_ind (1D array): 匹配的列索引。
    """
    num_rows, num_cols = cost_matrix.shape
    orig_num_rows, orig_num_cols = num_rows, num_cols

    if extend_cost:
        # 扩展成本矩阵到方阵
        size = max(num_rows, num_cols)
        if num_rows != num_cols:
            extended_cost_matrix = np.full((size, size), fill_value=cost_limit, dtype=cost_matrix.dtype)
            extended_cost_matrix[:num_rows, :num_cols] = cost_matrix
        else:
            extended_cost_matrix = cost_matrix.copy()
    else:
        extended_cost_matrix = cost_matrix.copy()

    # 使用匈牙利算法进行匹配
    row_ind, col_ind = linear_sum_assignment(extended_cost_matrix)

    # 如果扩展了成本矩阵,过滤掉虚拟匹配
    if extend_cost:
        valid_indices = (row_ind < orig_num_rows) & (col_ind < orig_num_cols)
        row_ind = row_ind[valid_indices]
        col_ind = col_ind[valid_indices]

    # 应用成本阈值
    mask = cost_matrix[row_ind, col_ind] <= cost_limit
    filtered_row_ind = row_ind[mask]
    filtered_col_ind = col_ind[mask]
    total_cost = cost_matrix[filtered_row_ind, filtered_col_ind].sum()

    # 创建完整的匹配数组,未匹配的用 -1 表示
    x = -1 * np.ones(orig_num_rows, dtype=int)
    y = -1 * np.ones(orig_num_cols, dtype=int)
    x[filtered_row_ind] = filtered_col_ind
    y[filtered_col_ind] = filtered_row_ind

    return total_cost, x, y

def bbox_overlaps_python(boxes, query_boxes):
    """
    Parameters
    ----------
    boxes: (N, 4) 格式的边界框数组 [x1, y1, x2, y2]
    query_boxes: (K, 4) 格式的查询框数组 [x1, y1, x2, y2]
    
    Returns
    -------
    overlaps: (N, K) IoU矩阵
    """
    N = boxes.shape[0]
    K = query_boxes.shape[0]
    overlaps = np.zeros((N, K), dtype=np.float32)
    
    for k in range(K):
        query_box_area = (
            (query_boxes[k, 2] - query_boxes[k, 0]) *
            (query_boxes[k, 3] - query_boxes[k, 1])
        )
        for n in range(N):
            box_area = (
                (boxes[n, 2] - boxes[n, 0]) *
                (boxes[n, 3] - boxes[n, 1])
            )
            
            # 计算交集区域
            ix1 = max(boxes[n, 0], query_boxes[k, 0])
            iy1 = max(boxes[n, 1], query_boxes[k, 1])
            ix2 = min(boxes[n, 2], query_boxes[k, 2])
            iy2 = min(boxes[n, 3], query_boxes[k, 3])
            
            # 如果没有重叠区域
            if ix2 < ix1 or iy2 < iy1:
                continue
                
            # 计算交集面积
            intersection = (ix2 - ix1) * (iy2 - iy1)
            
            # 计算IoU
            union = box_area + query_box_area - intersection
            overlaps[n, k] = intersection / union
            
    return overlaps

总结与代码示例:

如果你嫌本文太长,直接拉到了最后,请直接clone GitHub - Sthyao/ByteTrack_simplified: You know that bytetrack's source code is a hassle to install, but you can actually rewrite parts of it to achieve exactly the same functionality. This project rewrites complex dependencies in numpy and scipy to achieve consistent functionality

来亲自做个尝试,把demo跑通之后再看看代码。本文所有的代码已经上传到该github仓库。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值