yolov5+deepsort实现在跟踪时显示类别信息

问题:deepsort的输出未包含目标框的类别信息

解决办法: 通过修改deepsort源码实现类别的输出

1. deepsort的输入

# deepsort原有的输入
deepsort.update(xyhw, confidence, im0)
# deepsort修改后的输入
deepsort.update(xyhw, confidence, im0, labels)
# 其中labels是有目标检测阶段生成的预测标签, 题主使用是list,每个元素是one hot 格式的标签

2.修改deepsort的update函数

2.1 update函数的定义如下

    def update(self, bbox_xywh, confidences, ori_img):
        self.height, self.width = ori_img.shape[:2]
        # generate detections
        features = self._get_features(bbox_xywh, ori_img)
        bbox_tlwh = self._xywh_to_tlwh(bbox_xywh)
        detections = [Detection(bbox_tlwh[i], conf, features[i]) for i, conf in enumerate(
            confidences) if conf > self.min_confidence]

        # run on non-maximum supression
        boxes = np.array([d.tlwh for d in detections])
        scores = np.array([d.confidence for d in detections])

        # update tracker
        self.tracker.predict()
        self.tracker.update(detections)

        # output bbox identities
        outputs = []
        for track in self.tracker.tracks:
            if not track.is_confirmed() or track.time_since_update > 1:
                continue
            box = track.to_tlwh()
            x1, y1, x2, y2 = self._tlwh_to_xyxy(box)
            track_id = track.track_id
            outputs.append(np.array([x1, y1, x2, y2, label,track_id], dtype=np.int))
        if len(outputs) > 0:
            outputs = np.stack(outputs, axis=0)
        return outputs

2.1 修改update函数

# 修改后的update函数
    def update(self, bbox_xywh, confidences, ori_img,labels): # 修改处,新增了labels输入
        self.height, self.width = ori_img.shape[:2]
        # generate detections
        features = self._get_features(bbox_xywh, ori_img)
        bbox_tlwh = self._xywh_to_tlwh(bbox_xywh)
        detections = [Detection(bbox_tlwh[i], conf, features[i],labels[i]) for i, conf in enumerate(
            confidences) if conf > self.min_confidence] #修改处,对于detections新增了相应目标的label

        # run on non-maximum supression
        boxes = np.array([d.tlwh for d in detections])
        scores = np.array([d.confidence for d in detections])

        # update tracker
        self.tracker.predict()
        self.tracker.update(detections)

        # output bbox identities
        outputs = []
        for track in self.tracker.tracks:
            if not track.is_confirmed() or track.time_since_update > 1:
                continue
            box = track.to_tlwh()
            x1, y1, x2, y2 = self._tlwh_to_xyxy(box)
            track_id = track.track_id
            label = track.label# 新增此处,通过track.label取到track的label
            outputs.append(np.array([x1, y1, x2, y2, label,track_id], dtype=np.int)) # 修改此处,使得outputs中包含了label
        if len(outputs) > 0:
            outputs = np.stack(outputs, axis=0)
        return outputs

由于修改了Detection的输入因此也需要对Detection类进行相应修改

2.4 修改Detection类

# detections修改前
class Detection(object):
	    def __init__(self, tlwh, confidence, feature):
	        self.tlwh = np.asarray(tlwh, dtype=np.float)
	        self.confidence = float(confidence)
	        self.feature = np.asarray(feature, dtype=np.float32)
# 修改后
class Detection(object):
	    def __init__(self, tlwh, confidence, feature,label): # 新增label
	        self.tlwh = np.asarray(tlwh, dtype=np.float)
	        self.confidence = float(confidence)
	        self.feature = np.asarray(feature, dtype=np.float32)
	        self.label = label # 新增此行

2.4 修改Tracker类

class Tracker:
	def __init__(self, metric, max_iou_distance=0.7, max_age=70, n_init=3,label = None): # 新增label = None
		……
		self.label = label # 新增此行
		……
	def update(self, detections):
		        for track_idx, detection_idx in matches:
		            self.tracks[track_idx].update(self.kf, detections[detection_idx])
		            self.tracks[track_idx].label = detections[detection_idx].label # 新增此行

3. 修改思路

~~阅读update函数可以知道, 其输出组成由下列代码决定

x1, y1, x2, y2 = self._tlwh_to_xyxy(box)
track_id = track.track_id
outputs.append(np.array([x1, y1, x2, y2, track_id], dtype=np.int))

因此输出类别应该也由类似track.label输出
回溯代码,track是tracker.track的一个元素,而tracker.track是类别Tracker下的一个属性

class Tracker:
	def __init__(self, metric, max_iou_distance=0.7, max_age=70, n_init=3,label):
	……
	self.tracks = []
	……
  • 10
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值