闲来无聊,觉得onnx推理yolo的检测框好丑,强迫症犯了。自己写了一份补全代码。先给大家看看效果。个人觉得颜色不同好看一点,这里是把监控摄像头的rtsp流接进去了。检测效果还不错。
检测视频:
代码:
import argparse
from yolov5s_utils import *
import time
def main(opt):
# init model
yolo_detector = YoloV5s(opt)
yolo_cls = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
"elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
"skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard","surfboard",
"tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
"potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
"microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
"hair drier", "toothbrush" ]
colors = generate_colors(len(yolo_cls))
# read videos sources
cap = cv2.VideoCapture(opt.video_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
fourcc = cv2.VideoWriter_fourcc('M', 'P', '4', 'V')
out_frame = cv2.VideoWriter('output_video.mp4', fourcc, fps, (width, height))
cv2.namedWindow('result',cv2.WINDOW_NORMAL)
frame_id = 0
while True:
ret, img = cap.read()
if not ret:
print('视频读取失败!')
return
frame_id += 1
if frame_id % opt.frame_skip != 0:
continue
ori_img = img.copy()
# inference
start_time = time.time()
det_object, cls_object = yolo_detector.detect(ori_img, cls=yolo_cls, conf_thres=0.20, iou_thres=0.5)
end_time = time.time()
# print(end_time - start_time)
fps = int(1.0 / (end_time - start_time) + 0.5)
# print(fps)
cv2.putText(img,'CPU FPS:{}'.format(str(fps)),(10,30),fontFace=0,fontScale=1,color=(0,0,255),thickness=2)
# show
if len(det_object):
for index, object in enumerate(det_object):
cv2.rectangle(img, (int(object[0]), int(object[1])),
(int(object[2]), int(object[3])), colors[int(object[5])], thickness=2)
long = len(cls_object[int(object[5])]) * 8 *2 + 5
cv2.rectangle(img, (int(object[0]),int(object[1])-25),(int(object[0])+long,int(object[1])), colors[int(object[5])], thickness=-1)
cv2.putText(img, cls_object[int(object[5])], (int(object[0])-2, int(object[1])-3), fontFace=0, fontScale=1,
color=(0,255,0),thickness=2, lineType=cv2.LINE_AA)
cv2.putText(img, str(round(object[4],2)), (int(object[0]) - 2, int(object[1]) + 10), fontFace=0, fontScale=0.5,
color=colors[int(object[5])],thickness=1, lineType=cv2.LINE_AA)
# img = cv2.resize(img, (1200, 720))
cv2.imshow('result', img)
# out_frame.write(img)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
def parse_opt():
parser = argparse.ArgumentParser()
# 视频源和跳帧参数
parser.add_argument('--video_path', default='./uniform.mp4', help='video path')
# parser.add_argument('--video_path', default='rtsp://admin:LPRSIU@192.168.1.107:554/h264/ch1/main/av_stream', help='video path')
# parser.add_argument('--video_path', default=0, help='video path')
parser.add_argument('--frame_skip', type=int, default=1, help='Frame skip num')
# yolov5s目标检测
parser.add_argument('--detection_model', default='./yolov5s.onnx', help='model path')
opt = parser.parse_args()
return opt
if __name__ == "__main__":
opt = parse_opt()
main(opt)