目标跟踪算法:ByteTrack、卡尔曼滤波、匈牙利算法、C++代码逐行解读

目录

 1  ByteTrack特点        

2 ByteTrack和SORT区别----个人通俗理解

3 ByteTrack算法原理

4 ByteTrack整体流程图

5 C++代码github备份

6  BYTETracker::update源码

7 BYTETracker::update代码解读

7.1 对检测结果做初步处理

7.2 将tracked_stracks分为unconfirmed和激活的

7.3 第一次匹配:tracked_stracks + lost_stracks 和高得分检测框进行匹配

7.4 第二次匹配:前面未匹配上的tracked_stracks  和低得分检测框进行匹配

7.5 第三次匹配:unconfirmed track和前面没匹配上的高得分检测框进行匹配

7.6 还没有匹配上的高得分检测框:认为是个新出现的目标,新建一个跟踪链

8 参考文献:


        上一篇博客我复习了下SORT跟踪算法,这一篇博客我再复习下ByteTrack跟踪算法,ByteTrack里面也是用了卡尔曼滤波和匈牙利算法,关于卡尔曼滤波和匈牙利算法可以看我的上一篇博客:目标跟踪算法:SORT、卡尔曼滤波、匈牙利算法-CSDN博客

 1  ByteTrack特点        

    多目标追踪算法一般在完成当前帧的目标检测后只会保留置信度比较大的检测框用于进行目标跟踪,而在ByteTrack中,作者保留了所有的检测框并且通过阈值将它们分成了高置信度检测框和低置信度检测框。ByteTrack 可以有效解决一些遮挡,且能够保持较低的 ID Switch。因为目标会因为被遮挡检测置信度有所降低,当重新出现时,置信度会有所升高。算法特点在于:

  • 当目标逐渐被遮挡时,跟踪目标与低置信度检测目标匹配。
  • 当目标遮挡逐渐重现时,跟踪目标与高置信度检测目标匹配。

2 ByteTrack和SORT区别----个人通俗理解

        那其实bytetrack和sort相比,bytetrack也是用到了卡尔曼滤波和匈牙利算法,不同的就是bytetrack他利用了检测得到的高得分框和低得分框,然后他的匹配逻辑更复杂一点,而sort只用了最高得分的检测框去和历史轨迹做匹配,他的匹配逻辑简单点,但是不管是bytetrack还是sort他们都是用了匈牙利算法做匹配,只不过匹配的逻辑不太一样,并且他们两者都是用了卡尔曼滤波做预测以及更新最优值。

        为什么ByteTrack对于遮挡目标效果更好:如果目标被遮挡那么得分会低,对于sort算法,由于只保留大于阈值的检测框,所以可能根本就没有这个检测框了,因为可能这个检测框得分就小于阈值,而对于bytetrack他会要求检测算法保留所有的检测框,这样bytetrack会利用低得分的去和预测框进行匹配,所以这样bytetrack就能解决一些遮挡问题,所以要优于sort算法。

3 ByteTrack算法原理

追踪算法的详细步骤:

  1. 在开始追踪之前给每一目标创建追踪轨迹;
  2. 通过卡尔曼滤波预测每一个追踪轨迹的下一帧边界框;
  3. 通过检测器获得目标的检测框,根据置信度将检测框分为高分框和低分框;
  4. 首先针对高分框,计算高分框和预测框的IOU ,使用匈牙利算法匹配IOU,获得3个结果:已匹配的轨迹与高分框,未成功匹配的轨迹,未成功匹配的高分框。匹配成功后,通过卡尔曼滤波算法利用高分检测框和预测框计算得到最优框,然后将追踪轨迹中的框更新为计算得到的最优框;
  5. 然后针对低分框,计算低分框和上一步未匹配上的预测框的IOU,使用匈牙利算法匹配IOU,获得3个结果:已匹配的轨迹与低分框,未成功匹配的轨迹,未成功匹配的低分框。匹配成功后将通过卡尔曼滤波算法利用低分检测框和预测框计算得到最优框,然后将追踪轨迹中的框更新为计算得到的最优框; 
  6. 最后针对未匹配上的高分检测框,将其和状态未激活的轨迹匹配,获得3个结果:匹配、未匹配轨迹、未匹配检测框。对于匹配更新状态,对于未匹配轨迹标记为删除,对于未匹配检测框,置信度大于高阈值+0.1新建一个跟踪轨迹,小于则丢弃。

4 ByteTrack整体流程图

整体的流程图如下

5 C++代码github备份

代码上传到了github上。 

https://github.com/cumtchw/ByteTrack

6  BYTETracker::update源码

std::vector<Track> BYTETracker::update(std::vector<turing::Detect> const& dets_in, int img_width, int img_height, uint64_t time_stamp, ObjectTrackTable const& obj_trks) {
        vector<Track> output_stracks;

        auto update_motionless_time = [&obj_trks] (STrack * track) {
            //����Ŀ��ľ�ֹʱ��
            auto iter = obj_trks.find(track->track_id);
            if (iter != obj_trks.end()){
                auto const& history_trks = iter->second;
                if (!history_trks.empty()){
                    track->motionless_time = history_trks.back().motionless_time;
                }
            }
        };

        // Step 1: Get detections //
        this->frame_id++;
        vector<STrack> activated_stracks;
        vector<STrack> refind_stracks;
        vector<STrack> removed_stracks;
        vector<STrack> lost_stracks;
        vector<STrack> detections;
        vector<STrack> detections_low;

        vector<STrack> detections_cp;
        vector<STrack> tracked_stracks_swap;
        vector<STrack> resa, resb;

        vector<STrack*> unconfirmed;
        vector<STrack*> tracked_stracks;
        vector<STrack*> strack_pool;
        vector<STrack*> r_tracked_stracks;

        if (dets_in.size() > 0)
        {
            for (int i = 0; i < dets_in.size(); i++)
            {
                vector<float> tlbr_;
                tlbr_.resize(4);
                tlbr_[0] = dets_in[i].box.x;
                tlbr_[1] = dets_in[i].box.y;
                tlbr_[2] = dets_in[i].box.x + dets_in[i].box.width;
                tlbr_[3] = dets_in[i].box.y + dets_in[i].box.height;

                int clas = dets_in[i].clas;
                int score = dets_in[i].score * 100;
                int motionless_time = 0;

                STrack strack(STrack::tlbr_to_tlwh(tlbr_), clas, score, motionless_time, time_stamp);
                strack.rbox = dets_in[i].rbox;
                //static int track_thresh = Config::instance().int_conf("TRACK_THRESH", 50);
                static int track_thresh = 50;
                if (score >= track_thresh)
                {
                    detections.push_back(strack);
                }
                else
                {
                    detections_low.push_back(strack);
                }
            }
        }

        // Add newly detected tracklets to tracked_stracks
        for (int i = 0; i < this->tracked_stracks.size(); i++)
        {
            if (!this->tracked_stracks[i].is_activated)
                unconfirmed.push_back(&this->tracked_stracks[i]);
            else
                tracked_stracks.push_back(&this->tracked_stracks[i]);
        }

        // Step 2: First association, with IoU //
        strack_pool = joint_stracks(tracked_stracks, this->lost_stracks);
        STrack::multi_predict(strack_pool, this->kalman_filter);

        vector<vector<float> > dists;
        int dist_size = 0, dist_size_size = 0;
        dists = iou_distance(strack_pool, detections, dist_size, dist_size_size);

        vector<vector<int> > matches;
        vector<int> u_track, u_detection;
        linear_assignment(dists, dist_size, dist_size_size, HIGH_MATCH_THRESH, matches, u_track, u_detection);

        for (int i = 0; i < matches.size(); i++)
        {
            STrack *track = strack_pool[matches[i][0]];
            STrack *det = &detections[matches[i][1]];
            if (track->state == TrackState::Tracked)
            {
                track->update(*det, this->frame_id);
                activated_stracks.push_back(*track);
            }
            else
            {
                track->re_activate(*det, this->frame_id, false);
                refind_stracks.push_back(*track);
            }
            update_motionless_time(track);
        }

        // Step 3: Second association, using low score dets //
        for (int i = 0; i < u_detection.size(); i++)
        {
            detections_cp.push_back(detections[u_detection[i]]);
        }
        detections.clear();
        detections.assign(detections_low.begin(), detections_low.end());

        for (int i = 0; i < u_track.size(); i++)
        {
            if (strack_pool[u_track[i]]->state == TrackState::Tracked)
            {
                r_tracked_stracks.push_back(strack_pool[u_track[i]]);
            }
        }

        dists.clear();
        dists = iou_distance(r_tracked_stracks, detections, dist_size, dist_size_size);

        matches.clear();
        u_track.clear();
        u_detection.clear();
        linear_assignment(dists, dist_size, dist_size_size, LOW_MATCH_THRESH, matches, u_track, u_detection);

        for (int i = 0; i < matches.size(); i++)
        {
            STrack *track = r_tracked_stracks[matches[i][0]];
            STrack *det = &detections[matches[i][1]];
            if (track->state == TrackState::Tracked)
            {
                track->update(*det, this->frame_id);
                activated_stracks.push_back(*track);
            }
            else
            {
                track->re_activate(*det, this->frame_id, false);
                refind_stracks.push_back(*track);
            }
            update_motionless_time(track);
        }

        for (int i = 0; i < u_track.size(); i++)
        {
            STrack *track = r_tracked_stracks[u_track[i]];
            if (track->state != TrackState::Lost)
            {
                track->mark_lost();
                lost_stracks.push_back(*track);
            }
        }

        // Deal with unconfirmed tracks, usually tracks with only one beginning frame
        detections.clear();
        detections.assign(detections_cp.begin(), detections_cp.end());

        dists.clear();
        dists = iou_distance(unconfirmed, detections, dist_size, dist_size_size);

        matches.clear();
        vector<int> u_unconfirmed;
        u_detection.clear();
        linear_assignment(dists, dist_size, dist_size_size, UNCONFIRMED_MATCH_THRESH, matches, u_unconfirmed, u_detection);

        for (int i = 0; i < matches.size(); i++)
        {
            unconfirmed[matches[i][0]]->update(detections[matches[i][1]], this->frame_id);
            activated_stracks.push_back(*unconfirmed[matches[i][0]]);
        }

        for (int i = 0; i < u_unconfirmed.size(); i++)
        {
            STrack *track = unconfirmed[u_unconfirmed[i]];
            track->mark_removed();
            removed_stracks.push_back(*track);
        }

        // Step 4: Init new stracks //
        for (int i = 0; i < u_detection.size(); i++)
        {
            STrack *track = &detections[u_detection[i]];
            //static int high_thresh = Config::instance().int_conf("HIGH_THRESH", 60);
            static int high_thresh = 60;
            if (track->score < high_thresh)
                continue;
            track->activate(this->kalman_filter, this->frame_id);
            activated_stracks.push_back(*track);
        }

        // Step 5: Update state //
        for (int i = 0; i < this->lost_stracks.size(); i++)
        {
            //计算track移除时间
            // static int default_lost_time = Config::instance().int_conf("DEFAULT_LOST_TIME", 30); //默认的移除track的时间
            // static int max_lost_time = Config::instance().int_conf("MAX_LOST_TIME", 250); //移除track的最大时间
            static int default_lost_time = 25; //目标连续未匹配到的默认帧数,超过此帧数则不再跟踪, 默认30帧,默认的移除track时间,
            static int max_lost_time = 250; //目标连续未匹配到可设置的最大帧数, 默认可设置的最大值为250帧,移除track的最大时间。

            int lost_time = default_lost_time;
            if (this->lost_stracks[i].motionless_time > 0){
                如果目标当前已经静止大于等于1s, 则默认100帧, 并且使用下列规则随着静止时间增加进行相应的增加
                lost_time = 100 + 10 * log10f(this->lost_stracks[i].motionless_time);
                lost_time = std::min(lost_time, max_lost_time);
            }

            if (this->frame_id - this->lost_stracks[i].end_frame() > lost_time)
            {
                this->lost_stracks[i].mark_removed();
                removed_stracks.push_back(this->lost_stracks[i]);
            }
        }

        for (int i = 0; i < this->tracked_stracks.size(); i++)
        {
            if (this->tracked_stracks[i].state == TrackState::Tracked)
            {
                tracked_stracks_swap.push_back(this->tracked_stracks[i]);
            }
        }
        this->tracked_stracks.clear();
        this->tracked_stracks.assign(tracked_stracks_swap.begin(), tracked_stracks_swap.end());

        this->tracked_stracks = joint_stracks(this->tracked_stracks, activated_stracks);
        this->tracked_stracks = joint_stracks(this->tracked_stracks, refind_stracks);

        //std::cout << activated_stracks.size() << std::endl;

        this->lost_stracks = sub_stracks(this->lost_stracks, this->tracked_stracks);
        for (int i = 0; i < lost_stracks.size(); i++)
        {
            this->lost_stracks.push_back(lost_stracks[i]);
        }

        this->lost_stracks = sub_stracks(this->lost_stracks, this->removed_stracks);
        for (int i = 0; i < removed_stracks.size(); i++)
        {
            this->removed_stracks.push_back(removed_stracks[i]);
        }

        this->removed_stracks.erase(std::remove_if(this->removed_stracks.begin(), this->removed_stracks.end(), [this](STrack& strk){
            //static int max_lost_time = Config::instance().int_conf("MAX_LOST_TIME", 250); //�Ƴ�track�����ʱ��
            static int max_lost_time = 250; //�Ƴ�track�����ʱ��
            if (this->frame_id - strk.end_frame() > max_lost_time) {
                return true;
            }      
            return false;
        }), this->removed_stracks.end());
 
        remove_duplicate_stracks(resa, resb, this->tracked_stracks, this->lost_stracks);

        this->tracked_stracks.clear();
        this->tracked_stracks.assign(resa.begin(), resa.end());
        this->lost_stracks.clear();
        this->lost_stracks.assign(resb.begin(), resb.end());

        for (int i = 0; i < this->tracked_stracks.size(); i++)
        {
            STrack& t = this->tracked_stracks[i];
            Rect2f bbox(t.tlwh[0], t.tlwh[1], t.tlwh[2], t.tlwh[3]); //跟踪的框,

            if (this->tracked_stracks[i].is_activated)
            {
                Rect2f box(t._tlwh[0], t._tlwh[1], t._tlwh[2], t._tlwh[3]); //检测的框,

                int none_update_lasting_time = (time_stamp - t.det_time_stamp)/25;//modified by chw
                auto area_rate = bbox.area() / box.area();
                if (none_update_lasting_time > 1){
                    //对于检测框多于1帧未更新时:
                    //当预测框的面积与最近一次更新的检测框的面积比小于0.5或者已经连续10帧未更新, 则不输出当前的目标信息
                    //否则使用预测框
                    if (none_update_lasting_time > 10){
                        continue;
                    }
                    box = bbox;
                    //rectify_rect(ori_img.size().width, ori_img.size().height, box);
                    rectify_rect(img_width, img_height, box);
                }

                Point2f cent;
                if (!use_rbox_) {
                    cent = Point2f(box.x + box.width*0.5, box.y + box.height*0.5);
                }
                else {
                    cent = Point2f((t.rbox.points[0].x + t.rbox.points[2].x)*0.5, (t.rbox.points[0].y + t.rbox.points[2].y)*0.5);
                }

                TrackAuxState track_state{ cent, 0, 0, 0, 0, 0, Point2f(), 0};
                Track res{ t.track_id, t.clas, t.score, box, t.rbox, 0, 0, SS_MOVE, 0, -1, 0, 0, track_state};
                output_stracks.push_back(res);
            }
        }
        return output_stracks;
    }

7 BYTETracker::update代码解读

接下来详细看一下src/bytetrack/BYTETracker.cpp中的BYTETracker::update这个函数。

7.1 对检测结果做初步处理

if (dets_in.size() > 0)
        {
            for (int i = 0; i < dets_in.size(); i++)
            {
                vector<float> tlbr_;
                tlbr_.resize(4);
                tlbr_[0] = dets_in[i].box.x;
                tlbr_[1] = dets_in[i].box.y;
                tlbr_[2] = dets_in[i].box.x + dets_in[i].box.width;
                tlbr_[3] = dets_in[i].box.y + dets_in[i].box.height;

                int clas = dets_in[i].clas;
                int score = dets_in[i].score * 100;
                int motionless_time = 0;

                STrack strack(STrack::tlbr_to_tlwh(tlbr_), clas, score, motionless_time, time_stamp);
                strack.rbox = dets_in[i].rbox;
                //static int track_thresh = Config::instance().int_conf("TRACK_THRESH", 50);
                static int track_thresh = 50;
                if (score >= track_thresh)
                {
                    detections.push_back(strack);
                }
                else
                {
                    detections_low.push_back(strack);
                }
            }
        }

这里只是对检测结果做初步处理,然后根据得分分为高得分检测结果和低得分检测结果。

7.2 将tracked_stracks分为unconfirmed和激活的

        // Add newly detected tracklets to tracked_stracks
        for (int i = 0; i < this->tracked_stracks.size(); i++)
        {
            if (!this->tracked_stracks[i].is_activated)
                unconfirmed.push_back(&this->tracked_stracks[i]);
            else
                tracked_stracks.push_back(&this->tracked_stracks[i]);
        }

这里是把上一帧的跟踪信息分为unconfirmed和激活的,这里unconfirmed未确认的意思其实就是只有一帧结果的,然后is_activated就是正常的激活的跟踪链。

7.3 第一次匹配:tracked_stracks + lost_stracks 和高得分检测框进行匹配

// Step 2: First association, with IoU //
        strack_pool = joint_stracks(tracked_stracks, this->lost_stracks);
        STrack::multi_predict(strack_pool, this->kalman_filter);

        vector<vector<float> > dists;
        int dist_size = 0, dist_size_size = 0;
        dists = iou_distance(strack_pool, detections, dist_size, dist_size_size);

        vector<vector<int> > matches;
        vector<int> u_track, u_detection;
        linear_assignment(dists, dist_size, dist_size_size, HIGH_MATCH_THRESH, matches, u_track, u_detection);

        for (int i = 0; i < matches.size(); i++)
        {
            STrack *track = strack_pool[matches[i][0]];
            STrack *det = &detections[matches[i][1]];
            if (track->state == TrackState::Tracked)
            {
                track->update(*det, this->frame_id);
                activated_stracks.push_back(*track);
            }
            else
            {
                track->re_activate(*det, this->frame_id, false);
                refind_stracks.push_back(*track);
            }
            update_motionless_time(track);
        }
  • strack_pool = joint_stracks(tracked_stracks, this->lost_stracks);这个函数的作用就是把两个列表合成一个列表,也就是把tracked_stracks和lost_stracks都放到了strack_pool里面。
  • STrack::multi_predict(strack_pool, this->kalman_filter);这个就是卡尔曼滤波了,就是更新每个跟踪对象的均值向量mean (xCenter,yCenter,w/h,h,Vx,Vy,Vr,Vh)和协方差矩阵covariance,这就是所谓的预测.
  • dists = iou_distance(strack_pool, detections, dist_size, dist_size_size);这是是求解了IOU距离矩阵,利用预测框和检测框,求解距离矩阵,然后dist_size, dist_size_size是矩阵的行列数。
  • linear_assignment(dists, dist_size, dist_size_size, HIGH_MATCH_THRESH, matches, u_track, u_detection);这个是利用匈牙利算法做匹配,注意matches, u_track, u_detection这三个参数分别表示匹配上的,还有没匹配上的跟踪框,没匹配上的检测框。u_track, u_detection这个在后面还会用到。
  • track->update(*det, this->frame_id);这里就是用卡尔曼滤波,然后根据检测值去纠正下之前的预测值,也就是说最终结果是用卡尔曼滤波然后根据检测值和预测值得到最优值。
  • track->re_activate(*det, this->frame_id, false);这个分支说之前lost_track这次匹配上了,那么说明是重新找到了,比如可能是遮挡的重新出现了。

7.4 第二次匹配:前面未匹配上的tracked_stracks  和低得分检测框进行匹配

// Step 3: Second association, using low score dets //
        for (int i = 0; i < u_detection.size(); i++)
        {
            detections_cp.push_back(detections[u_detection[i]]);
        }
        detections.clear();
        detections.assign(detections_low.begin(), detections_low.end());

        for (int i = 0; i < u_track.size(); i++)
        {
            if (strack_pool[u_track[i]]->state == TrackState::Tracked)
            {
                r_tracked_stracks.push_back(strack_pool[u_track[i]]);
            }
        }

        dists.clear();
        dists = iou_distance(r_tracked_stracks, detections, dist_size, dist_size_size);

        matches.clear();
        u_track.clear();
        u_detection.clear();
        linear_assignment(dists, dist_size, dist_size_size, LOW_MATCH_THRESH, matches, u_track, u_detection);

        for (int i = 0; i < matches.size(); i++)
        {
            STrack *track = r_tracked_stracks[matches[i][0]];
            STrack *det = &detections[matches[i][1]];
            if (track->state == TrackState::Tracked)
            {
                track->update(*det, this->frame_id);
                activated_stracks.push_back(*track);
            }
            else
            {
                track->re_activate(*det, this->frame_id, false);
                refind_stracks.push_back(*track);
            }
            update_motionless_time(track);
        }

        for (int i = 0; i < u_track.size(); i++)
        {
            STrack *track = r_tracked_stracks[u_track[i]];
            if (track->state != TrackState::Lost)
            {
                track->mark_lost();
                lost_stracks.push_back(*track);
            }
        }

这写代码和前面的类似,区别就在于,这次是用的跟踪链是第一次没匹配上的跟踪链,然后这次用的检测框是低得分检测框,剩下的基本上一样,也是先求iou矩阵,然后做匹配。

7.5 第三次匹配:unconfirmed track和前面没匹配上的高得分检测框进行匹配

// Deal with unconfirmed tracks, usually tracks with only one beginning frame
        detections.clear();
        detections.assign(detections_cp.begin(), detections_cp.end());

        dists.clear();
        dists = iou_distance(unconfirmed, detections, dist_size, dist_size_size);

        matches.clear();
        vector<int> u_unconfirmed;
        u_detection.clear();
        linear_assignment(dists, dist_size, dist_size_size, UNCONFIRMED_MATCH_THRESH, matches, u_unconfirmed, u_detection);

        for (int i = 0; i < matches.size(); i++)
        {
            unconfirmed[matches[i][0]]->update(detections[matches[i][1]], this->frame_id);
            activated_stracks.push_back(*unconfirmed[matches[i][0]]);
        }

        for (int i = 0; i < u_unconfirmed.size(); i++)
        {
            STrack *track = unconfirmed[u_unconfirmed[i]];
            track->mark_removed();
            removed_stracks.push_back(*track);
        }

匹配方式也是一样,还是求iou矩阵,然后匈牙利算法进行匹配。

7.6 还没有匹配上的高得分检测框:认为是个新出现的目标,新建一个跟踪链

      // Step 4: Init new stracks //
        for (int i = 0; i < u_detection.size(); i++)
        {
            STrack *track = &detections[u_detection[i]];
            //static int high_thresh = Config::instance().int_conf("HIGH_THRESH", 60);
            static int high_thresh = 60;
            if (track->score < high_thresh)
                continue;
            track->activate(this->kalman_filter, this->frame_id);
            activated_stracks.push_back(*track);
        }

这里就是针对前面都没有匹配成功的高得分检测框,新建个跟踪链。

8 参考文献:

目标追踪 ByteTrack 算法详细流程分析 - 金色旭光 - 博客园

ByteTrack流程剖析(C++版本)_bytetrack c++-CSDN博客

 实时目标追踪:ByteTrack算法步骤详解和代码逐行解析_bytetrack 源码分析-CSDN博客

ultralytics框架实现ByteTrack目标追踪算法_51CTO博客_目标检测追踪

【目标跟踪】ByteTrack详解与代码细节-CSDN博客

 【目标跟踪】ByteTrack详解与代码细节_目标跟踪_神仙罗辑-开放原子开发者工作坊

【目标跟踪】ByteTrack详解与代码细节-腾讯云开发者社区-腾讯云

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

陈 洪 伟

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值