【deep_sort_realtime库】tracker.py 中文注释版

本文介绍了基于深度学习的多目标追踪器,它使用NN匹配、卡尔曼滤波和IOU匹配算法来关联和管理目标。追踪器具有参数如最大IOU距离、最大年龄和初始化帧数,以及对跟踪对象的命名策略和状态管理功能。
摘要由CSDN通过智能技术生成

 原链接:deep_sort_realtime/deep_sort_realtime/deep_sort/tracker.py at master · levan92/deep_sort_realtime (github.com)

 

中文注释版:

# vim: expandtab:ts=4:sw=4
from __future__ import absolute_import
from datetime import datetime
import numpy as np
from . import kalman_filter
from . import linear_assignment
from . import iou_matching
from .track import Track

class Tracker:


    def __init__(  # 构造函数,用于初始化 Tracker 类
            self,
            metric,  # 传入的距离度量对象,用于测量与跟踪之间的关联
            max_iou_distance=0.7,  # 最大IoU(交并比)距离,用于关联检测框和跟踪对象
            max_age=30,  # 最大年龄,跟踪在丢失多少次后才被删除
            n_init=3,  # 初始化帧数,新跟踪对象在成为确认跟踪前需要经过的帧数
            override_track_class=None,  # 用于覆盖默认的跟踪类
            today=None,  # 当前日期,用于跟踪对象的命名
            gating_only_position=False,  # 在门限过程中是否仅使用位置信息进行比较
    ):
        """多目标跟踪器-参数说明

----------
metric : nn_matching.NearestNeighborDistanceMetric
    用于测量与跟踪关联的距离度量指标。
max_age : int
    在跟踪被删除前允许错过的最大帧数。
n_init : int
    在确认跟踪之前需要连续检测的帧数。如果在前 `n_init` 帧内发生错过,则跟踪状态将被设置为 `Deleted`。
today: Optional[datetime.date]
    提供今天的日期,用于跟踪的命名

属性
----------
metric : nn_matching.NearestNeighborDistanceMetric
    用于测量到跟踪关联的距离度量指标。
max_age : int
    在跟踪被删除前允许错过的最大帧数。
n_init : int
    跟踪在初始化阶段保持的帧数。
kf : kalman_filter.KalmanFilter
    一个卡尔曼滤波器,用于在图像空间内过滤目标轨迹。
tracks : List[Track]
    当前时间步中活跃的跟踪列表。
gating_only_position : Optional[bool]
    在门限过程中使用,比较 KF 预测状态和测量状态。如果为 True,则在门限过程中只考虑状态分布的 x, y 位置。默认为 False,将考虑 x,y, 宽高比和高度。
    """
        # 初始化函数体

        self.today = today  # 设置当前日期
        self.metric = metric  # 设置传入的距离度量对象
        self.max_iou_distance = max_iou_distance  # 设置最大IoU距离
        self.max_age = max_age  # 设置最大年龄
        self.n_init = n_init  # 设置初始化帧数
        self.gating_only_position = gating_only_position  # 设置是否仅在门限过程中使用位置信息

        # 初始化卡尔曼滤波器,用于预测跟踪对象的状态
        self.kf = kalman_filter.KalmanFilter()
        # 初始化活跃跟踪列表
        self.tracks = []
        # 初始化待删除跟踪ID列表
        self.del_tracks_ids = []
        # 初始化下一个可用的跟踪ID
        self._next_id = 1
        # 如果提供了覆盖跟踪类,则使用提供的类,否则使用默认的 Track 类
        if override_track_class:
            self.track_class = override_track_class
        else:
            self.track_class = Track

    def predict(self):
        """
        推进跟踪状态分布一个时间步长向前。

        这个函数应在每个时间步调用一次,在 `update` 函数之前。
        """
        for track in self.tracks:
            track.predict(self.kf)
        # 对每个跟踪对象调用其 predict 方法,使用卡尔曼滤波器预测下一状态

    def update(self, detections, today=None):
        """
        执行测量更新和跟踪管理。

        参数
        ----------
        detections : List[deep_sort.detection.Detection]
            当前时间步的检测列表。
        today: Optional[datetime.date]
            提供今天的日期,用于跟踪的命名
        """
        if self.today:
            if today is None:
                today = datetime.now().date()
                # 如果提供了今天的日期,则使用该日期,否则获取当前日期
            # 检查是否为新的一天,如果是,则刷新跟踪索引
            if today != self.today:
                self.today = today
                self._next_id = 1

        # 运行匹配级联
        matches, unmatched_tracks, unmatched_detections = self._match(detections)

        # 更新跟踪集合
        for track_idx, detection_idx in matches:
            self.tracks[track_idx].update(self.kf, detections[detection_idx])
            # 对匹配的跟踪更新其状态
        for track_idx in unmatched_tracks:
            self.tracks[track_idx].mark_missed()
            # 标记未匹配的跟踪为错过
        for detection_idx in unmatched_detections:
            self._initiate_track(detections[detection_idx])
            # 为未匹配的检测初始化新跟踪

        # 准备更新跟踪列表和删除跟踪ID列表
        new_tracks = []
        self.del_tracks_ids = []
        for t in self.tracks:
            if not t.is_deleted():
                new_tracks.append(t)
            else:
                self.del_tracks_ids.append(t.track_id)
        # 清除已删除的跟踪,更新跟踪列表
        self.tracks = new_tracks

        # 更新距离度量
        active_targets = [t.track_id for t in self.tracks if t.is_confirmed()]
        # 准备特征和目标列表用于更新距离度量
        features, targets = [], []
        for track in self.tracks:
            if not track.is_confirmed():
                continue
            features += track.features
            targets += [track.track_id for _ in track.features]
            # 更新跟踪的特征列表,只保留最后一个特征
            track.features = [track.features[-1]]
        self.metric.partial_fit(
            np.asarray(features), np.asarray(targets), active_targets
        )
        # 使用特征和目标列表部分拟合距离度量,更新其内部状态

    def _match(self, detections):
        """
        执行匹配算法,将检测结果与现有跟踪进行关联。

        参数
        ----------
        detections : List[deep_sort.detection.Detection]
            当前时间步的检测列表。

        返回
        ------
        matches : List[Tuple[int, int]]
            匹配的跟踪索引和检测索引的元组列表。
        unmatched_tracks : List[int]
            未匹配的跟踪索引列表。
        unmatched_detections : List[int]
            未匹配的检测索引列表。
        """

        def gated_metric(tracks, dets, track_indices, detection_indices):
            """
            计算门限化的匹配成本矩阵。

            参数
            ----------
            tracks : List[Track]
                跟踪列表。
            dets : List[Detection]
                检测列表。
            track_indices : List[int]
                跟踪的索引列表。
            detection_indices : List[int]
                检测的索引列表。

            返回
            ------
            cost_matrix : numpy.ndarray
                计算得到的门限化匹配成本矩阵。
            """
            features = np.array([dets[i].feature for i in detection_indices])
            targets = np.array([tracks[i].track_id for i in track_indices])
            # 计算特征和目标之间的距离成本矩阵
            cost_matrix = self.metric.distance(features, targets)
            # 使用卡尔曼滤波器预测和检测对象的空间位置,生成门限化的成本矩阵
            cost_matrix = linear_assignment.gate_cost_matrix(
                self.kf, cost_matrix, tracks, dets, track_indices, detection_indices,
                only_position=self.gating_only_position
            )

            return cost_matrix

        # 将跟踪集合分为已确认和未确认的跟踪。
        confirmed_tracks = [i for i, t in enumerate(self.tracks) if t.is_confirmed()]
        unconfirmed_tracks = [
            i for i, t in enumerate(self.tracks) if not t.is_confirmed()
        ]

        # 使用外观特征关联已确认的跟踪。
        (
            matches_a,
            unmatched_tracks_a,
            unmatched_detections,
        ) = linear_assignment.matching_cascade(
            gated_metric,
            self.metric.matching_threshold,
            self.max_age,
            self.tracks,
            detections,
            confirmed_tracks,
        )

        # 使用IOU关联剩余的跟踪和未确认的跟踪。
        iou_track_candidates = unconfirmed_tracks + [
            k for k in unmatched_tracks_a if self.tracks[k].time_since_update == 1
        ]
        unmatched_tracks_a = [
            k for k in unmatched_tracks_a if self.tracks[k].time_since_update != 1
        ]
        (
            matches_b,
            unmatched_tracks_b,
            unmatched_detections,
        ) = linear_assignment.min_cost_matching(
            iou_matching.iou_cost,
            self.max_iou_distance,
            self.tracks,
            detections,
            iou_track_candidates,
            unmatched_detections,
        )

        matches = matches_a + matches_b
        # 合并两个匹配列表
        unmatched_tracks = list(set(unmatched_tracks_a + unmatched_tracks_b))
        # 合并两个未匹配跟踪的列表
        return matches, unmatched_tracks, unmatched_detections
        # 返回匹配结果和未匹配的跟踪及检测列表

    def _initiate_track(self, detection):
        """
        初始化一个新的跟踪对象。

        参数
        ----------
        detection : deep_sort.detection.Detection
            当前帧中的一个检测对象。

        返回
        ------
        无
        """
        # 使用卡尔曼滤波器初始化新跟踪的状态和协方差
        mean, covariance = self.kf.initiate(detection.to_xyah())

        # 如果提供了今天的日期,则在跟踪ID中包含日期信息
        if self.today:
            track_id = "{}_{}".format(self.today, self._next_id)
        else:
            # 否则,只使用跟踪ID计数器
            track_id = "{}".format(self._next_id)

        # 创建并添加新的跟踪对象到跟踪列表中
        self.tracks.append(
            self.track_class(
                mean,  # 初始状态的均值向量
                covariance,  # 初始状态的协方差矩阵
                track_id,  # 新跟踪的唯一ID
                self.n_init,  # 初始化帧数
                self.max_age,  # 最大年龄
                # 以下参数是跟踪对象的属性
                feature=detection.feature,  # 检测对象的特征
                original_ltwh=detection.get_ltwh(),  # 检测对象的原始边界框
                det_class=detection.class_name,  # 检测对象的类别名称
                det_conf=detection.confidence,  # 检测对象的置信度
                instance_mask=detection.instance_mask,  # 检测对象的实例掩码
                others=detection.others,  # 检测对象的其他信息
            )
        )
        # 递增跟踪ID计数器
        self._next_id += 1

    def delete_all_tracks(self):
        """
        删除所有跟踪对象。

        返回
        ------
        无
        """
        # 清空跟踪列表
        self.tracks = []
        # 重置跟踪ID计数器
        self._next_id = 1

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值