yolov5与网络通信的结合
此篇博客记录在yolov5的实际应用过程中,使用socket网络编程,实现主程序给yolov5检测程序发送图片,yolov5检测程序接收图片进行目标检测并将目标信息发送回主程序的功能。
##1. predict.py程序
该程序由yolov5 detect.py修改而来,作为一个接口,可接收已被读取的numpy格式的图片(原detect.py的输入是直接接收摄像头的数据或是指定图片的路径)。此程序放在detect.py同一路径下。
import io
import numpy as np
import cv2
import torch
from PIL import Image
from numpy import random
'''
代码:由YOLOv5自带的detect.py 改编
'''
from utils.plots import Annotator, colors, save_one_box
from models.experimental import attempt_load
from utils.general import check_img_size, non_max_suppression, scale_coords, \
set_logging
from utils.torch_utils import select_device
from utils.plots import Colors
def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
# Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
shape = img.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
if not scaleup: # only scale down, do not scale up (for better test mAP)
r = min(r, 1.0)
# Compute padding
ratio = r, r # width, height ratios
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
if auto: # minimum rectangle
dw, dh = np.mod(dw, 32), np.mod(dh, 32) # wh padding
elif scaleFill: # stretch
dw, dh = 0.0, 0.0
new_unpad = (new_shape[1], new_shape[0])
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
dw /= 2 # divide padding into 2 sides
dh /= 2
if shape[::-1] != new_unpad: # resize
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
return img, ratio, (dw, dh)
weights = 'runs/train/exp9/weights/best.pt' # 训练好的模型位置
opt_device = '0' # device = 'cpu' or '0' or '0,1,2,3'
imgsz = 640
opt_conf_thres = 0.25
opt_iou_thres = 0.5
# Initialize
set_logging()
device = select_device(opt_device)
half = device.type != 'cpu' # half precision only supported on CUDA
# 加载模型
model = attempt_load(weights, map_location=device) # load FP32 model
imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size
if half:
model.half() # to FP16
# Get names and colors
names = model.module.names if hasattr(model, 'module') else model.names # 获取标签
print(names)
colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
# def transform_image(image_bytes):
# image = Image.open(io.BytesIO(image_bytes))
# # img = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
# # img = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR)
# # print(img)
# return image
# (接口中传输的二进制流)将二进制用cv2 读取流并转换成yolov5 可接受的图片
def bytes_img(image_bytes):
# 二进制数据流转np.ndarray [np.uint8: 8位像素]
img = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_COLOR)
rgb_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return rgb_img
# '''
# [[[ 60 65 68]
# [202 205 213]
# [207 208 228]
# '''
from utils.torch_utils import select_device
from models.experimental import attempt_load
from utils.general import check_img_size, non_max_suppression, scale_coords
# from utils.datasets import letterbox
# from utils.plots import plot_one_box
colors = Colors() # create instance for 'from utils.plots import colors'
# 此函数用来接收numpy格式的图片,进行检测并返回目标信息(类别,位置)
def predict_(pic_):
# Run inference
img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
# Set Dataloader & Run inference
im0s = pic_ # BGR # 蓝绿红
img = letterbox(im0s, new_shape=imgsz)[0]
# Convert
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
img = np.ascontiguousarray(img)
img = torch.from_numpy(img).to(device)
img = img.half() if half else img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
# Inference
# pred = model(img, augment=opt.augment)[0]
pred = model(img)[0]
# Apply NMS
pred = non_max_suppression(pred, opt_conf_thres, opt_iou_thres)
# print(type(pred))
# Process detections
detect_info = []
for i, det in enumerate(pred): # detections per image
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0s.shape).round()
# Write results
for *xyxy, conf, cls in reversed(det):
# print(cls)
dic ={
'class':f'{names[int(cls)]}', # 检测目标对应的类别名
'location':torch.tensor(xyxy).view(1, 4).view(-1).tolist(), # 坐标信息,左上和右下的坐标
'score': round(float(conf) * 100, 2) # 目标检测分数
}
detect_info.append(dic)
#画框
c = int(cls)
color = colors(c, True)
p1,p2=(int(xyxy[0]),int(xyxy[1])),(int(xyxy[2]),int(xyxy[3]))
img=cv2.rectangle(img=pic_,pt1=p1,pt2=p2,color=color,thickness=2)
img=cv2.putText(img,str(names[c]),(p1[0], p1[1] - 2),0,0.75,color,2)
# cv2.imshow('frame',img)
# cv2.waitKey(1)
return detect_info
# img=cv2.imread(r'E:\01_school_study\01_python_project\yolo_project\traffic_sign\data\images\test\1_0.jpg')
# predict_(img)
# if __name__=='__main__':
# img=cv2.imread('data/images/test/1_0.jpg')
# print(predict_(img))
2.主程序
此程序为项目工程主控制程序,由此程序获取摄像头数据,将图片编码后使用udp发送给yolo检测程序,并接收检测程序发送回来的检测结果。此程序位置不需放在YOLO工程的目录下。
from socket import *
import time
import cv2
# import predict
print("=====================发送数据时间戳UDP服务器=====================")
#本进程的ip+端口号,本进程将把摄像头采集的图片发送给检测程序,并接收检测程序发送回的目标信息
HOST0 = '127.0.0.1'
PORT0 = 9966 # 端口号
UFSIZ = 4092 # 接收消息的缓冲大小
ADDR0 = (HOST0, PORT0)
udpSerSock1 = socket(AF_INET, SOCK_DGRAM) #创建udp服务器套接字
udpSerSock1.bind(ADDR0) #套接字与地址绑定
#目标检测程序的ip+端口号
HOST = '127.0.0.1' # 主机号为空白表示可以使用任何可用的地址。
PORT = 28888 # 端口号
BUFSIZ = 1024 * 1024 # 接收数据缓冲大小
ADDR_send = (HOST, PORT)
udpCliSock_send = socket(AF_INET, SOCK_DGRAM) # 创建客户端套接字
# vifdeo_path = "test.mp4"
# vifdeo_path = "/home/sha/Videos/11.18-001.avi"
capture = cv2.VideoCapture(0)
while True:
ref, frame = capture.read()
if ref:
#发送图片给检测程序
print(type(frame))
frame = cv2.resize(frame, (420, 260))
cv2.imshow("send", frame)
cv2.waitKey(1)
img_encode = cv2.imencode('.jpg', frame)[1]
data = img_encode.tobytes()
udpCliSock_send.sendto(data, ADDR_send)
time.sleep(0.1)
# 接收检测后的信息
try:
#设置超时,当在规定时间内没有接收数据时,执行except,若不设置超时,则程序将停留在recv_data = udpSerSock1.recvfrom(1024)。
udpSerSock1.settimeout(0.01)
recv_data = udpSerSock1.recvfrom(1024)
info=recv_data[0].decode('utf-8')
print(info)
except:
continue
3.yolo检测程序
该程序负责接收主程序发送过来的图片信息,将其解码,并进行检测,并将检测结果通过udp发送给主程序。此程序需放在detect.py同一目录下。
#-------------------------------------#
# 调用摄像头检测
#-------------------------------------#
from PIL import Image
import numpy as np
import cv2
import time
from socket import *
from time import ctime
from io import BytesIO
from PIL import Image
from predict import predict_
# 调用远程摄像头
print("=====================Timestamp UDP server=====================")
#本进程的ip+端口号,本进程接收图片,检测后并将目标信息(类别,左上和右下两点坐标)发送给主程序
HOST = '127.0.0.1' #主机号为空白表示可以使用任何可用的地址。
PORT = 28888 #端口号
BUFSIZ = 4092*4092 #接收数据缓冲大小
ADDR = (HOST, PORT)
udpSerSock = socket(AF_INET, SOCK_DGRAM) #创建udp服务器套接字
udpSerSock.bind(ADDR) #套接字与地址绑定
fps = 0.0
# 主程序ip+端口号
HOST_send = '127.0.0.1'
PORT_send = 9966 # 端口号
UFSIZ_send = 1024 # 接收消息的缓冲大小
ADDR_send = (HOST_send, PORT_send)
udpCliSock1 = socket(AF_INET, SOCK_DGRAM) # 创建客户端套接字
# remote
# udpCliSock_send.sendto(bytes('start_1', 'utf-8'), ADDR_send)
flag = True
while True:
# print('Waiting to receive data...')
try:
udpSerSock.settimeout(10)
data_org, addr = udpSerSock.recvfrom(BUFSIZ) #连续接收指定字节的数据,接收到的是字节数组
# print(data_org)
# img = cv2.imdecode(data_org,cv2.IMREAD_COLOR)
data = data_org.hex()
head_org = data_org[5:]
# head = head_org.hex()
# img=cv2.imdecode(head_org,cv2.IMREAD_COLOR)
buf = BytesIO(data_org)
# retval = cv2.imdecode(buf, 1)
img = np.array(Image.open(buf))
# print(img)
if True:
# if flag:
# udpCliSock_send.sendto(bytes('start_2', 'utf-8'), ADDR_send)
# flag = False
frame = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
t1 = time.time()
info=predict_(frame)
cv2.imshow('frame0',frame)
cv2.waitKey(1)
print(info)
send_data=str(info)
udpCliSock1.sendto(send_data.encode("utf-8"), ADDR_send)
# time.sleep(0.1)
except:
print("Timeout 10s, please check client!")
break
参考博客:http://t.csdn.cn/8QC85