YOLOv8+Sort跟踪算法实现代码

YOLOv8 + Sort跟踪算法实现代码

YOLOv8代码由Ultralytics公司发布,在工程化方面做得不错。现在已经实现了目标跟踪、图像分类、实例分割等任务,成为了一个算法框架,并且集成了很多其他的一些算法,如YOLO系列的v3-v9系列,RT-DETR等算法。

ultralytics原生框架实现了bytetrack和botsort-without-Reid,因项目中需要与Sort算法进行比较。因此实现了YOLOv8与Sort算法的适配,这里给出实现了的代码。本文给出的代码还需要进一步完善,目前仅实现了视频的跟踪,后期还需要实现图片序列和多个视频的跟踪。

环境配置
首先需要下载ultralytics框架和Sort算法的代码
下载ultralytics代码

# 国内可用这个
git clone https://gitee.com/monkeycc/ultralytics.git  # 国内可用这个
git clone https://gitee.com/funweb/sort.git      # 国内可用这个

下载好代码后,按照ultralytics教程进行安装。Sort不需要安装,只需要满足依赖库即可。Sort不需要额外训练,检测器需要训练。因此,需要配置自己的数据集,在ultarlytics中训练。
训练代码如下:

from ultralytics import YOLO,RTDETR   # 这里也可以训练rt-detr

model = YOLO('yolov8s.yaml')
model = YOLO('yolov8s.pt')
model.info()
results = model.train(data='coco8.yaml', epochs=100, imgsz=640)
results = model('bus.jpg')
import os
import glob
import numpy as np
import cv2
import argparse
import time
from sort import Sort
from ultralytics import YOLO,RTDETR
# 单独定义了一些附加的函数方法
from utils import convert, check_source, draw_trackers, write_result




class Tracker_detector():
    def __init__(self,weight,tracker=Sort,parser=None):
        self.model = parser.model
        if self.model =='yolo':
            self.detector = YOLO(weight)
        if self.model == 'rtdetr':
            self.detector = RTDETR(weight)
        
        self.config = parser
        self.tracker = Sort(max_age=self.config.max_age,
                            min_hits=self.config.min_hits,
                            iou_threshold=self.config.iou_threshold)
        print(f' Detector and Tracker have inited')
        
    
    def get_dets(self,source):
        '''
        get detection results
        and convert results to a numpy.ndarray with format [x1,y1,x2,y2,score] 
        '''
        results = self.detector(source)
        results = results[0] if isinstance(results,list) else results
        dets = convert(results)
        return dets
    

    def get_id(self,dets):
        '''
        get tracker results
        '''
        trackers = self.tracker.update(dets)
        return trackers


    def show_result(self,source):
    # shou_result 方法写得有点冗余,后面需要改进一下
        source_name = os.path.basename(source).split('.')[0]
        save_name = self.model + '-' + self.config.save_path.lower() +'-' + source_name + '.txt'
        save_file = os.path.join('result_output',self.model, save_name)
        os.makedirs(os.path.join('result_output',self.model),exist_ok=True)
        if self.config.source_type == 'video':
            frame_count = 0
            for frame in self._read_frame(source):
                if frame is not None:
                    # results = self.detector(frame)
                    # results = results[0] if isinstance(results,list) else results
                    start_time = time.time()
                    frame_count +=1
                    results = self.detector(frame)
                    results = results[0] if isinstance(results,list) else results
                    boxes = results.boxes.data.cpu().numpy()    # 返回带有bboxes,conf,class的数据类型
                    dets = convert(boxes)
                    trackers = self.tracker.update(dets)
                    end_time = time.time()
                    fps = 1/(end_time-start_time)
                    fps = fps * self.config.sampling_rate
                    frame = draw_trackers(frame,trackers)
                    cv2.imshow('test',frame)
                    if cv2.waitKey(1) & 0xFF == ord('q'):
                        cv2.destroyAllWindows()
                        break
                    if self.config.save_path:
                        write_result(trackers,frame_count,save_file)


    def __call__(self,source):
        # create output dir and save file name
        source_name = os.path.basename(source).split('.')[0]
        save_name = self.model + '-' + self.config.save_path.lower() +'-' + source_name + '.txt'
        os.makedirs(os.path.join('result_output',self.model),exist_ok=True)
        save_file = os.path.join('result_output',self.model, save_name)
        if self.config.source_type == 'video':
            frame_count = 0
            for frame in self._read_frame(source):
                if frame is not None:
                    start_time = time.time()
                    frame_count +=1
                    results = self.detector(frame)
                    results = results[0] if isinstance(results,list) else results
                    boxes = results.boxes.data.cpu().numpy()    # 返回带有bboxes,conf,class的数据类型
                    dets = convert(boxes)
                    trackers = self.tracker.update(dets)
                    if self.config.display:
                        end_time = time.time()
                        fps = 1/(end_time-start_time)
                        fps = fps * self.config.sampling_rate
                        frame = draw_trackers(frame,trackers)
                        cv2.imshow('test',frame)
                        if cv2.waitKey(1) & 0xFF == ord('q'):
                           cv2.destroyAllWindows()
                           break
                    if self.config.save_path:
                        write_result(trackers,frame_count,save_file)
        
    
    def _read_frame(self,source):
    # 实现对视频视频的读取,并且通过sampling_rate控制是否需要间隔读取帧
        self.cap = cv2.VideoCapture(source)
        counting = 0
        while True:
            ret,frame = self.cap.read()
            if not ret:
                break
            counting +=1
            if counting % self.config.sampling_rate == 0:
                yield frame

        self.cap.release()
        yield None
    


def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description='SORT demo')
    parser.add_argument('--model',help='Detector',type=str,default='rtdetr',
                        choices=['yolo','rtdetr'])
    parser.add_argument('--weight',help='Path for detector weight',
                        default='ultralytics/runs/detect/train2/weights/best.pt')
    parser.add_argument('--display', dest='display', default=False,
                        help='Display online tracker output (slow) [False]')
    parser.add_argument('--source_type',help='soruce type, image or video',default='video')
    parser.add_argument("--source", help="Path for mot video,images.", type=str, 
                        default= 'mot/videos')
    parser.add_argument("--save_path", help="Path to save results.", type=str, default='Sort')
    parser.add_argument('--sampling_rate',help= 'frequency of reading frames',type=int,
                        default=1,choices=[1,2,3])
    parser.add_argument("--max_age", 
                        help="Maximum number of frames to keep alive a track without associated detections.", 
                        type=int, default=1)
    parser.add_argument("--min_hits", 
                        help="Minimum number of associated detections before track is initialised.", 
                        type=int, default=3)
    parser.add_argument("--iou_threshold", help="Minimum IOU for match.", type=float, default=0.3)
    args = parser.parse_args()
    return args






if __name__ == '__main__':
    parser = parse_args()
    paths = check_source(parser.source,parser.source_type)
    # print(paths)
    # for path in paths:
    TBD = Tracker_detector(parser.weight, parser=parser)
    # TBD.show_result(paths[0])
    TBD(paths[14])
    del TBD

utils.py的代码主要实现一些附加功能,如边框绘制,结果保存(结果保存为MOT15格式)等。代码如下:

import cv2
import os
import numpy as np
import random

def generate_color():
    return (random.randint(0, 255),random.randint(0, 255),random.randint(0, 255))

colors = {}
for i in range(1000):
    colors[i] = generate_color()


def convert(ori_dets):
    '''
    Convert a numpy.ndarray with format [x1,y1,x2,y2,score,class] 
        to a numpy.ndarray with format [x1,y1,x2,y2,score].
    
    Parameters:
    ori_dets (numpy.ndarray): Original detections array with class information.
    dest_dets (numpy.ndarray): Destination detections array without class information.
    '''
    ori_dets = np.array(ori_dets)
    ori_dets= np.expand_dims(ori_dets) if ori_dets.ndim == 1 else ori_dets
    if ori_dets.shape[1] != 6:
        raise ValueError("ori_dets must have 6 columns (x1, y1, x2, y2, score, class)")
    
    dest_dets = ori_dets[:, :5]
    return dest_dets


def check_source(path,type='video'):
    '''
    path: file path or file dir
    type: video,image    
    '''
    if os.path.isdir(path):
        # print(f'{path} is a dir path')
        list = os.listdir(path)
        if type =='video':
            list = [i for i in list if i.lower().endswith(('.avi','mp4'))]
            list = [os.path.join(path,i) for i in list]
            print(f'{path} is a videos dir path')
            return list
        if type == 'image':
            list = [i for i in list if i.lower().endswith(('.jpg','png'))]
            list = [os.path.join(path,i) for i in list]
            print(f'{path} is a images dir path')
            return list
    if os.path.isfile(path):
        if type == 'video':
            if path.lower().endswith(('.avi','mp4')):
                return path
            else:
                print(f'Video file format is not right!')
            
        if type == 'image':
            if path.lower().endswith(('.jpg','png')):
                return path
            else:
                print(f'Image file format is not right!')



def draw_trackers(frame, results,fps=None):
    '''
    Parameters:
    frame (numpy.ndarray,type:int8): A frame of source video.
    results (numpy.ndarray): Results of a tracker updating.
    fps: (int or float): 
    Return:
    frame (numpy.ndarray,type:int8): A frame to be drawed results with bboxes and ids
    '''

    for d in results:
         d = d.astype(np.int32)
         t_size = cv2.getTextSize(str(d[0]),cv2.FONT_HERSHEY_PLAIN,2,2)[0]
         cv2.rectangle(frame, (d[0], d[1]), (d[2], d[3]), colors[d[4]], thickness=4) 
         cv2.putText(frame,str(d[4]),(d[0],d[1] + t_size[1] +4),cv2.FONT_HERSHEY_COMPLEX, 1,(0,255,255),2)
         if fps is not None:
            cv2.putText(frame, "fps= %.2f" % (fps), (0,80), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0),2)
    return frame




def write_result(results,frame_num, filename, data_type='mot'):
    '''
    write results to txt file
    '''
    if data_type == 'mot':
        save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\n'
    elif data_type == 'kitti':
        save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
    else:
        raise ValueError(data_type)
    
    results = np.expand_dims(results,0) if results.ndim == 1 else results  
    with open(filename, 'a') as f:
        for d in results:
            if data_type == 'kitti':
                frame_id -= 1
            x1, y1,x2,y2,id = d[0],d[1],d[2],d[3],d[4]
            line = save_format.format(frame=frame_num,id=id,
                                      x1=x1,y1=y1,w=x2-x1,h=y2-y1)
            f.write(line)  
  • 4
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Yolov5+DeepSort是一种目标检测和目标跟踪算法组合,可以在视频中实现实时的多目标跟踪。在GitHub上,你可以找到相关的源码和使用教程。 以下是使用Yolov5+DeepSort源码的一般步骤: 1. 下载源码:首先,你需要在GitHub上找到Yolov5和DeepSort的源码仓库,并将其下载到本地。 2. 安装依赖:在运行源码之前,你需要安装相关的依赖库。通常,这些依赖库包括PyTorch、NumPy、OpenCV等。你可以根据源码仓库中的要求进行安装。 3. 准备数据集:为了训练和测试模型,你需要准备一个适当的数据集。这个数据集应该包含标注好的图像或视频,并且标注信息应该包括目标的类别和位置。 4. 训练模型:使用准备好的数据集,你可以开始训练Yolov5模型。根据源码仓库中的指导,你需要运行相应的训练脚本,并设置好相关的参数,如学习率、批大小等。训练过程可能需要一定的时间,具体取决于你的硬件配置和数据集的大小。 5. 测试模型:在训练完成后,你可以使用训练好的Yolov5模型进行目标检测。根据源码仓库中的指导,你需要运行相应的测试脚本,并提供待检测的图像或视频作为输入。测试过程将输出检测到的目标及其位置。 6. 应用DeepSort:一旦你完成了目标检测,你可以将DeepSort算法应用于检测到的目标,以实现目标跟踪。根据源码仓库中的指导,你需要运行相应的跟踪脚本,并提供目标检测的结果作为输入。跟踪过程将输出每个目标的唯一ID和轨迹信息。 以上是一般的使用教程概述,具体的步骤和细节可能因源码仓库的不同而有所差异。建议你在GitHub上找到对应的源码仓库,并参考其中的详细文档和示例代码来进行具体操作。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值