YOLOv10结合StrongSORT+OSNet实现目标跟踪于重识别【附代码】


前言

本文将介绍如何使用YOLOv10和StrongSORT+OSNet实现目标跟踪+重识别+轨迹绘制,并详细讲解代码中的每个参数和函数的作用。


功能概述

1. 选择需要跟踪的视频路径
2. 加载 YOLOv10 模型并进行目标检测
3. 使用 StrongSORT+OSNet实现目标跟踪于重识别
4. 根据跟踪内容可视化绘制跟踪轨迹


必要环境

  1. 配置yolov10环境 可参考往期博客
    地址:https://blog.csdn.net/Dora_blank/article/details/139302363?spm=1001.2014.3001.5502

一、代码结构

1. 参数定义

parser = argparse.ArgumentParser()
# 检测参数1`
parser.add_argument('--weights', default=r"yolov10n.pt", type=str, help='weights path')
parser.add_argument('--source', default=r"video_1.mp4", type=str, help='video(.mp4)path')
parser.add_argument('--save', default=r"./save", type=str, help='save img or video path')
parser.add_argument('--vis', default=True, action='store_true', help='visualize image')
parser.add_argument('--conf_thre', type=float, default=0.5, help='conf_thre')
parser.add_argument('--iou_thre', type=float, default=0.5, help='iou_thre')

# 跟踪参数
parser.add_argument('--track_model', default=r"./track_models/osnet_x0_25_msmt17.pt", type=str, help='track model')
parser.add_argument('--max_dist', type=float, default=0.2, help='max dist')
parser.add_argument('--max_iou_distance', type=float, default=0.7, help='max_iou_distance')
parser.add_argument('--max_age', type=int, default=30, help='max_age')
parser.add_argument('--n_init', type=int, default=3, help='n_init')
parser.add_argument('--nn_budget', type=int, default=100, help='nn_budget')
# 解析参数
opt = parser.parse_args()

参数作用如下:
目标检测:
–weights:YOLOv10权重路径
–source:输入视频路径
–save:结果保存路径
–vis:可视化跟踪结果
–conf_thre:置信度阈值
–iou_thre:IoU阈值

目标跟踪:
–track_model:选择osnet重识别权重路径
–max_dist: 关联检测框和跟踪框的最大距离阈值,当距离小于这个阈值时,检测框和跟踪框会被认为是匹配的
–max_iou_distance: 关联的最大IOU距离阈值,当IOU小于这个阈值时,检测框和跟踪器会被认为是匹配的
–max_age: 跟踪框最大保留的帧数 超过时则会删除
–n_init: 定义了跟踪器需要被匹配到的最小连续帧数,只有连续被匹配到这么多帧的跟踪器才会被认为是一个有效的跟踪目标
–nn_budget: 每个跟踪器保留的外观特征向量的最大数量,用于计算外观相似性

2. 设备选择

根据是否有GPU可用,选择使用CUDA或CPU

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

3. 定义检测器类

定义一个Detector类来封装检测和跟踪逻辑并初始化模型参数

class Detector(object):
    def __init__(self, weight_path, conf_threshold=0.2, iou_threshold=0.5):
        self.device = device
        self.model = YOLOv10(weight_path).to(self.device)
        self.conf_threshold = conf_threshold
        self.iou_threshold = iou_threshold

        self.tracker = StrongSORT(
            opt.track_model,
            self.device,
            max_dist=opt.max_dist,
            max_iou_distance=opt.max_iou_distance,
            max_age=opt.max_age,
            n_init=opt.n_init,
            nn_budget=opt.nn_budget,
        )

        self.trajectories = {}
        self.max_trajectory_length = 5

4. 获取目标框颜色

在绘制轨迹和边框时,我们需要为每个目标分配不同的颜色

@staticmethod
def get_color(idx):
    idx = idx * 3
    color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
    return color

5. 检测与跟踪

用于对图像中检测到的目标进行跟踪,并绘制跟踪框和跟踪轨迹

def detect_image(self, img_bgr):
    image = img_bgr.copy()
    results = self.model(img_bgr, conf=self.conf_threshold, iou=self.iou_threshold)
    boxes = results[0].boxes.xyxy.cpu().numpy()  # xyxy format
    confidences = results[0].boxes.conf
    class_preds = results[0].boxes.cls

    confidences_expanded = confidences.unsqueeze(1)
    class_preds_expanded = class_preds.unsqueeze(1)
    boxes_tensor = torch.from_numpy(boxes).to(class_preds_expanded.device)
    xywhs = xyxy2xywh(boxes_tensor)
    online_targets = self.tracker.update(xywhs.cpu(), confidences_expanded.cpu(), class_preds_expanded.cpu(), image)

    for t in online_targets:
        tlwh = [t[0], t[1], t[2] - t[0], t[3] - t[1]]
        tid = t[4]
        cls = t[5]
        xmin, ymin, xmax, ymax = tlwh[0], tlwh[1], tlwh[0] + tlwh[2], tlwh[1] + tlwh[3]
        class_pred = int(cls)
        color = self.get_color(class_pred + 4)
        center = (int(xmin+xmax)//2, int(ymin + ymax)//2)
        bbox_label = results[0].names[class_pred]
        cv2.rectangle(img_bgr, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2)

        if tid not in self.trajectories:
            self.trajectories[tid] = deque(maxlen=self.max_trajectory_length)
        self.trajectories[tid].appendleft(center)

        # 截断轨迹长度
        if len(self.trajectories[tid]) > self.max_trajectory_length:
            self.trajectories[tid] = self.trajectories[tid][:self.max_trajectory_length]

        for i in range(1, len(self.trajectories[tid])):

            if self.trajectories[tid][i - 1] is None or self.trajectories[tid][i] is None:
                continue

            thickness = int(np.sqrt(64 / float(i + 1)))
            cv2.line(img_bgr, self.trajectories[tid][i-1],
                     self.trajectories[tid][i], color,
                     thickness)

        # 显示类名和跟踪ID
        cv2.putText(img_bgr, f"{bbox_label} {int(tid)}",
                    (int(xmin), int(ymin-5)), cv2.FONT_HERSHEY_COMPLEX, 0.6, color, 2)

    return img_bgr

6. 绘制轨迹线代码详解

这段代码会遍历每个跟踪框的中心点,使用OpenCV的cv2.line函数绘制轨迹线,线的长度以及厚度会随距离增加而减小,轨迹长度大于self.max_trajectory_length会被截断,以便更清晰地显示在图像中

for j in range(1, len(self.trajectories[tid])):
    if self.trajectories[tid][j - 1] is None or self.trajectories[tid][j] is None:
        continue

    thickness = int(np.sqrt(64 / float(j + 1)))
    cv2.line(img_bgr, self.trajectories[tid][j - 1],
             self.trajectories[tid][j], color,
             thickness)

7. 目标框格式转换详解

这段代码块目的是将检测结果转化为可传入StrongSORT中的格式

1. 扩展置信度和类别预测的维度

confidences_expanded = confidences.unsqueeze(1)
class_preds_expanded = class_preds.unsqueeze(1)

2. 将检测框转换为Tensor并移动到与预测相同的设备

boxes_tensor = torch.from_numpy(boxes).to(class_preds_expanded.device)

3. 将坐标格式转为中心点、宽度、高度

xywhs = xyxy2xywh(boxes_tensor)

4. 使用 StrongSORT的更新函数 得到跟踪框信息

online_targets = self.tracker.update(xywhs.cpu(), confidences_expanded.cpu(), class_preds_expanded.cpu(), image)

二、完整代码

完整代码如下:

import cv2
import torch
from ultralytics import YOLOv10
import os
import argparse
from strong_sort.strong_sort import StrongSORT
from collections import deque
import numpy as np


parser = argparse.ArgumentParser()
# 检测参数1`
parser.add_argument('--weights', default=r"yolov10n.pt", type=str, help='weights path')
parser.add_argument('--source', default=r"video_1.mp4", type=str, help='video(.mp4)path')
parser.add_argument('--save', default=r"./save", type=str, help='save img or video path')
parser.add_argument('--vis', default=True, action='store_true', help='visualize image')
parser.add_argument('--conf_thre', type=float, default=0.5, help='conf_thre')
parser.add_argument('--iou_thre', type=float, default=0.5, help='iou_thre')

# 跟踪参数
parser.add_argument('--track_model', default=r"./track_models/osnet_x0_25_msmt17.pt", type=str, help='track model')
parser.add_argument('--max_dist', type=float, default=0.2, help='max dist')
parser.add_argument('--max_iou_distance', type=float, default=0.7, help='max_iou_distance')
parser.add_argument('--max_age', type=int, default=30, help='max_age')
parser.add_argument('--n_init', type=int, default=3, help='n_init')
parser.add_argument('--nn_budget', type=int, default=100, help='nn_budget')
# 解析参数
opt = parser.parse_args()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def xyxy2xywh(x):
    # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[:, 0] = (x[:, 0] + x[:, 2]) / 2  # x center
    y[:, 1] = (x[:, 1] + x[:, 3]) / 2  # y center
    y[:, 2] = x[:, 2] - x[:, 0]  # width
    y[:, 3] = x[:, 3] - x[:, 1]  # height
    return y

class Detector(object):
    def __init__(self, weight_path, conf_threshold=0.2, iou_threshold=0.5):
        self.device = device
        self.model = YOLOv10(weight_path).to(self.device)
        self.conf_threshold = conf_threshold
        self.iou_threshold = iou_threshold

        self.tracker = StrongSORT(
            opt.track_model,
            self.device,
            max_dist=opt.max_dist,
            max_iou_distance=opt.max_iou_distance,
            max_age=opt.max_age,
            n_init=opt.n_init,
            nn_budget=opt.nn_budget,
        )

        self.trajectories = {}
        self.max_trajectory_length = 5

    @staticmethod
    def get_color(idx):
        idx = idx * 3
        color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
        return color

    def detect_image(self, img_bgr):
        image = img_bgr.copy()
        results = self.model(img_bgr, conf=self.conf_threshold, iou=self.iou_threshold)
        boxes = results[0].boxes.xyxy.cpu().numpy()  # xyxy format
        confidences = results[0].boxes.conf
        class_preds = results[0].boxes.cls

        confidences_expanded = confidences.unsqueeze(1)
        class_preds_expanded = class_preds.unsqueeze(1)
        boxes_tensor = torch.from_numpy(boxes).to(class_preds_expanded.device)
        xywhs = xyxy2xywh(boxes_tensor)
        online_targets = self.tracker.update(xywhs.cpu(), confidences_expanded.cpu(), class_preds_expanded.cpu(), image)

        for t in online_targets:
            tlwh = [t[0], t[1], t[2] - t[0], t[3] - t[1]]
            tid = t[4]
            cls = t[5]
            xmin, ymin, xmax, ymax = tlwh[0], tlwh[1], tlwh[0] + tlwh[2], tlwh[1] + tlwh[3]
            class_pred = int(cls)
            color = self.get_color(class_pred + 4)
            center = (int(xmin+xmax)//2, int(ymin + ymax)//2)
            bbox_label = results[0].names[class_pred]
            cv2.rectangle(img_bgr, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, 2)

            if tid not in self.trajectories:
                self.trajectories[tid] = deque(maxlen=self.max_trajectory_length)
            self.trajectories[tid].appendleft(center)

            # 截断轨迹长度
            if len(self.trajectories[tid]) > self.max_trajectory_length:
                self.trajectories[tid] = self.trajectories[tid][:self.max_trajectory_length]

            for i in range(1, len(self.trajectories[tid])):

                if self.trajectories[tid][i - 1] is None or self.trajectories[tid][i] is None:
                    continue

                thickness = int(np.sqrt(64 / float(i + 1)))
                cv2.line(img_bgr, self.trajectories[tid][i-1],
                         self.trajectories[tid][i], color,
                         thickness)

            # 显示类名和跟踪ID
            cv2.putText(img_bgr, f"{bbox_label} {int(tid)}",
                        (int(xmin), int(ymin-5)), cv2.FONT_HERSHEY_COMPLEX, 0.6, color, 2)

        return img_bgr


# Example usage
if __name__ == '__main__':
    model = Detector(weight_path=opt.weights, conf_threshold=opt.conf_thre, iou_threshold=opt.iou_thre)

    capture = cv2.VideoCapture(opt.source)
    fps = capture.get(cv2.CAP_PROP_FPS)
    size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)),
            int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
    fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
    outVideo = cv2.VideoWriter(os.path.join(opt.save, os.path.basename(opt.source).split('.')[-2] + "_out.mp4"), fourcc,
                               fps, size)

    while True:
        ret, frame = capture.read()
        if not ret:
            break
        img_vis = model.detect_image(frame)
        outVideo.write(img_vis)

        img_vis = cv2.resize(img_vis, None, fx=1, fy=1, interpolation=cv2.INTER_NEAREST)
        cv2.imshow('track', img_vis)
        cv2.waitKey(30)

    capture.release()
    outVideo.release()

三、效果展示

StrongSORT跟踪效果

在这里插入图片描述


总结

本期博客就到这里啦,喜欢的小伙伴们可以点点关注,感谢!

最近经常在b站上更新一些有关目标检测的视频,大家感兴趣可以来看看 https://b23.tv/1upjbcG

学习交流群:995760755

  • 15
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
以下是使用YOLOv7和Strong Sort结合起来实现目标检测和多目标跟踪的Python代码示例: ```python import cv2 from yolov7.detect import detect_image from strong_sort.sort import Sort # 初始化Strong Sort跟踪器 tracker = Sort() # 加载YOLOv7模型 model = 'yolov7/yolov7.weights' config = 'yolov7/yolov7.cfg' classes = 'yolov7/coco.names' # 定义阈值和NMS参数 conf_threshold = 0.5 nms_threshold = 0.4 # 加载模型和类别信息 net = cv2.dnn.readNetFromDarknet(config, model) classes = open(classes).read().strip().split('\n') # 打开摄像头或视频文件 cap = cv2.VideoCapture(0) while True: # 读取视频帧 ret, frame = cap.read() if not ret: break # 调用YOLOv7进行目标检测 detections = detect_image(frame, net, classes, conf_threshold, nms_threshold) # 预测每个物体的位置和类别 boxes = [] scores = [] for detection in detections: x, y, w, h, score, label = detection boxes.append([x, y, x+w, y+h]) scores.append(score) # 调用Strong Sort进行多目标跟踪 boxes = tracker.update(boxes, scores) # 在图像中绘制跟踪结果 for box in boxes: x1, y1, x2, y2, id = box cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2) cv2.putText(frame, str(id), (int(x1), int(y1)), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 255, 0), 2) # 显示图像 cv2.imshow('frame', frame) # 按下q键退出程序 if cv2.waitKey(1) == ord('q'): break # 释放摄像头或视频文件 cap.release() # 关闭所有窗口 cv2.destroyAllWindows() ``` 这段代码中,我们首先初始化了Strong Sort跟踪器,并加载了YOLOv7模型和类别信息。然后,我们使用OpenCV读取摄像头或视频文件中的每一帧,在每一帧中调用YOLOv7进行目标检测,得到每个物体的位置和类别信息。接着,我们将检测结果传递给Strong Sort跟踪器进行多目标跟踪,得到每个物体的ID和最新的位置信息。最后,我们在图像中绘制跟踪结果,并显示在屏幕上。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

[空--白]

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值