实现yolov5-master+deepsort封装并实现区域入侵检测(鼠标绘制)

以下文章仅提供代码思路不提供完整源码


环境搭建:

yolov5-master 源码下载https://github.com/ultralytics/yolov5

首先需要调通yolov5-master,详情见:https://blog.csdn.net/qq_36756866/article/details/109111065

获取deepsort模块:https://github.com/ZQPei/deep_sort_pytorch

(注:仅需要deepsort模块即可,但部分下载的deepsort模块不兼容,报错问题私信,可提供代码用于学习)

效果:

读懂以下内容需理解yolov5-master:detect文件。推荐链接:

https://www.bilibili.com/video/BV1Dt4y1x7Fz?p=2&vd_source=20c2fbe464a4779579b5dcdc2b3dd4ed

实现思路:

获取目标检测后的结果后,放入目标跟踪模块,并添加鼠标事件绘制目标区域。后判断既定区域与检测目标的交集

所需模块:

重写detect文件,以及tracker文件,并生成主函数mian。

重写detect:

获取yolov5-master预测值,详情见注释(本段思路来自网络)

import torch
import numpy as np
from utils.general import non_max_suppression,scale_boxes   #非极大抑制
from utils.augmentations import letterbox   #图像补边操作
from utils.torch_utils import select_device  #驱动器
from models.experimental import attempt_load   #


class Detector():       #封装成类,用于获取预测值
    def __init__(self):     #初始化
        super(Detector, self).__init__()
        self.imgSize=640
        self.threshold=0.3
        self.stride=1
        self.weights='weights/yolov5s.pt'
        self.device='0' if torch.cuda.is_available() else 'cpu'
        self.device=select_device(self.device)
        model = attempt_load(self.weights, device=self.device)
        model.to(self.device).eval()
        model.float()
        self.model=model
        self.names=model.module.names if hasattr(
            model,'module') else model.names        #用于将标签概率转化为标签名

    def preprocess(self,img):
        img0=img.copy()   #拷贝原图
        img=letterbox(img,new_shape=self.imgSize)[0]
        img=img[:,:,::-1].transpose(2,0,1)   #cv2读出来的是bgr通道需改为rgb通道,在交换维度得到(通道数,宽,高)
        img=np.ascontiguousarray(img)   #将数组变为一个连续性的数组
        img=torch.from_numpy(img).to(self.device)    #将数组转换为torch所接受的格式   放到驱动器
        img=img.float()   #转换成float类型
        img/=255.0    #归一化缩放
        if img.ndimension()==3:   #判断通道数是不是3
            img=img.unsqueeze(0)   #是的话就进行升维  增加一个batchsize的维度   传入图像为4维

        return img0,img

    def detect(self,im):
        img0,img=self.preprocess(im)    #调用上述函数,获取数据处理后的图片
        pred=self.model(img,augment=False)[0]   #有了图片之后,丢入模型获取预测值
        pred=pred.float()
        pred=non_max_suppression(pred,self.threshold,0.4)   #将预测值进行非极大值抑制处理
        pred_boxes=[]   #用于保存所需要的值
        for det in pred:    #遍历非极大值抑制之后的预测值
            if det is not None and len(det):    #判断非空
                det[:,:4]=scale_boxes(img.shape[2:],det[:,:4],img0.shape).round()   #调用函数用来将预测值的坐标映射到原图

                for *x,conf,cls_id in det:
                    lbl=self.names[int(cls_id)]     #将标签概率通过字典转化为标签名
                    if lbl not in ['person', 'bicycle', 'car','motorcycle', 'bus', 'truck']: #获取所需的标签,即不在列表内的标签不取(这里直接用的yolov5s的权重,八十分类)
                        continue
                    x1,y1=int(x[0]),int(x[1])
                    x2,y2=int(x[2]),int(x[3])
                    pred_boxes.append((x1,y1,x2,y2,lbl,conf))   #将获取的值按元组形式保存到列表

        return im,pred_boxes    #返回原图 和预测框

修改tracker:

直接使用原tracker,在原来的基础上做修改

修改部分有:绘制部分代码,加入入侵判断,并区域内记录人数(我去掉了繁琐且意义不大的部分)

修改思路:将获取到的预测值,将目标框的中心点区域作为判断标准,判断与掩膜区域(即既定区域)是否有交集,来判断是否入侵,若入侵,则将框颜色转成红色(原本为绿色)。记录入侵人数,需要统计当前图片所有满足条件,即绘制红框个数,在最后将入侵人数绘制。

def draw_bboxes(im, bboxes,mask):       #绘制目标跟踪框,中心点,和入侵判断。参数(原图,预测框坐标,鼠标绘制的既定区域)
    count = 0   #用于记录既定区域内入侵人数
    for (x1, y1, x2, y2, cls_id, pos_id) in bboxes:     #遍历所有预测框
        x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)     #转int类型用于绘制框
        list_pts = []   #用于保存中心点框的坐标(个人认为没有多大必要,且写法繁琐,但无关紧要 遂保存)
        check_point_x = int((x2 + x1) / 2)      #中心点坐标x
        point_radius = 2        #中心点半径
        check_point_y = int((y2 + y1) / 2)  #中心点y
        c1, c2 = (int(x1), int(y1)), (int(x2), int(y2)) #左上右下点

        # 将原图上的预测框中心点区域求掩膜区域同位置的交集,不为0则判断为入侵(注y在前,理解不了的可自行百度
        overlap = cv2.bitwise_and(mask[check_point_y - 3:check_point_y + 3, check_point_x - 3:check_point_x + 3],
                                  im[check_point_y - 3:check_point_y + 3, check_point_x - 3:check_point_x + 3])

        if np.sum(overlap) != 0 and cls_id == 'person':     #判断交集不为0,且标签为person 判断为入侵
            count += 1      #入侵人数+1
            cv2.rectangle(im, c1, c2, (0, 0, 255), 2, cv2.LINE_AA)  #绘制入侵的框,将框的颜色设置为红色
            allname = cls_id + '-ID--' + str(pos_id)    #将标签名和ID数字符拼接
            t_size = cv2.getTextSize(allname, 0, fontScale=0.5, thickness=2)[0] #获取上述文本的长度 用于绘制预测框上的文本框
            c2 = (c1[0] + t_size[0], c1[1] - t_size[1] - 3)     #调整文本框右下角坐标
            cv2.rectangle(im, c1, c2, (0, 0, 255), -1)  # filled        #绘制文本框

        else:       #绘制非既定区域内的框  (绿色框)
            cv2.rectangle(im, c1, c2, (0, 255, 0), 2, cv2.LINE_AA)
            allname = cls_id + '-ID--' + str(pos_id)
            t_size = cv2.getTextSize(allname, 0, fontScale=0.5, thickness=2)[0]
            c2 = (c1[0] + t_size[0], c1[1] - t_size[1] - 3)
            cv2.rectangle(im, c1, c2, (0, 255, 0), -1)

        list_pts.append([check_point_x - point_radius, check_point_y - point_radius])       #中心点区域
        list_pts.append([check_point_x - point_radius, check_point_y + point_radius])
        list_pts.append([check_point_x + point_radius, check_point_y + point_radius])
        list_pts.append([check_point_x + point_radius, check_point_y - point_radius])

        ndarray_pts = np.array(list_pts, np.int32)
        cv2.fillPoly(im, [ndarray_pts], color=(0, 0, 255))      #绘制多边形
        list_pts.clear()
        cv2.putText(im, allname, (c1[0], c1[1] - 2), 0, 0.5,        #将文本写入文本框
                    [225, 255, 255], thickness=2, lineType=cv2.LINE_AA)
    if count != 0:      #判断若掩膜区域有目标则写入警告,并绘制统计的人数
        cv2.putText(im, f'Warning:-{count}-peopels', (10, 50), 0, 2, (0, 0, 255), 3, cv2.LINE_AA)

    return im   #返回绘制后的图


                     

 绘制跟踪轨迹:

在update模块中绘制轨迹(要求各个绘制轨迹并仅保留50帧,若检测目标消失后,仍保留十帧,若十帧后仍未出现则删除该坐标)

修改思路:

        先通过字典的形式将id号(键)和中心点坐标保存(值,用列表储存坐标),遍历每一个键,看键对应的值,也就是中心点坐标长度是否大于50,若大于50则删除第一个坐标,则列表长度恒为50。用另一个字典来记录id号消失的帧数。

        如何判断消失帧数:

                只需判断前一帧所获取的id号,是否在当前帧保存中心点坐标的字典中,若不存在则计数+1,再判断是否大于十帧,若大于则删除所有以此id为键的所有键值对即可。

                

#加入两个空字典(dict_box,dic_id),1.用于保存中心点的坐标,2.用于记录预测框消失的帧数,用于删除消失目标的中心点坐标
def update(bboxes, im,frame_cnt,dict_box,dic_id):
    bbox_xywh = []
    confs = []
    bboxes2draw = []

    if len(bboxes) > 0:
        for x1, y1, x2, y2, lbl, conf in bboxes:
            obj = [
                int((x1 + x2) * 0.5), int((y1 + y2) * 0.5),
                x2 - x1, y2 - y1
            ]
            bbox_xywh.append(obj)
            confs.append(conf)

        xywhs = torch.Tensor(bbox_xywh)
        confss = torch.Tensor(confs)

        outputs = deepsort.update(xywhs, confss, im)

        if len(outputs) > 0:
            bbox_xyxy = outputs[:,:4]  # 提取前四列 (坐标)
            identities = outputs[:,-1]  # 提取最后一列 (ID)
            for i in list(dict_box.keys()):     #遍历保存坐标的字典的键(即id号)
                if i not in identities:     #如果id 不在预测结果里面,说明id在当前帧消失
                    dic_id[i] += 1      #则生成一个键值对,用于记录id消失的帧数
                    if dic_id[i] > 10:  #如果帧数大于十,则在记录中心点坐标的字典中删除此id和坐标,并删除帧数记录
                        dict_box.pop(i)
                        dic_id.pop(i)

            box_xywh = xyxy2tlwh(bbox_xyxy)


            for j in range(len(box_xywh)):     #遍历坐标
                x_center = box_xywh[j][0] + box_xywh[j][2] / 2  # 求框的中心x坐标
                y_center = box_xywh[j][1] + box_xywh[j][3] / 2  # 求框的中心y坐标
                id = outputs[j][-1]
                center = [x_center, y_center]

                dict_box.setdefault(id, []).append(center)  #将中心点坐标和id号用字典绑定
                dic_id[id] = 1      #定义消失计数初始值
                if len(dict_box[id]) > 50:  #记录只50个中心点坐标,判断长度是否大于50大于则删除第一个
                    dict_box[id].pop(0)


            if frame_cnt > 2:  # 第一帧无法连线,所以设置从第二帧开始,frame_cnt为当前帧号      绘制轨迹
                for key, value in dict_box.items():
                    for a in range(len(value) - 1):
                            index_start = a
                            index_end = index_start + 1
                            cv2.line(im, tuple(map(int, value[index_start])), tuple(map(int, value[index_end])),
                                     # map(int,"1234")转换为list[1,2,3,4]
                                     (0,0,255), thickness=2, lineType=8)

        for x1, y1, x2, y2, track_id in list(outputs):
            # x1, y1, x2, y2, track_id = value
            center_x = (x1 + x2) * 0.5
            center_y = (y1 + y2) * 0.5

            label = search_label(center_x=center_x, center_y=center_y,
                                 bboxes_xyxy=bboxes, max_dist_threshold=20.0)

            center = (x1,y1,x2,y2,label,track_id)

            bboxes2draw.append((center))
        pass
    pass

    return bboxes2draw

完整tarcker:

import cv2
import torch
import numpy as np

from deep_sort.utils.parser import get_config
from deep_sort.deep_sort import DeepSort

cfg = get_config()
cfg.merge_from_file("./deep_sort/configs/deep_sort.yaml")
deepsort = DeepSort(cfg.DEEPSORT.REID_CKPT,
                    max_dist=cfg.DEEPSORT.MAX_DIST, min_confidence=cfg.DEEPSORT.MIN_CONFIDENCE,
                    nms_max_overlap=cfg.DEEPSORT.NMS_MAX_OVERLAP, max_iou_distance=cfg.DEEPSORT.MAX_IOU_DISTANCE,
                    max_age=cfg.DEEPSORT.MAX_AGE, n_init=cfg.DEEPSORT.N_INIT, nn_budget=cfg.DEEPSORT.NN_BUDGET,
                    use_cuda=True)

def xyxy2tlwh(x):
    '''
    (top left x, top left y,width, height)
    '''
    y = torch.zeros_like(x) if isinstance(x,
                                          torch.Tensor) else np.zeros_like(x)
    y[:, 0] = x[:, 0]
    y[:, 1] = x[:, 1]
    y[:, 2] = x[:, 2] - x[:, 0]
    y[:, 3] = x[:, 3] - x[:, 1]
    return y

def draw_bboxes(im, bboxes,mask):       #绘制目标跟踪框,中心点,和入侵判断。参数(原图,预测框坐标,鼠标绘制的既定区域)
    count = 0   #用于记录既定区域内入侵人数
    for (x1, y1, x2, y2, cls_id, pos_id) in bboxes:     #遍历所有预测框
        x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)     #转int类型用于绘制框
        list_pts = []   #用于保存中心点框的坐标(个人认为没有多大必要,且写法繁琐,但无关紧要 遂保存)
        check_point_x = int((x2 + x1) / 2)      #中心点坐标x
        point_radius = 2        #中心点半径
        check_point_y = int((y2 + y1) / 2)  #中心点y
        c1, c2 = (int(x1), int(y1)), (int(x2), int(y2)) #左上右下点

        # 将原图上的预测框中心点区域求掩膜区域同位置的交集,不为0则判断为入侵(注y在前,理解不了的可自行百度
        overlap = cv2.bitwise_and(mask[check_point_y - 3:check_point_y + 3, check_point_x - 3:check_point_x + 3],
                                  im[check_point_y - 3:check_point_y + 3, check_point_x - 3:check_point_x + 3])

        if np.sum(overlap) != 0 and cls_id == 'person':     #判断交集不为0,且标签为person 判断为入侵
            count += 1      #入侵人数+1
            cv2.rectangle(im, c1, c2, (0, 0, 255), 2, cv2.LINE_AA)  #绘制入侵的框,将框的颜色设置为红色
            allname = cls_id + '-ID--' + str(pos_id)    #将标签名和ID数字符拼接
            t_size = cv2.getTextSize(allname, 0, fontScale=0.5, thickness=2)[0] #获取上述文本的长度 用于绘制预测框上的文本框
            c2 = (c1[0] + t_size[0], c1[1] - t_size[1] - 3)     #调整文本框右下角坐标
            cv2.rectangle(im, c1, c2, (0, 0, 255), -1)  # filled        #绘制文本框

        else:       #绘制非既定区域内的框  (绿色框)
            cv2.rectangle(im, c1, c2, (0, 255, 0), 2, cv2.LINE_AA)
            allname = cls_id + '-ID--' + str(pos_id)
            t_size = cv2.getTextSize(allname, 0, fontScale=0.5, thickness=2)[0]
            c2 = (c1[0] + t_size[0], c1[1] - t_size[1] - 3)
            cv2.rectangle(im, c1, c2, (0, 255, 0), -1)

        list_pts.append([check_point_x - point_radius, check_point_y - point_radius])       #中心点区域
        list_pts.append([check_point_x - point_radius, check_point_y + point_radius])
        list_pts.append([check_point_x + point_radius, check_point_y + point_radius])
        list_pts.append([check_point_x + point_radius, check_point_y - point_radius])

        ndarray_pts = np.array(list_pts, np.int32)
        cv2.fillPoly(im, [ndarray_pts], color=(0, 0, 255))      #绘制多边形
        list_pts.clear()
        cv2.putText(im, allname, (c1[0], c1[1] - 2), 0, 0.5,        #将文本写入文本框
                    [225, 255, 255], thickness=2, lineType=cv2.LINE_AA)
    if count != 0:      #判断若掩膜区域有目标则写入警告,并绘制统计的人数
        cv2.putText(im, f'Warning:-{count}-peopels', (10, 50), 0, 2, (0, 0, 255), 3, cv2.LINE_AA)

    return im   #返回绘制后的图

#加入两个空字典(dict_box,dic_id),1.用于保存中心点的坐标,2.用于记录预测框消失的帧数,用于删除消失目标的中心点坐标
def update(bboxes, im,frame_cnt,dict_box,dic_id):
    bbox_xywh = []
    confs = []
    bboxes2draw = []

    if len(bboxes) > 0:
        for x1, y1, x2, y2, lbl, conf in bboxes:
            obj = [
                int((x1 + x2) * 0.5), int((y1 + y2) * 0.5),
                x2 - x1, y2 - y1
            ]
            bbox_xywh.append(obj)
            confs.append(conf)

        xywhs = torch.Tensor(bbox_xywh)
        confss = torch.Tensor(confs)

        outputs = deepsort.update(xywhs, confss, im)

        if len(outputs) > 0:
            bbox_xyxy = outputs[:,:4]  # 提取前四列 (坐标)
            identities = outputs[:,-1]  # 提取最后一列 (ID)
            for i in list(dict_box.keys()):     #遍历保存坐标的字典的键(即id号)
                if i not in identities:     #如果id 不在预测结果里面,说明id在当前帧消失
                    dic_id[i] += 1      #则生成一个键值对,用于记录id消失的帧数
                    if dic_id[i] > 10:  #如果帧数大于十,则在记录中心点坐标的字典中删除此id和坐标,并删除帧数记录
                        dict_box.pop(i)
                        dic_id.pop(i)

            box_xywh = xyxy2tlwh(bbox_xyxy)


            for j in range(len(box_xywh)):     #遍历坐标
                x_center = box_xywh[j][0] + box_xywh[j][2] / 2  # 求框的中心x坐标
                y_center = box_xywh[j][1] + box_xywh[j][3] / 2  # 求框的中心y坐标
                id = outputs[j][-1]
                center = [x_center, y_center]

                dict_box.setdefault(id, []).append(center)  #将中心点坐标和id号用字典绑定
                dic_id[id] = 1      #定义消失计数初始值
                if len(dict_box[id]) > 50:  #记录只50个中心点坐标,判断长度是否大于50大于则删除第一个
                    dict_box[id].pop(0)


            if frame_cnt > 2:  # 第一帧无法连线,所以设置从第二帧开始,frame_cnt为当前帧号      绘制轨迹
                for key, value in dict_box.items():
                    for a in range(len(value) - 1):
                            index_start = a
                            index_end = index_start + 1
                            cv2.line(im, tuple(map(int, value[index_start])), tuple(map(int, value[index_end])),
                                     # map(int,"1234")转换为list[1,2,3,4]
                                     (0,0,255), thickness=2, lineType=8)

        for x1, y1, x2, y2, track_id in list(outputs):
            # x1, y1, x2, y2, track_id = value
            center_x = (x1 + x2) * 0.5
            center_y = (y1 + y2) * 0.5

            label = search_label(center_x=center_x, center_y=center_y,
                                 bboxes_xyxy=bboxes, max_dist_threshold=20.0)

            center = (x1,y1,x2,y2,label,track_id)

            bboxes2draw.append((center))
        pass
    pass

    return bboxes2draw


def search_label(center_x, center_y, bboxes_xyxy, max_dist_threshold):
    """
    在 yolov5 的 bbox 中搜索中心点最接近的label
    :param center_x:
    :param center_y:
    :param bboxes_xyxy:
    :param max_dist_threshold:
    :return: 字符串
    """
    label = ''
    # min_label = ''
    min_dist = -1.0

    for x1, y1, x2, y2, lbl, conf in bboxes_xyxy:
        center_x2 = (x1 + x2) * 0.5
        center_y2 = (y1 + y2) * 0.5

        # 横纵距离都小于 max_dist
        min_x = abs(center_x2 - center_x)
        min_y = abs(center_y2 - center_y)

        if min_x < max_dist_threshold and min_y < max_dist_threshold:
            # 距离阈值,判断是否在允许误差范围内
            # 取 x, y 方向上的距离平均值
            avg_dist = (min_x + min_y) * 0.5
            if min_dist == -1.0:
                # 第一次赋值
                min_dist = avg_dist
                # 赋值label
                label = lbl
                pass
            else:
                # 若不是第一次,则距离小的优先
                if avg_dist < min_dist:
                    min_dist = avg_dist
                    # label
                    label = lbl
                pass
            pass
        pass

    return label

主函数部分:

调用上述文件,并添加绘制掩膜函数和鼠标事件

 然后就是读取视频,读取视频文件的每一帧,调用上述封装的函数,并输出效果

import cv2
import numpy as np
from detect import Detector     #调用detect文件的Detector类
import tracker      #调用修改的tracker文件
point1 = []     #定义空列表,用于获取鼠标绘制的坐标
point2 = []

def draw_mask(im,point1,point2):        #定义一个函数,用于将鼠标绘制的区域绘制到显示的图片上
    pts = np.array([[point1[0], point1[1]],
                    [point2[0], point1[1]],
                    [point2[0], point2[1]],
                    [point1[0], point2[1]]], np.int32)
    cv2.polylines(im, [pts], True, (255, 255, 0), 3)
    return im

def draw_rectangle(event, x, y, flags, param):      #定义鼠标事件函数
    global im,point1, point2

    if event == cv2.EVENT_LBUTTONDOWN:
        point1 = (x, y)

    elif event == cv2.EVENT_MOUSEMOVE and (flags & cv2.EVENT_FLAG_LBUTTON):
        cv2.rectangle(im, point1, (x, y), (255, 0, 0), 2)

    elif event == cv2.EVENT_LBUTTONUP:
        point2 = (x,y)
        cv2.rectangle(im, point1, point2, (0, 255, 0), 2)

cv2.namedWindow('main')     #窗口名字为展示时窗口名
cv2.setMouseCallback('main', draw_rectangle)    #在此窗口名中调用鼠标事件

if __name__ == '__main__':

    video = cv2.VideoCapture('test_person.mp4')
    height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
    width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame = video.read()
    frame_cnt = 0       #记录帧数  用于判断帧数是否大于2,进行目标跟踪
    dict_box = dict()
    dic_id = dict()
    det = Detector()    #生成预测对象
    while True:
        _,im = video.read()
        if im is None:
            break
        # im = cv2.resize(im,(960,640))
        frame_cnt += 1
        listbox = []   #box框
        
        im,bboxes = det.detect(im)  #获取预测结果

        mask = np.zeros((height, width, 3), np.uint8)       #掩膜区域数值设置为0
        if point1 != [] and point2 != []:   #判断是否获取到鼠标事件的点
            cv2.rectangle(mask, (point1[0], point1[1]), (point2[0], point2[1]), 255, -1)    #绘制区域,用作判断入

        if len(bboxes)>0:   #判断是否有预测值
            listboxs = tracker.update(bboxes,im,frame_cnt,dict_box,dic_id)  #将预测值送入目标跟踪中
            im = tracker.draw_bboxes(im,listboxs,mask)  #绘制在原图上
            if point1 != [] and point2 != []:   #   将既定区域绘制到输出图片上
                im = draw_mask(im,point1,point2)

        if cv2.waitKey(1)&0xFF == ord('q'):     #按q退出播放
            break

        cv2.imshow('main',im)
        cv2.waitKey(30)

    video.release()
    cv2.destroyAllWindows()



【资源说明】 1、该资源包括项目的全部源码,下载可以直接使用! 2、本项目适合作为计算机、数学、电子信息等专业的课程设计、期末大作业和毕设项目,作为参考资料学习借鉴。 3、本资源作为“参考资料”如果需要实现其他功能,需要能看懂代码,并且热爱钻研,自行调试。 基于yolov5+PyQt5开发的自定义区域入侵检测事件信息统计系统python源码+说明文档.zip **How to Use:** - Run `detect_logicwd.py` directly to enter the detection interface. ```shell script python detect_logicwd.py ``` ## Project Structure ``` . ├── README.md ├── README_EN.md ├── models/ # configuration files for yolov5 ├── output/ # target detection results of yolov5 for images ├── ruqin/ # JSON files for drawing polygonal areas ├── ui/ # generated UI files and PY files of PyQT5, the currently used interface file is detect_ui_new_v2.py ├── weights/ # storage folder for weights files of yolov5 and other training models ├── line_draw.py # includes the detection algorithm of area intrusion, control of displaying invading objects in detection area, drawing polygonal areas, etc. └── detect_logicwd.py # software main body and entry point of the project. It includes importing corresponding PY files of UI, logical implementation of various buttons, outputting information (detection frame, object information, coordinates, confidence level, video screen, etc.). ``` ......
DeepSORTDeep Learning-based Object Tracking)是一种基于深度学习的目标跟踪算法,它可以预测目标的运动轨迹。以下是使用DeepSORT进行目标跟踪并画出目标运动轨迹的代码示例: 首先,安装必要的库: ```python pip install numpy pip install opencv-python pip install tensorflow pip install keras pip install filterpy pip install scikit-learn ``` 然后,导入必要的库: ```python import numpy as np import cv2 from deep_sort import preprocessing, nn_matching from deep_sort.detection import Detection from deep_sort.tracker import Tracker from tools import generate_detections as gdet ``` 接下来,定义一些变量: ```python model_filename = 'model_data/mars-small128.pb' encoder = gdet.create_box_encoder(model_filename, batch_size=1) metric = nn_matching.NearestNeighborDistanceMetric("cosine", 0.2, None) tracker = Tracker(metric) ``` 其中,`model_filename`是预训练模型的文件路径,`encoder`是用于对检测框进行编码的函数,`metric`是用于计算距离的度量函数,`tracker`是目标跟踪器。 然后,读取视频并进行处理: ```python video_path = 'path/to/video.mp4' video_capture = cv2.VideoCapture(video_path) while True: ret, frame = video_capture.read() if ret != True: break frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) detections = detect_objects(frame) # 检测物体 features = encoder(frame, detections) # 特征编码 detections = [Detection(bbox, 1.0, feature) for bbox, feature in zip(detections, features)] tracker.predict() # 预测目标位置 tracker.update(detections) # 更新目标位置 for track in tracker.tracks: if not track.is_confirmed() or track.time_since_update > 1: continue bbox = track.to_tlbr() cv2.rectangle(frame, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0, 255, 0), 2) cv2.putText(frame, str(track.track_id), (int(bbox[0]), int(bbox[1]) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2, cv2.LINE_AA) trace = np.array(track.trace).astype(np.int32) cv2.polylines(frame, [trace], False, (0, 255, 0), 2) cv2.imshow('frame', frame) if cv2.waitKey(1) & 0xFF == ord('q'): break video_capture.release() cv2.destroyAllWindows() ``` 其中,`detect_objects`函数用于检测物体,`track.to_tlbr()`用于获取目标位置,`cv2.rectangle`和`cv2.putText`用于在图像上绘制框和ID,`track.trace`用于获取目标运动轨迹,`cv2.polylines`用于在图像上绘制运动轨迹。 完整的代码可在 https://github.com/nwojke/deep_sort 中找到。
评论 17
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值