基于boxmot风格修改的hybrid_sort

文章目录

hybrid_sort

"""
    This script is adopted from the SORT script by Alex Bewley alex@bewley.ai
"""
from __future__ import print_function

import numpy as np
from .association import *
import argparse
  


def k_previous_obs(observations, cur_age, k):
    if len(observations) == 0:
        return [-1, -1, -1, -1, -1]
    for i in range(k):
        dt = k - i
        if cur_age - dt in observations:
            return observations[cur_age-dt]
    max_age = max(observations.keys())
    return observations[max_age]


def convert_bbox_to_z(bbox):
    """
    Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form
      [x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is
      the aspect ratio
    """
    w = bbox[2] - bbox[0]
    h = bbox[3] - bbox[1]
    x = bbox[0] + w/2.
    y = bbox[1] + h/2.
    s = w * h  # scale is just area
    r = w / float(h+1e-6)
    score = bbox[4]
    if score:
        return np.array([x, y, s, score, r]).reshape((5, 1))
    else:
        return np.array([x, y, s, r]).reshape((4, 1))


def convert_x_to_bbox(x, score=None):
    """
    Takes a bounding box in the centre form [x,y,s,r] and returns it in the form
      [x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right
    """
    w = np.sqrt(x[2] * x[4])
    h = x[2] / w
    score = x[3]
    if(score == None):
      return np.array([x[0]-w/2., x[1]-h/2., x[0]+w/2., x[1]+h/2.]).reshape((1, 4))
    else:
      return np.array([x[0]-w/2., x[1]-h/2., x[0]+w/2., x[1]+h/2., score]).reshape((1, 5))


def speed_direction(bbox1, bbox2):
    cx1, cy1 = (bbox1[0]+bbox1[2]) / 2.0, (bbox1[1]+bbox1[3])/2.0
    cx2, cy2 = (bbox2[0]+bbox2[2]) / 2.0, (bbox2[1]+bbox2[3])/2.0
    speed = np.array([cy2-cy1, cx2-cx1])
    norm = np.sqrt((cy2-cy1)**2 + (cx2-cx1)**2) + 1e-6
    return speed / norm

def speed_direction_lt(bbox1, bbox2):
    cx1, cy1 = bbox1[0], bbox1[1]
    cx2, cy2 = bbox2[0], bbox2[1]
    speed = np.array([cy2-cy1, cx2-cx1])
    norm = np.sqrt((cy2-cy1)**2 + (cx2-cx1)**2) + 1e-6
    return speed / norm

def speed_direction_rt(bbox1, bbox2):
    cx1, cy1 = bbox1[0], bbox1[3]
    cx2, cy2 = bbox2[0], bbox2[3]
    speed = np.array([cy2-cy1, cx2-cx1])
    norm = np.sqrt((cy2-cy1)**2 + (cx2-cx1)**2) + 1e-6
    return speed / norm

def speed_direction_lb(bbox1, bbox2):
    cx1, cy1 = bbox1[2], bbox1[1]
    cx2, cy2 = bbox2[2], bbox2[1]
    speed = np.array([cy2-cy1, cx2-cx1])
    norm = np.sqrt((cy2-cy1)**2 + (cx2-cx1)**2) + 1e-6
    return speed / norm

def speed_direction_rb(bbox1, bbox2):
    cx1, cy1 = bbox1[2], bbox1[3]
    cx2, cy2 = bbox2[2], bbox2[3]
    speed = np.array([cy2-cy1, cx2-cx1])
    norm = np.sqrt((cy2-cy1)**2 + (cx2-cx1)**2) + 1e-6
    return speed / norm

class KalmanBoxTracker(object):
    """
    This class represents the internal state of individual tracked objects observed as bbox.
    """
    count = 0

    def __init__(self, bbox, delta_t=3, orig=False, args=None):
        """
        Initialises a tracker using initial bounding box.

        """
        # define constant velocity model
        # if not orig and not args.kalman_GPR:
        if not orig:
          # from .kalmanfilter import KalmanFilterNew as KalmanFilter
          from .kalmanfilter_score_new import KalmanFilterNew_score_new as KalmanFilter_score_new
          self.kf = KalmanFilter_score_new(dim_x=9, dim_z=5)
          # self.kf_score = KalmanFilter_score(dim_x=2, dim_z=1)
        else:
          from filterpy.kalman import KalmanFilter
          self.kf = KalmanFilter(dim_x=7, dim_z=4)
        # u, v, s, c, r, ~u, ~v, ~s, ~c
        self.kf.F = np.array([[1, 0, 0, 0, 0, 1, 0, 0, 0],
                              [0, 1, 0, 0, 0, 0, 1, 0, 0],
                              [0, 0, 1, 0, 0, 0, 0, 1, 0],
                              [0, 0, 0, 1, 0, 0, 0, 0, 1],
                              [0, 0, 0, 0, 1, 0, 0, 0, 0],
                              [0, 0, 0, 0, 0, 1, 0, 0, 0],
                              [0, 0, 0, 0, 0, 0, 1, 0, 0],
                              [0, 0, 0, 0, 0, 0, 0, 1, 0],
                              [0, 0, 0, 0, 0, 0, 0, 0, 1]])
        self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0, 0, 0],
                              [0, 1, 0, 0, 0, 0, 0, 0, 0],
                              [0, 0, 1, 0, 0, 0, 0, 0, 0],
                              [0, 0, 0, 1, 0, 0, 0, 0, 0],
                              [0, 0, 0, 0, 1, 0, 0, 0, 0]])
        # self.kf_score.F = np.array([[1, 1],
        #                             [0, 1]])
        # self.kf_score.H = np.array([[1, 0]])

        self.kf.R[2:, 2:] *= 10.
        self.kf.P[5:, 5:] *= 1000.  # give high uncertainty to the unobservable initial velocities
        self.kf.P *= 10.
        self.kf.Q[-1, -1] *= 0.01
        self.kf.Q[-2, -2] *= 0.01
        self.kf.Q[5:, 5:] *= 0.01

        self.kf.x[:5] = convert_bbox_to_z(bbox)


        # self.kf_score.R[0:, 0:] *= 10.
        # self.kf_score.P[1:, 1:] *= 1000.  # give high uncertainty to the unobservable initial velocities
        # self.kf_score.P *= 10.
        # self.kf_score.Q[-1, -1] *= 0.01
        # self.kf_score.Q[1:, 1:] *= 0.01
        # self.kf_score.x[:1] = bbox[-1]


        self.time_since_update = 0
        self.id = KalmanBoxTracker.count
        KalmanBoxTracker.count += 1
        self.history = []
        self.hits = 0
        self.hit_streak = 0
        self.age = 0
        self.age_recover_for_cbiou = 0
        """
        NOTE: [-1,-1,-1,-1,-1] is a compromising placeholder for non-observation status, the same for the return of 
        function k_previous_obs. It is ugly and I do not like it. But to support generate observation array in a 
        fast and unified way, which you would see below k_observations = np.array([k_previous_obs(...]]), let's bear it for now.
        """
        self.last_observation = np.array([-1, -1, -1, -1, -1])  # placeholder
        self.last_observation_save = np.array([-1, -1, -1, -1, -1])
        self.observations = dict()
        self.history_observations = []
        # self.velocity = None
        self.velocity_lt = None
        self.velocity_rt = None
        self.velocity_lb = None
        self.velocity_rb = None
        self.delta_t = delta_t
        self.confidence_pre = None
        self.confidence = bbox[-1]
        self.args = args
        self.kf.args = args
        # self.kf_score.args = args

    def camera_update(self, warp_matrix):
        """
        update 'self.mean' of current tracklet with ecc results.
        Parameters
        ----------
        warp_matrix: warp matrix computed by ECC.
        """
        x1, y1, x2, y2, s = convert_x_to_bbox(self.kf.x)[0]
        x1_, y1_ = warp_matrix @ np.array([x1, y1, 1]).T
        x2_, y2_ = warp_matrix @ np.array([x2, y2, 1]).T
        # w, h = x2_ - x1_, y2_ - y1_
        # cx, cy = x1_ + w / 2, y1_ + h / 2
        self.kf.x[:5] = convert_bbox_to_z([x1_, y1_, x2_, y2_, s])

    def update(self, bbox):
        """
        Updates the state vector with observed bbox.
        """
        # velocity = None
        velocity_lt = None
        velocity_rt = None
        velocity_lb = None
        velocity_rb = None
        if bbox is not None:
            if self.last_observation.sum() >= 0:  # no previous observation
                previous_box = None
                for i in range(self.delta_t):
                    # dt = self.delta_t - i
                    if self.age - i - 1 in self.observations:
                        previous_box = self.observations[self.age - i - 1]
                        if velocity_lt is not None:
                            # velocity += speed_direction(previous_box, bbox)
                            velocity_lt += speed_direction_lt(previous_box, bbox)
                            velocity_rt += speed_direction_rt(previous_box, bbox)
                            velocity_lb += speed_direction_lb(previous_box, bbox)
                            velocity_rb += speed_direction_rb(previous_box, bbox)
                        else:
                            # velocity = speed_direction(previous_box, bbox)
                            velocity_lt = speed_direction_lt(previous_box, bbox)
                            velocity_rt = speed_direction_rt(previous_box, bbox)
                            velocity_lb = speed_direction_lb(previous_box, bbox)
                            velocity_rb = speed_direction_rb(previous_box, bbox)
                        # break
                if previous_box is None:
                    previous_box = self.last_observation
                    # self.velocity = speed_direction(previous_box, bbox)
                    # self.velocity = norm_vel(self.velocity)
                    self.velocity_lt = speed_direction_lt(previous_box, bbox)
                    self.velocity_rt = speed_direction_rt(previous_box, bbox)
                    self.velocity_lb = speed_direction_lb(previous_box, bbox)
                    self.velocity_rb = speed_direction_rb(previous_box, bbox)
                else:
                    # self.velocity = velocity
                    # self.velocity = norm_vel(self.velocity)
                    self.velocity_lt = velocity_lt
                    self.velocity_rt = velocity_rt
                    self.velocity_lb = velocity_lb
                    self.velocity_rb = velocity_rb
            """
              Insert new observations. This is a ugly way to maintain both self.observations
              and self.history_observations. Bear it for the moment.
            """
            self.last_observation = bbox
            self.last_observation_save = bbox
            self.observations[self.age] = bbox
            self.history_observations.append(bbox)

            self.time_since_update = 0
            self.history = []
            self.hits += 1
            self.hit_streak += 1
            self.kf.update(convert_bbox_to_z(bbox))
            # self.kf_score.update(bbox[-1])
            self.confidence_pre = self.confidence
            self.confidence = bbox[-1]
            self.age_recover_for_cbiou = self.age
        else:
            self.kf.update(bbox)
            # self.kf_score.update(bbox)
            self.confidence_pre = None

    def predict(self):
        """
        Advances the state vector and returns the predicted bounding box estimate.
        """
        if((self.kf.x[7]+self.kf.x[2]) <= 0):
            self.kf.x[7] *= 0.0

        self.kf.predict()
        # self.kf_score.predict()
        self.age += 1
        if(self.time_since_update > 0):
            self.hit_streak = 0
        self.time_since_update += 1
        self.history.append(convert_x_to_bbox(self.kf.x))
        if not self.confidence_pre:
            return self.history[-1], np.clip(self.kf.x[3], self.args.track_thresh, 1.0), np.clip(self.confidence, 0.1, self.args.track_thresh)
        else:
            return self.history[-1], np.clip(self.kf.x[3], self.args.track_thresh, 1.0), np.clip(self.confidence - (self.confidence_pre - self.confidence), 0.1, self.args.track_thresh)

    def get_state(self):
        """
        Returns the current bounding box estimate.
        """
        return convert_x_to_bbox(self.kf.x)


"""
    We support multiple ways for association cost calculation, by default
    we use IoU. GIoU may have better performance in some situations. We note 
    that we hardly normalize the cost by all methods to (0,1) which may not be 
    the best practice.
"""
ASSO_FUNCS = {  "iou": iou_batch,
                "giou": giou_batch,
                "ciou": ciou_batch,
                "diou": diou_batch,
                "ct_dist": ct_dist,
                "Height_Modulated_IoU": hmiou
                }


class Hybird_Sort(object):
    def __init__(self, max_age=30, min_hits=3,
        iou_threshold=0.3, delta_t=3, asso_func="iou", inertia=0.2, use_byte=False):
        """
        Sets key parameters for SORT
        """
        self.args = self.make_parser().parse_args()
        # print(f"args : {self.args}")
        self.max_age = max_age
        self.min_hits = min_hits
        self.iou_threshold = iou_threshold
        self.trackers = []
        self.frame_count = 0
        self.det_thresh = self.args.track_thresh
        self.delta_t = delta_t
        self.asso_func = ASSO_FUNCS[asso_func]
        self.inertia = inertia
        self.use_byte = use_byte
        KalmanBoxTracker.count = 0
       

    def make_parser(self,):
        parser = argparse.ArgumentParser("OC-SORT parameters")
        parser.add_argument("--expn", type=str, default=None)
        parser.add_argument("-n", "--name", type=str, default=None, help="model name")

        # distributed
        parser.add_argument( "--dist-backend", default="nccl", type=str, help="distributed backend")
        parser.add_argument("--output_dir", type=str, default="evaldata/trackers/mot_challenge")
        parser.add_argument("--dist-url", default=None, type=str, help="url used to set up distributed training")
        parser.add_argument("-b", "--batch-size", type=int, default=64, help="batch size")
        parser.add_argument("-d", "--devices", default=None, type=int, help="device for training")

        parser.add_argument("--local_rank", default=0, type=int, help="local rank for dist training")
        parser.add_argument( "--num_machines", default=1, type=int, help="num of node for training")
        parser.add_argument("--machine_rank", default=0, type=int, help="node rank for multi-node training")

        parser.add_argument(
            "-f", "--exp_file",
            default=None,
            type=str,
            help="pls input your expriment description file",
        )
        parser.add_argument(
            "--fp16", dest="fp16",
            default=False,
            action="store_true",
            help="Adopting mix precision evaluating.",
        )
        parser.add_argument("--fuse", dest="fuse", default=False, action="store_true", help="Fuse conv and bn for testing.",)
        parser.add_argument("--trt", dest="trt", default=False, action="store_true", help="Using TensorRT model for testing.",)
        parser.add_argument("--test", dest="test", default=False, action="store_true", help="Evaluating on test-dev set.",)
        parser.add_argument("--speed", dest="speed", default=False, action="store_true", help="speed test only.",)
        parser.add_argument("opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER,)
        
        # det args
        parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
        parser.add_argument("--conf", default=0.1, type=float, help="test conf")
        parser.add_argument("--nms", default=0.7, type=float, help="test nms threshold")
        parser.add_argument("--tsize", default=None, type=int, help="test img size")
        parser.add_argument("--seed", default=None, type=int, help="eval seed")

        # tracking args
        parser.add_argument("--track_thresh", type=float, default=0.6, help="detection confidence threshold")
        parser.add_argument("--iou_thresh", type=float, default=0.3, help="the iou threshold in Sort for matching")
        parser.add_argument("--min_hits", type=int, default=3, help="min hits to create track in SORT")
        parser.add_argument("--inertia", type=float, default=0.2, help="the weight of VDC term in cost matrix")
        parser.add_argument("--deltat", type=int, default=3, help="time step difference to estimate direction")
        parser.add_argument("--track_buffer", type=int, default=30, help="the frames for keep lost tracks")
        parser.add_argument("--match_thresh", type=float, default=0.9, help="matching threshold for tracking")
        parser.add_argument('--min-box-area', type=float, default=100, help='filter out tiny boxes')
        parser.add_argument("--gt-type", type=str, default="_val_half", help="suffix to find the gt annotation")
        parser.add_argument("--mot20", dest="mot20", default=False, action="store_true", help="test mot20.")
        parser.add_argument("--public", action="store_true", help="use public detection")
        parser.add_argument('--asso', default="iou", help="similarity function: iou/giou/diou/ciou/ctdis")
        parser.add_argument("--use_byte", dest="use_byte", default=False, action="store_true", help="use byte in tracking.")

        parser.add_argument("--TCM_first_step", default=False, action="store_true", help="use TCM in first step.")
        parser.add_argument("--TCM_byte_step", default=False, action="store_true", help="use TCM in byte step.")
        parser.add_argument("--TCM_first_step_weight", type=float, default=1.0, help="TCM first step weight")
        parser.add_argument("--TCM_byte_step_weight", type=float, default=1.0, help="TCM second step weight")
        parser.add_argument("--hybird_sort_with_reid", default=False, action="store_true", help="use ReID model for Hybird SORT.")

        # for fast reid
        parser.add_argument("--EG_weight_high_score", default=0.0, type=float, help="weight of appearance cost matrix when using EG")
        parser.add_argument("--EG_weight_low_score", default=0.0, type=float, help="weight of appearance cost matrix when using EG")
        parser.add_argument("--low_thresh", default=0.1, type=float, help="threshold of low score detections for BYTE")
        parser.add_argument("--high_score_matching_thresh", default=0.8, type=float, help="matching threshold for detections with high score")
        parser.add_argument("--low_score_matching_thresh", default=0.5, type=float, help="matching threshold for detections with low score")
        parser.add_argument("--alpha", default=0.8, type=float, help="momentum of embedding update")
        parser.add_argument("--with_fastreid", dest="with_fastreid", default=False, action="store_true", help="use FastReID flag.")
        parser.add_argument("--fast_reid_config", dest="fast_reid_config", default=r"fast_reid/configs/CUHKSYSU_DanceTrack/sbs_S50.yml", type=str, help="reid config file path")
        parser.add_argument("--fast_reid_weights", dest="fast_reid_weights", default=r"fast_reid/logs/CUHKSYSU_DanceTrack/sbs_S50/model_final.pth", type=str, help="reid weight path")
        parser.add_argument("--with_longterm_reid", dest="with_longterm_reid", default=False, action="store_true", help="use long-term reid features for association.")
        parser.add_argument("--longterm_reid_weight", default=0.0, type=float, help="weight of appearance cost matrix when using long term reid features in 1st stage association")
        parser.add_argument("--longterm_reid_weight_low", default=0.0, type=float, help="weight of appearance cost matrix when using long term reid features in 2nd stage association")
        parser.add_argument("--with_longterm_reid_correction", dest="with_longterm_reid_correction", default=False, action="store_true", help="use long-term reid features for association correction.")
        parser.add_argument("--longterm_reid_correction_thresh", default=1.0, type=float, help="threshold of correction when using long term reid features in 1st stage association")
        parser.add_argument("--longterm_reid_correction_thresh_low", default=1.0, type=float, help="threshold of correction when using long term reid features in 2nd stage association")
        parser.add_argument("--longterm_bank_length", type=int, default=30, help="max length of reid feat bank")
        parser.add_argument("--adapfs", dest="adapfs", default=False, action="store_true", help="Adaptive Feature Smoothing.")
        # ECC for CMC
        parser.add_argument("--ECC", dest="ECC", default=False, action="store_true", help="use ECC for CMC.")

        # for kitti/bdd100k inference with public detections
        parser.add_argument('--raw_results_path', type=str, default="exps/permatrack_kitti_test/",
            help="path to the raw tracking results from other tracks")
        parser.add_argument('--out_path', type=str, help="path to save output results")
        parser.add_argument("--dataset", type=str, default="mot17", help="kitti or bdd")
        parser.add_argument("--hp", action="store_true", help="use head padding to add the missing objects during \
                initializing the tracks (offline).")

        # for demo video
        parser.add_argument("--demo_type", default="image", help="demo type, eg. image, video and webcam")
        parser.add_argument( "--path", default="./videos/demo.mp4", help="path to images or video")
        parser.add_argument("--demo_dancetrack", default=False, action="store_true",
                            help="only for dancetrack demo, replace timestamp with dancetrack sequence name.")
        parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
        parser.add_argument(
            "--save_result",
            action="store_true",
            help="whether to save the inference result of image/video",
        )
        parser.add_argument(
            "--aspect_ratio_thresh", type=float, default=1.6,
            help="threshold for filtering out boxes of which aspect ratio are above the given value."
        )
        parser.add_argument('--min_box_area', type=float, default=10, help='filter out tiny boxes')
        parser.add_argument(
            "--device",
            default="gpu",
            type=str,
            help="device to run our model, can either be cpu or gpu",
        )
        return parser


    def camera_update(self, trackers, warp_matrix):
        for tracker in trackers:
            tracker.camera_update(warp_matrix)

    def update(self, output_results, img):
        """
        Params:
          dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...]
        Requires: this method must be called once for each frame even with empty detections (use np.empty((0, 5)) for frames without detections).
        Returns the a similar array, where the last column is the object ID.
        NOTE: The number of objects returned may differ from the number of detections provided.
        """
        if output_results is None:
            return np.empty((0, 5))

        self.frame_count += 1
        # post_process detections
        scores = output_results[:, 4]
        bboxes = output_results[:, :4]

        dets = np.concatenate((bboxes, np.expand_dims(scores, axis=-1)), axis=1)
        inds_low = scores > 0.1
        inds_high = scores < self.det_thresh
        inds_second = np.logical_and(inds_low, inds_high)  # self.det_thresh > score > 0.1, for second matching
        dets_second = dets[inds_second]  # detections for second matching
        remain_inds = scores > self.det_thresh
        dets = dets[remain_inds]

        # get predicted locations from existing trackers.
        trks = np.zeros((len(self.trackers), 8))
        to_del = []
        ret = []
        for t, trk in enumerate(trks):
            pos, kalman_score, simple_score = self.trackers[t].predict()
            try:
                trk[:6] = [pos[0][0], pos[0][1], pos[0][2], pos[0][3], kalman_score, simple_score[0]]
            except:
                trk[:6] = [pos[0][0], pos[0][1], pos[0][2], pos[0][3], kalman_score[0], simple_score]
            if np.any(np.isnan(pos)):
                to_del.append(t)
        trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
        for t in reversed(to_del):
            self.trackers.pop(t)

        # velocities = np.array(
        #     [trk.velocity if trk.velocity is not None else np.array((0, 0)) for trk in self.trackers])
        velocities_lt = np.array(
            [trk.velocity_lt if trk.velocity_lt is not None else np.array((0, 0)) for trk in self.trackers])
        velocities_rt = np.array(
            [trk.velocity_rt if trk.velocity_rt is not None else np.array((0, 0)) for trk in self.trackers])
        velocities_lb = np.array(
            [trk.velocity_lb if trk.velocity_lb is not None else np.array((0, 0)) for trk in self.trackers])
        velocities_rb = np.array(
            [trk.velocity_rb if trk.velocity_rb is not None else np.array((0, 0)) for trk in self.trackers])
        last_boxes = np.array([trk.last_observation for trk in self.trackers])
        k_observations = np.array(
            [k_previous_obs(trk.observations, trk.age, self.delta_t) for trk in self.trackers])

        """
            First round of association
        """
        if self.args.TCM_first_step:
            matched, unmatched_dets, unmatched_trks = associate_4_points_with_score(
                dets, trks, self.iou_threshold, velocities_lt, velocities_rt, velocities_lb, velocities_rb,
                k_observations, self.inertia, self.asso_func, self.args)
        else:
            matched, unmatched_dets, unmatched_trks = associate_4_points(
                dets, trks, self.iou_threshold, velocities_lt, velocities_rt, velocities_lb, velocities_rb, k_observations, self.inertia, self.asso_func, self.args)

        for m in matched:
            self.trackers[m[1]].update(dets[m[0], :])

        """
            Second round of associaton by OCR
        """
        # BYTE association
        if self.use_byte and len(dets_second) > 0 and unmatched_trks.shape[0] > 0:
            u_trks = trks[unmatched_trks]
            iou_left = self.asso_func(dets_second, u_trks)
            iou_left = np.array(iou_left)
            if iou_left.max() > self.iou_threshold:
                """
                    NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may
                    get a higher performance especially on MOT17/MOT20 datasets. But we keep it
                    uniform here for simplicity
                """
                if self.args.TCM_byte_step:
                    iou_left -= np.array(cal_score_dif_batch_two_score(dets_second, u_trks) * self.args.TCM_byte_step_weight)
                matched_indices = linear_assignment(-iou_left)
                to_remove_trk_indices = []
                for m in matched_indices:
                    det_ind, trk_ind = m[0], unmatched_trks[m[1]]
                    if iou_left[m[0], m[1]] < self.iou_threshold:
                        continue
                    self.trackers[trk_ind].update(dets_second[det_ind, :])
                    to_remove_trk_indices.append(trk_ind)
                unmatched_trks = np.setdiff1d(unmatched_trks, np.array(to_remove_trk_indices))

        if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0:
            left_dets = dets[unmatched_dets]
            left_trks = last_boxes[unmatched_trks]
            iou_left = self.asso_func(left_dets, left_trks)
            iou_left = np.array(iou_left)

            if iou_left.max() > self.iou_threshold:
                """
                    NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may
                    get a higher performance especially on MOT17/MOT20 datasets. But we keep it
                    uniform here for simplicity
                """
                rematched_indices = linear_assignment(-iou_left)
                to_remove_det_indices = []
                to_remove_trk_indices = []
                for m in rematched_indices:
                    det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[1]]
                    if iou_left[m[0], m[1]] < self.iou_threshold:
                        continue
                    self.trackers[trk_ind].update(dets[det_ind, :])
                    to_remove_det_indices.append(det_ind)
                    to_remove_trk_indices.append(trk_ind)
                unmatched_dets = np.setdiff1d(unmatched_dets, np.array(to_remove_det_indices))
                unmatched_trks = np.setdiff1d(unmatched_trks, np.array(to_remove_trk_indices))

        for m in unmatched_trks:
            self.trackers[m].update(None)

        # create and initialise new trackers for unmatched detections
        for i in unmatched_dets:
            trk = KalmanBoxTracker(dets[i, :], delta_t=self.delta_t, args=self.args)
            self.trackers.append(trk)
        i = len(self.trackers)
        for trk in reversed(self.trackers):
            if trk.last_observation.sum() < 0:
                d = trk.get_state()[0][:4]
            else:
                """
                    this is optional to use the recent observation or the kalman filter prediction,
                    we didn't notice significant difference here
                """
                d = trk.last_observation[:4]
            if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits):
                # +1 as MOT benchmark requires positive
                ret.append(np.concatenate((d, [trk.id+1])).reshape(1, -1))
            i -= 1
            # remove dead tracklet
            if(trk.time_since_update > self.max_age):
                self.trackers.pop(i)
        if(len(ret) > 0):
            return np.concatenate(ret)
        return np.empty((0, 5))

    def update_public(self, dets, cates, scores):
        self.frame_count += 1

        det_scores = np.ones((dets.shape[0], 1))
        dets = np.concatenate((dets, det_scores), axis=1)

        remain_inds = scores > self.det_thresh
        
        cates = cates[remain_inds]
        dets = dets[remain_inds]

        trks = np.zeros((len(self.trackers), 5))
        to_del = []
        ret = []
        for t, trk in enumerate(trks):
            pos = self.trackers[t].predict()[0]
            cat = self.trackers[t].cate
            trk[:] = [pos[0], pos[1], pos[2], pos[3], cat]
            if np.any(np.isnan(pos)):
                to_del.append(t)
        trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
        for t in reversed(to_del):
            self.trackers.pop(t)

        velocities = np.array([trk.velocity if trk.velocity is not None else np.array((0,0)) for trk in self.trackers])
        last_boxes = np.array([trk.last_observation for trk in self.trackers])
        k_observations = np.array([k_previous_obs(trk.observations, trk.age, self.delta_t) for trk in self.trackers])

        matched, unmatched_dets, unmatched_trks = associate_kitti\
              (dets, trks, cates, self.iou_threshold, velocities, k_observations, self.inertia)
          
        for m in matched:
            self.trackers[m[1]].update(dets[m[0], :])
          
        if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0:
            """
                The re-association stage by OCR.
                NOTE: at this stage, adding other strategy might be able to continue improve
                the performance, such as BYTE association by ByteTrack. 
            """
            left_dets = dets[unmatched_dets]
            left_trks = last_boxes[unmatched_trks]
            left_dets_c = left_dets.copy()
            left_trks_c = left_trks.copy()

            iou_left = self.asso_func(left_dets_c, left_trks_c)
            iou_left = np.array(iou_left)
            det_cates_left = cates[unmatched_dets]
            trk_cates_left = trks[unmatched_trks][:,4]
            num_dets = unmatched_dets.shape[0]
            num_trks = unmatched_trks.shape[0]
            cate_matrix = np.zeros((num_dets, num_trks))
            for i in range(num_dets):
                for j in range(num_trks):
                    if det_cates_left[i] != trk_cates_left[j]:
                            """
                                For some datasets, such as KITTI, there are different categories,
                                we have to avoid associate them together.
                            """
                            cate_matrix[i][j] = -1e6
            iou_left = iou_left + cate_matrix
            if iou_left.max() > self.iou_threshold - 0.1:
                rematched_indices = linear_assignment(-iou_left)
                to_remove_det_indices = []
                to_remove_trk_indices = []
                for m in rematched_indices:
                    det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[1]]
                    if iou_left[m[0], m[1]] < self.iou_threshold - 0.1:
                          continue
                    self.trackers[trk_ind].update(dets[det_ind, :])
                    to_remove_det_indices.append(det_ind)
                    to_remove_trk_indices.append(trk_ind) 
                unmatched_dets = np.setdiff1d(unmatched_dets, np.array(to_remove_det_indices))
                unmatched_trks = np.setdiff1d(unmatched_trks, np.array(to_remove_trk_indices))

        for i in unmatched_dets:
            trk = KalmanBoxTracker(dets[i,:])
            trk.cate = cates[i]
            self.trackers.append(trk)
        i = len(self.trackers)

        for trk in reversed(self.trackers):
            if trk.last_observation.sum() > 0:
                d = trk.last_observation[:4]
            else:
                d = trk.get_state()[0]
            if (trk.time_since_update < 1):
                if (self.frame_count <= self.min_hits) or (trk.hit_streak >= self.min_hits):
                    # id+1 as MOT benchmark requires positive
                    ret.append(np.concatenate((d, [trk.id+1], [trk.cate], [0])).reshape(1,-1)) 
                if trk.hit_streak == self.min_hits:
                    # Head Padding (HP): recover the lost steps during initializing the track
                    for prev_i in range(self.min_hits - 1):
                        prev_observation = trk.history_observations[-(prev_i+2)]
                        ret.append((np.concatenate((prev_observation[:4], [trk.id+1], [trk.cate], 
                            [-(prev_i+1)]))).reshape(1,-1))
            i -= 1 
            if (trk.time_since_update > self.max_age):
                  self.trackers.pop(i)
        
        if(len(ret)>0):
            return np.concatenate(ret)
        return np.empty((0, 7))



example

tracker = Hybird_Sort()
xxxxxxxxxxxxxxxxx
# Update object tracker                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                
online_targets = tracker.update(detections, frame)
生成混合图像(Hybrid Image)是一种将高频信息和低频信息混合在一起的图像处理技术。高频信息包括图像中的细节和纹理,而低频信息则包括图像中的整体形状和结构。通过将两个图像叠加在一起,可以生成一个混合图像,其中高频信息来自第一张图像,低频信息来自第二张图像。 生成混合图像的过程需要经过以下步骤: 1. 对两张待混合的图像进行高通和低通滤波,分别提取出高频信息和低频信息。 2. 将两张图像的低频信息进行叠加,得到混合图像的低频部分。 3. 将两张图像的高频信息进行差值,得到混合图像的高频部分。 4. 将低频部分和高频部分进行叠加,得到最终的混合图像。 使用Python和OpenCV库可以很容易地实现混合图像的生成。你可以使用以下代码进行尝试: ``` python import cv2 import numpy as np # 读取两张待混合的图像 img1 = cv2.imread('image1.jpg') img2 = cv2.imread('image2.jpg') # 对图像进行高通和低通滤波,提取出高频信息和低频信息 kernel_size = 31 sigma = 11 img1_lowpass = cv2.GaussianBlur(img1, (kernel_size, kernel_size), sigma) img2_lowpass = cv2.GaussianBlur(img2, (kernel_size, kernel_size), sigma) img1_highpass = img1 - img1_lowpass img2_highpass = img2 - img2_lowpass # 将两张图像的低频信息进行叠加 hybrid_lowpass = cv2.addWeighted(img1_lowpass, 0.5, img2_lowpass, 0.5, 0) # 将两张图像的高频信息进行差值 hybrid_highpass = cv2.addWeighted(img1_highpass, 0.5, img2_highpass, 0.5, 0) # 将低频部分和高频部分进行叠加,得到最终的混合图像 hybrid_image = cv2.add(hybrid_lowpass, hybrid_highpass) # 显示混合图像 cv2.imshow('Hybrid Image', hybrid_image) cv2.waitKey(0) cv2.destroyAllWindows() ``` 在运行代码之前,你需要将`image1.jpg`和`image2.jpg`替换为两张你想要混合的图像。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

点PY

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值