前言
上次把yolov5和deepsort的跟踪算法搞定了,并加入了轨迹线显示,优化了轨迹线显示一段和目标消失时轨迹线不显示。但是他的检测版本有点老,我们实际训练和检测版本较高,所以这一节介绍如何更改为高版本6.1。
一、直接上手
把你的yolo拖到跟踪的路径下。我命名为nano。
之后替换为nano路径。关于time_synchronized函数,我这个6.1版本中换成了简写的time_sync,所以,ctrl+f查找替换为time_sync就行。
因为我们检测的视频是flv格式的,所以在datasets里面加入flv格式
至此大功告成了。
二、摄像头检测
上述改进测试本地视频是没有问题了,但是测试本地摄像头是有问题的,提示cap没有属性。更改一下datasets文件里面的loadsteams类。主要增加一个cap。总共应该有两处,下面中文注释处。
class LoadStreams:
# YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
self.mode = 'stream'
self.img_size = img_size
self.stride = stride
self.cap = [] #增加第一处
if os.path.isfile(sources):
with open(sources) as f:
sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
else:
sources = [sources]
n = len(sources)
self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
self.sources = [clean_str(x) for x in sources] # clean source names for later
self.auto = auto
for i, s in enumerate(sources): # index, source
# Start thread to read frames from video stream
st = f'{i + 1}/{n}: {s}... '
if 'youtube.com/' in s or 'youtu.be/' in s: # if source is YouTube video
check_requirements(('pafy', 'youtube_dl==2020.12.2'))
import pafy
s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
cap = cv2.VideoCapture(s)
assert cap.isOpened(), f'{st}Failed to open {s}'
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
_, self.imgs[i] = cap.read() # guarantee first frame
self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
self.threads[i].start()
self.cap.append(cap) #增加第二处
LOGGER.info('') # newline
# check for common shapes
s = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0].shape for x in self.imgs])
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
if not self.rect:
LOGGER.warning('WARNING: Stream shapes differ. For optimal performance supply similarly-shaped streams.')
def update(self, i, cap, stream):
# Read stream `i` frames in daemon thread
n, f, read = 0, self.frames[i], 1 # frame number, frame array, inference every 'read' frame
while cap.isOpened() and n < f:
n += 1
# _, self.imgs[index] = cap.read()
cap.grab()
if n % read == 0:
success, im = cap.retrieve()
if success:
self.imgs[i] = im
else:
LOGGER.warning('WARNING: Video stream unresponsive, please check your IP camera connection.')
self.imgs[i] = np.zeros_like(self.imgs[i])
cap.open(stream) # re-open stream if signal was lost
time.sleep(1 / self.fps[i]) # wait time
def __iter__(self):
self.count = -1
return self
def __next__(self):
self.count += 1
if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
cv2.destroyAllWindows()
raise StopIteration
# Letterbox
img0 = self.imgs.copy()
img = [letterbox(x, self.img_size, stride=self.stride, auto=self.rect and self.auto)[0] for x in img0]
# Stack
img = np.stack(img, 0)
# Convert
img = img[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
img = np.ascontiguousarray(img)
return self.sources, img, img0, self.cap, ''
def __len__(self):
return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
此外主检测文件也要改。
# for frame_idx, (path, img, im0s, vid_cap) in enumerate(dataset):
for frame_idx, (path, img, im0s, vid_cap, _) in enumerate(dataset):
效果如下
三、输出介绍
目前的输出文件会在推理的路径下生成。
里面的results.txt文件内容是这样的,是MOT格式的。
(frame_idx, identity, bbox_left, bbox_top, bbox_w, bbox_h,-1, -1, -1, -1)
MOT(Multiple Object Tracking)格式是一种用于保存多目标跟踪结果的文件格式,通常用于比赛、评估和研究。它的基本结构是在每一行中写入有关单个目标的信息,包括帧号、目标ID、边界框的坐标和其他属性。
含义介绍:
帧号:当前帧的索引,也就是图像序列中的第几帧。
目标ID:一个唯一的标识符,用于区分不同的目标。
边界框左上角x坐标:目标边界框的左上角x坐标。
边界框左上角y坐标:目标边界框的左上角y坐标。(用这两个主要是为了统一老算法的输出,用于目标定位于追踪)
边界框宽度:目标边界框的宽度。
边界框高度:目标边界框的高度。(可能还有其他字段,这取决于特定的MOT格式和需求)
问题
按下q键无法释放摄像头,无法结束程序,就修改按键部分如下。这部分只能结束小循环,关闭摄像头,但是大的循环还是无法结束。有没有大佬知道怎么结束程序呢。我在大循环中也加了跳出循环,也无法结束进程。
# Stream results
key = cv2.waitKey(1)
if view_img:
cv2.imshow(p, im0)
if key== ord('q') or key == 27:
vid_cap[0].release()
cv2.destroyAllWindows()
# exit_flag = True
break
总结
成功更改适配6.1版本。至此检测与跟踪自成一体。