YOLO V8+ByteTrack官方代码解析
包括了两个核心类
- class BYTETracker
- class STrack
整个算法流程中最为核心的方法:BYTETracker.update
算法的官方流程:
def update(self, results, img=None): # results里面存储了检测的结果
"""首先进行初始化操作"""
self.frame_id += 1 # 确定当前的帧数
activated_stracks = [] # 临时列表放已激活的航迹
refind_stracks = [] # 临时列表放原本丢失又重生的航迹
lost_stracks = [] # 临时列表放丢失的航迹
removed_stracks = [] # 临时列表放删除的航迹
# 从检测框结果中得到分数和边界框回归参数+分类
scores = results.conf
bboxes = results.xyxy
# Add index
bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1)
cls = results.cls
remain_inds = scores > self.args.track_high_thresh #判断置信度得分是否大于0.5 track_high_thresh 高分框边界
inds_low = scores > self.args.track_low_thresh #低置信度的检测
inds_high = scores < self.args.track_high_thresh # 高置信度的检测
inds_second = np.logical_and(inds_low, inds_high) #做逻辑与运算及大于0.1又小于0.5的(第二步匹配时会用到的)
dets_second = bboxes[inds_second] # 根据索引得到第二次目标边界框存储置信度的检测结果(分数低于阈值但高于0.1的那些检测框)
dets = bboxes[remain_inds] # 存储高置信度的检测结果(分数高的那些检测框)
scores_keep = scores[remain_inds]
scores_second = scores[inds_second]
cls_keep = cls[remain_inds]
cls_second = cls[inds_second] # 跟踪被分为激活和未确认的跟踪。
detections = self.init_track(dets, scores_keep, cls_keep, img) # 为每一个检测框分配追踪器
# Add newly detected tracklets to tracked_stracks
unconfirmed = []
tracked_stracks = [] # type: list[STrack]
for track in self.tracked_stracks: # 将新检测的轨迹添加进去
if not track.is_activated: # 未激活状态添加
unconfirmed.append(track)
else:
tracked_stracks.append(track) # 激活状态的添加
# Step 2: First association, with high score detection boxes( 第一次跟踪的匹配 )
strack_pool = self.joint_stracks(tracked_stracks, self.lost_stracks) # 所有航迹包括了新激活的、追踪中的、丢失的
# Predict the current location with KF
self.multi_predict(strack_pool) # **执行上一个类中使用的卡尔曼滤波执行预测操作**
if hasattr(self, 'gmc') and img is not None:
warp = self.gmc.apply(img, dets)
STrack.multi_gmc(strack_pool, warp)
STrack.multi_gmc(unconfirmed, warp)
# 计算距离跟踪目标与新的检测目标之间的iou距离,这个距离会被用来在数据关联阶段确定哪些检测与哪些跟踪目标对应
dists = self.get_dists(strack_pool, detections) #下面通过匈牙利算法进行匹配
# matches 匹配成功的航迹与检测框
# u_track 匹配失败剩余的航迹(既没有与之匹配的检测框)
# u_detection 匹配失败剩余的检测框(既没有与之匹配的跟踪航迹)
# 参数match_thresh是设定多大的距离认为是匹配
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.args.match_thresh)
# 把匹配成功的航迹进行状态更新
# 那些处于跟踪状态的的航迹则更新位置,放入激活航迹的列表
# 那些处于丢失状态的航迹则再次激活,放入重生航迹的列表
for itracked, idet in matches: # 首次获取匹配的轨迹和检测框
track = strack_pool[itracked]
det = detections[idet]
if track.state == TrackState.Tracked: # 匹配成功的状态为Tracked已跟踪(已匹配状态)
track.update(det, self.frame_id) # 匹配成功的轨迹之后卡尔曼滤波的更新操作
activated_stracks.append(track)
else: # 如果是丢失态重新激活
track.re_activate(det, self.frame_id, new_id=False)
refind_stracks.append(track)
# Step 3: Second association, with low score detection boxes
# association the untrack to the low score detections
# 第二次关联:使用低置信度的检测进行二次匹配。匹配到的跟踪要么更新要么重新激活
detections_second = self.init_track(dets_second, scores_second, cls_second, img)
# 高分框匹配失败的追踪器——remain tracked
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
# TODO
dists = matching.iou_distance(r_tracked_stracks, detections_second)
matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5)
for itracked, idet in matches:
# 已激活的轨迹需要根据该帧检测框的信息去更新轨迹属性,KF根据测量值(该帧的检测框坐标)矫正最优估计,
# 并更新状态协方差矩阵。而未激活的轨迹重新关联上检测框,需要重新激活,re_activate与update功能类似
track = r_tracked_stracks[itracked]
det = detections_second[idet]
if track.state == TrackState.Tracked: # 已激活的track
track.update(det, self.frame_id) #确定匹配后更新track信息,KF根据测量值矫正最优估计,并更新状态协方差矩阵
activated_stracks.append(track)
else:
track.re_activate(det, self.frame_id, new_id=False) # 未激活的track重新匹配上需要重新激活
refind_stracks.append(track)
# 经过两次匹配之后若还有未匹配到的追踪器(u_track里面是第二次匹配还是失败的剩余航迹,也不删除还有得用,列为丢失等待重生)
for it in u_track: # 遍历未匹配到的追踪器
track = r_tracked_stracks[it] # 获取轨迹
if track.state != TrackState.Lost: # 如果状态不是lost转状态
track.mark_lost() # 标记状态未LOST状态
lost_stracks.append(track)
# Deal with unconfirmed tracks, usually tracks with only one beginning frame
detections = [detections[i] for i in u_detection]
dists = self.get_dists(unconfirmed, detections)
# unconfirm状态理解,位于new状态但还没有被匹配上
matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
for itracked, idet in matches:
unconfirmed[itracked].update(detections[idet], self.frame_id)
activated_stracks.append(unconfirmed[itracked])
for it in u_unconfirmed:
track = unconfirmed[it]
track.mark_removed() # new没有匹配上之间removed状态
removed_stracks.append(track)
# Step 4: Init new stracks
for inew in u_detection: # 第一次匹配中失败的高分检测框
track = detections[inew]
if track.score < self.args.new_track_thresh:
continue
track.activate(self.kalman_filter, self.frame_id) #激活一个kalman滤波器,下一帧开始预测航迹新位置
activated_stracks.append(track) # 存放激活状态的追踪器
# Step 5: Update state
for track in self.lost_stracks:
if self.frame_id - track.end_frame > self.max_time_lost: # lost_stracks超过30帧后将其进行删除
track.mark_removed()
removed_stracks.append(track)
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
self.tracked_stracks = self.joint_stracks(self.tracked_stracks, activated_stracks) # 将激活的轨迹添加到已跟踪的轨迹列表中
self.tracked_stracks = self.joint_stracks(self.tracked_stracks, refind_stracks) # 重新找的轨迹添加
self.lost_stracks = self.sub_stracks(self.lost_stracks, self.tracked_stracks)
self.lost_stracks.extend(lost_stracks)
self.lost_stracks = self.sub_stracks(self.lost_stracks, self.removed_stracks)
self.tracked_stracks, self.lost_stracks = self.remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
self.removed_stracks.extend(removed_stracks)
if len(self.removed_stracks) > 1000:
self.removed_stracks = self.removed_stracks[-999:] # clip remove stracks to 1000 maximum
# 返回已经激活的跟踪列表
return np.asarray( # 返回一个NumPy数组,包含所有激活跟踪的边界框、跟踪ID、得分、类别和索引
[x.tlbr.tolist() + [x.track_id, x.score, x.cls, x.idx] for x in self.tracked_stracks if x.is_activated],
dtype=np.float32)
使用跟踪方法在检测中通过回调的方式调用对应的目标函数
# 预测期间初始化目标跟踪器。它接受一个预测器对象和一个可选的persist 参数,用于控制是否要保持已存在的跟踪器
def on_predict_start(predictor, persist=False):
"""
Initialize trackers for object tracking during prediction.
Args:
predictor (object): The predictor object to initialize trackers for.
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
Raises:
AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
"""
# 已经具有跟踪器,则函数会直接返回。
if hasattr(predictor, 'trackers') and persist:
return
tracker = check_yaml(predictor.args.tracker)
cfg = IterableSimpleNamespace(**yaml_load(tracker))
assert cfg.tracker_type in ['bytetrack', 'botsort'], \
f"Only support 'bytetrack' and 'botsort' for now, but got '{cfg.tracker_type}'"
trackers = [] # 初始化跟踪器并将其存入列表中
for _ in range(predictor.dataset.bs):
tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30)
trackers.append(tracker)
predictor.trackers = trackers
# 通过该函数将检测和跟踪合并到一起,先进行检测在进行跟踪
def on_predict_postprocess_end(predictor):
"""Postprocess detected boxes and update with object tracking."""
bs = predictor.dataset.bs
im0s = predictor.batch[1]
for i in range(bs):
det = predictor.results[i].boxes.cpu().numpy()
if len(det) == 0:
continue
tracks = predictor.trackers[i].update(det, im0s[i]) # 调用ByteTrack里面跟踪相关的update方法
if len(tracks) == 0:
continue
idx = tracks[:, -1].astype(int)
predictor.results[i] = predictor.results[i][idx]
predictor.results[i].update(boxes=torch.as_tensor(tracks[:, :-1]))
def register_tracker(model, persist):
"""
Register tracking callbacks to the model for object tracking during prediction.
Args:
model (object): The model object to register tracking callbacks for.
persist (bool): Whether to persist the trackers if they already exist.
"""
# 将两个函数作为回调函数加入到model(YOLO V8)之中 [在目标跟踪器v8中加入了回调函数]
# 将目标检测器和跟踪器联系起来(跟踪器挂到检测器上)
model.add_callback('on_predict_start', partial(on_predict_start, persist=persist))
model.add_callback('on_predict_postprocess_end', on_predict_postprocess_end)
在update方法中使用到的一些补充的方法
# 使用STrack算法,用检测结果和得分初始化目标跟踪
def init_track(self, dets, scores, cls, img=None):
# dets : 检测到的目标的边界框列表。每个边界框通常表示为一个四元组,例如 (x_min, y_min, x_max, y_max) ,表示目标的左上和右下角的坐标
# scores :置信度分数列表
"""Initialize object tracking with detections and scores using STrack algorithm."""
# zip 将检测边界框、分数和类别组合在一起。这样做是为了确保每个边界框、分数和类别都能按照正确的顺序匹配。
# 返回一个 STrack 目标的列表,每个目标表示一个被检测到的并被初始化为跟踪的对象
return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections
# 计算跟踪和检测之间的距离,并融合得分
def get_dists(self, tracks, detections):
"""Calculates the distance between tracks and detections using IOU and fuses scores."""
dists = matching.iou_distance(tracks, detections) # 计算跟踪目标与检测目标直接的iou距离
# TODO: mot20
# if not self.args.mot20:
dists = matching.fuse_score(dists, detections) # 将IoU距离和每个检测的得分结合(逐元素乘法)起来
return dists
# 使用YOLOv8网络返回预测的跟踪(一个核心的方法)
def multi_predict(self, tracks):
"""Returns the predicted tracks using the YOLOv8 network."""
STrack.multi_predict(tracks) # 调用STrack中的方法
def reset_id(self):
"""Resets the ID counter of STrack."""
STrack.reset_id()
# 其目的是将两个跟踪目标列表( tlista 和tlistb )合并成一个单一列表,同时确保没有重复的跟踪目标
@staticmethod
def joint_stracks(tlista, tlistb):
"""Combine two lists of stracks into a single one."""
exists = {} # 字典用于存储已经添加到结果列表res 中的跟踪目标的 track_id
res = [] # 结果列表
for t in tlista: # 遍历 tlista ,并将每个跟踪目标 t 添加到结果列表 res 中。同时,将每个跟踪目标的 track_id 添加到exists 字典中,值设置为1
exists[t.track_id] = 1
res.append(t)
for t in tlistb: # 遍历 tlistb ,对于每个跟踪目标 t,通过检查其 track_id 是否在 exists 字典中,来判断它是否已经在结果列表中。如果 track_id 不在字典中,那么将跟踪目标添加到结果列表 res 中,并在 exists 字典中为相应的 track_id 设置值1。
tid = t.track_id
if not exists.get(tid, 0):
exists[tid] = 1
res.append(t)
return res # 函数返回合并后的跟踪目标列表 res ,其中不包括任何重复的跟踪目标
# 返回只包含在 tlista 中但不在 tlistb 中的strack
@staticmethod
def sub_stracks(tlista, tlistb): # 目标是从第一个轨迹列表tlista 中减去与第二个轨迹列表 tlistb 中重叠的部分。这里的“重叠”是指两个轨迹具有相同的轨迹ID
"""DEPRECATED CODE in https://github.com/ultralytics/ultralytics/pull/1890/
stracks = {t.track_id: t for t in tlista}
for t in tlistb:
tid = t.track_id
if stracks.get(tid, 0):
del stracks[tid]
return list(stracks.values())
"""
track_ids_b = {t.track_id for t in tlistb} # 在tlistb中提取出来全部的轨迹id(这是一个集合推导式来创建一个集合)
return [t for t in tlista if t.track_id not in track_ids_b] #使用(列表推导式返回不在b中存在的列表)
# 去除了在stracksa和stracksb中重复的stracks
# 是在两个给定的轨迹列表(stracksa 和 stracksb)中删除重复的轨迹。这里的重复指的是两个轨迹之间的 IoU 距离小于某个阈值,即两个轨迹有很高的重叠