import cv2
import numpy as np
import onnxruntime
CLASSES = ["",""] # 请将您的类别列表补充完整
class YOLOV5():
def __init__(self, onnxpath):
self.onnx_session = onnxruntime.InferenceSession(onnxpath)
self.input_name = self.get_input_name()
self.output_name = self.get_output_name()
def get_input_name(self):
input_name = []
for node in self.onnx_session.get_inputs():
input_name.append(node.name)
return input_name
def get_output_name(self):
output_name = []
for node in self.onnx_session.get_outputs():
output_name.append(node.name)
return output_name
def get_input_feed(self, img_tensor):
input_feed = {}
for name in self.input_name:
input_feed[name] = img_tensor
return input_feed
def inference(self, image):
or_img = cv2.resize(image, (640, 640))
img = or_img[:, :, ::-1].transpose(2, 0, 1) # BGR2RGB和HWC2CHW
img = img.astype(dtype=np.float32)
img /= 255.0
img = np.expand_dims(img, axis=0)
input_feed = self.get_input_feed(img)
pred = self.onnx_session.run(None, input_feed)[0]
return pred, or_img
def xywh2xyxy(x):
y = np.copy(x)
y[:, 0] = x[:, 0] - x[:, 2] / 2
y[:, 1] = x[:, 1] - x[:, 3] / 2
y[:, 2] = x[:, 0] + x[:, 2] / 2
y[:, 3] = x[:, 1] + x[:, 3] / 2
return y
def filter_box(org_box, conf_thres, iou_thres):
org_box = np.squeeze(org_box)
conf = org_box[..., 4] > conf_thres
box = org_box[conf == True]
cls_cinf = box[..., 5:]
cls = []
for i in range(len(cls_cinf)):
cls.append(int(np.argmax(cls_cinf[i])))
all_cls = list(set(cls))
output = []
for i in range(len(all_cls)):
curr_cls = all_cls[i]
curr_cls_box = []
curr_out_box = []
for j in range(len(cls)):
if cls[j] == curr_cls:
box[j][5] = curr_cls
curr_cls_box.append(box[j][:6])
curr_cls_box = np.array(curr_cls_box)
curr_cls_box = xywh2xyxy(curr_cls_box)
boxes = curr_cls_box[:, :4]
scores = curr_cls_box[:, 4]
indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), conf_thres, iou_thres)
for idx in indices:
curr_out_box.append(curr_cls_box[idx])
output.extend(curr_out_box)
output = np.array(output)
return output
def draw(image, box_data):
# 中文标签映射
ENGLISH_TO_CHINESE = {
'bird': '鸟',
'elephant': '大象',
'cat': '猫',
'dog': '狗',
'giraffe': '长颈鹿',
'horse': '马',
# 如果有其他标签,继续添加
}
if len(box_data) == 0:
print("没有检测到任何对象。")
return
boxes = box_data[..., :4].astype(np.int32)
scores = box_data[..., 4]
classes = box_data[..., 5].astype(np.int32)
for box, score, cl in zip(boxes, scores, classes):
top, left, right, bottom = box
# 获取中文标签,如果不存在则使用英文标签
chinese_label = ENGLISH_TO_CHINESE.get(CLASSES[cl], CLASSES[cl])
# 绘制框时显示英文标签
cv2.rectangle(image, (top, left), (right, bottom), (255, 0, 0), 2)
cv2.putText(image, '{0} {1:.2f}'.format(CLASSES[cl], score),
(top, left),
cv2.FONT_HERSHEY_SIMPLEX,
0.6, (0, 0, 255), 2)
# 终端输出中文标签
print('类别: {}'.format(chinese_label))
def main():
onnx_path = 'runs/train/5475/weights/best.onnx'
model = YOLOV5(onnx_path)
cap = cv2.VideoCapture(0)
while True:
ret, frame = cap.read()
if not ret:
break
output, or_img = model.inference(frame)
outbox = filter_box(output, 0.5, 0.5)
draw(or_img, outbox)
cv2.imshow('Video', or_img)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
if __name__ == "__main__":
main()
Yolov5部署onnx模型代码
最新推荐文章于 2024-06-15 06:51:38 发布