YOLOv8 + Sort跟踪算法实现代码
YOLOv8代码由Ultralytics公司发布,在工程化方面做得不错。现在已经实现了目标跟踪、图像分类、实例分割等任务,成为了一个算法框架,并且集成了很多其他的一些算法,如YOLO系列的v3-v9系列,RT-DETR等算法。
ultralytics原生框架实现了bytetrack和botsort-without-Reid,因项目中需要与Sort算法进行比较。因此实现了YOLOv8与Sort算法的适配,这里给出实现了的代码。本文给出的代码还需要进一步完善,目前仅实现了视频的跟踪,后期还需要实现图片序列和多个视频的跟踪。
环境配置
首先需要下载ultralytics框架和Sort算法的代码
下载ultralytics代码
# 国内可用这个
git clone https://gitee.com/monkeycc/ultralytics.git # 国内可用这个
git clone https://gitee.com/funweb/sort.git # 国内可用这个
下载好代码后,按照ultralytics教程进行安装。Sort不需要安装,只需要满足依赖库即可。Sort不需要额外训练,检测器需要训练。因此,需要配置自己的数据集,在ultarlytics中训练。
训练代码如下:
from ultralytics import YOLO,RTDETR # 这里也可以训练rt-detr
model = YOLO('yolov8s.yaml')
model = YOLO('yolov8s.pt')
model.info()
results = model.train(data='coco8.yaml', epochs=100, imgsz=640)
results = model('bus.jpg')
import os
import glob
import numpy as np
import cv2
import argparse
import time
from sort import Sort
from ultralytics import YOLO,RTDETR
# 单独定义了一些附加的函数方法
from utils import convert, check_source, draw_trackers, write_result
class Tracker_detector():
def __init__(self,weight,tracker=Sort,parser=None):
self.model = parser.model
if self.model =='yolo':
self.detector = YOLO(weight)
if self.model == 'rtdetr':
self.detector = RTDETR(weight)
self.config = parser
self.tracker = Sort(max_age=self.config.max_age,
min_hits=self.config.min_hits,
iou_threshold=self.config.iou_threshold)
print(f' Detector and Tracker have inited')
def get_dets(self,source):
'''
get detection results
and convert results to a numpy.ndarray with format [x1,y1,x2,y2,score]
'''
results = self.detector(source)
results = results[0] if isinstance(results,list) else results
dets = convert(results)
return dets
def get_id(self,dets):
'''
get tracker results
'''
trackers = self.tracker.update(dets)
return trackers
def show_result(self,source):
# shou_result 方法写得有点冗余,后面需要改进一下
source_name = os.path.basename(source).split('.')[0]
save_name = self.model + '-' + self.config.save_path.lower() +'-' + source_name + '.txt'
save_file = os.path.join('result_output',self.model, save_name)
os.makedirs(os.path.join('result_output',self.model),exist_ok=True)
if self.config.source_type == 'video':
frame_count = 0
for frame in self._read_frame(source):
if frame is not None:
# results = self.detector(frame)
# results = results[0] if isinstance(results,list) else results
start_time = time.time()
frame_count +=1
results = self.detector(frame)
results = results[0] if isinstance(results,list) else results
boxes = results.boxes.data.cpu().numpy() # 返回带有bboxes,conf,class的数据类型
dets = convert(boxes)
trackers = self.tracker.update(dets)
end_time = time.time()
fps = 1/(end_time-start_time)
fps = fps * self.config.sampling_rate
frame = draw_trackers(frame,trackers)
cv2.imshow('test',frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
cv2.destroyAllWindows()
break
if self.config.save_path:
write_result(trackers,frame_count,save_file)
def __call__(self,source):
# create output dir and save file name
source_name = os.path.basename(source).split('.')[0]
save_name = self.model + '-' + self.config.save_path.lower() +'-' + source_name + '.txt'
os.makedirs(os.path.join('result_output',self.model),exist_ok=True)
save_file = os.path.join('result_output',self.model, save_name)
if self.config.source_type == 'video':
frame_count = 0
for frame in self._read_frame(source):
if frame is not None:
start_time = time.time()
frame_count +=1
results = self.detector(frame)
results = results[0] if isinstance(results,list) else results
boxes = results.boxes.data.cpu().numpy() # 返回带有bboxes,conf,class的数据类型
dets = convert(boxes)
trackers = self.tracker.update(dets)
if self.config.display:
end_time = time.time()
fps = 1/(end_time-start_time)
fps = fps * self.config.sampling_rate
frame = draw_trackers(frame,trackers)
cv2.imshow('test',frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
cv2.destroyAllWindows()
break
if self.config.save_path:
write_result(trackers,frame_count,save_file)
def _read_frame(self,source):
# 实现对视频视频的读取,并且通过sampling_rate控制是否需要间隔读取帧
self.cap = cv2.VideoCapture(source)
counting = 0
while True:
ret,frame = self.cap.read()
if not ret:
break
counting +=1
if counting % self.config.sampling_rate == 0:
yield frame
self.cap.release()
yield None
def parse_args():
"""Parse input arguments."""
parser = argparse.ArgumentParser(description='SORT demo')
parser.add_argument('--model',help='Detector',type=str,default='rtdetr',
choices=['yolo','rtdetr'])
parser.add_argument('--weight',help='Path for detector weight',
default='ultralytics/runs/detect/train2/weights/best.pt')
parser.add_argument('--display', dest='display', default=False,
help='Display online tracker output (slow) [False]')
parser.add_argument('--source_type',help='soruce type, image or video',default='video')
parser.add_argument("--source", help="Path for mot video,images.", type=str,
default= 'mot/videos')
parser.add_argument("--save_path", help="Path to save results.", type=str, default='Sort')
parser.add_argument('--sampling_rate',help= 'frequency of reading frames',type=int,
default=1,choices=[1,2,3])
parser.add_argument("--max_age",
help="Maximum number of frames to keep alive a track without associated detections.",
type=int, default=1)
parser.add_argument("--min_hits",
help="Minimum number of associated detections before track is initialised.",
type=int, default=3)
parser.add_argument("--iou_threshold", help="Minimum IOU for match.", type=float, default=0.3)
args = parser.parse_args()
return args
if __name__ == '__main__':
parser = parse_args()
paths = check_source(parser.source,parser.source_type)
# print(paths)
# for path in paths:
TBD = Tracker_detector(parser.weight, parser=parser)
# TBD.show_result(paths[0])
TBD(paths[14])
del TBD
utils.py的代码主要实现一些附加功能,如边框绘制,结果保存(结果保存为MOT15格式)等。代码如下:
import cv2
import os
import numpy as np
import random
def generate_color():
return (random.randint(0, 255),random.randint(0, 255),random.randint(0, 255))
colors = {}
for i in range(1000):
colors[i] = generate_color()
def convert(ori_dets):
'''
Convert a numpy.ndarray with format [x1,y1,x2,y2,score,class]
to a numpy.ndarray with format [x1,y1,x2,y2,score].
Parameters:
ori_dets (numpy.ndarray): Original detections array with class information.
dest_dets (numpy.ndarray): Destination detections array without class information.
'''
ori_dets = np.array(ori_dets)
ori_dets= np.expand_dims(ori_dets) if ori_dets.ndim == 1 else ori_dets
if ori_dets.shape[1] != 6:
raise ValueError("ori_dets must have 6 columns (x1, y1, x2, y2, score, class)")
dest_dets = ori_dets[:, :5]
return dest_dets
def check_source(path,type='video'):
'''
path: file path or file dir
type: video,image
'''
if os.path.isdir(path):
# print(f'{path} is a dir path')
list = os.listdir(path)
if type =='video':
list = [i for i in list if i.lower().endswith(('.avi','mp4'))]
list = [os.path.join(path,i) for i in list]
print(f'{path} is a videos dir path')
return list
if type == 'image':
list = [i for i in list if i.lower().endswith(('.jpg','png'))]
list = [os.path.join(path,i) for i in list]
print(f'{path} is a images dir path')
return list
if os.path.isfile(path):
if type == 'video':
if path.lower().endswith(('.avi','mp4')):
return path
else:
print(f'Video file format is not right!')
if type == 'image':
if path.lower().endswith(('.jpg','png')):
return path
else:
print(f'Image file format is not right!')
def draw_trackers(frame, results,fps=None):
'''
Parameters:
frame (numpy.ndarray,type:int8): A frame of source video.
results (numpy.ndarray): Results of a tracker updating.
fps: (int or float):
Return:
frame (numpy.ndarray,type:int8): A frame to be drawed results with bboxes and ids
'''
for d in results:
d = d.astype(np.int32)
t_size = cv2.getTextSize(str(d[0]),cv2.FONT_HERSHEY_PLAIN,2,2)[0]
cv2.rectangle(frame, (d[0], d[1]), (d[2], d[3]), colors[d[4]], thickness=4)
cv2.putText(frame,str(d[4]),(d[0],d[1] + t_size[1] +4),cv2.FONT_HERSHEY_COMPLEX, 1,(0,255,255),2)
if fps is not None:
cv2.putText(frame, "fps= %.2f" % (fps), (0,80), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0),2)
return frame
def write_result(results,frame_num, filename, data_type='mot'):
'''
write results to txt file
'''
if data_type == 'mot':
save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\n'
elif data_type == 'kitti':
save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
else:
raise ValueError(data_type)
results = np.expand_dims(results,0) if results.ndim == 1 else results
with open(filename, 'a') as f:
for d in results:
if data_type == 'kitti':
frame_id -= 1
x1, y1,x2,y2,id = d[0],d[1],d[2],d[3],d[4]
line = save_format.format(frame=frame_num,id=id,
x1=x1,y1=y1,w=x2-x1,h=y2-y1)
f.write(line)