import os.path
import cv2
import requests
import numpy as np
from ultralytics.utils import yaml_load
from ultralytics.utils.checks import check_yaml
class ImageDetect:
"""图片检测"""
def __init__(self):
self.MODEL_BASE_DIR = os.path.dirname(__file__)
self.CLASSES = yaml_load(check_yaml("coco128.yaml"))["names"]
self.colors = np.random.uniform(0, 255, size=(len(self.CLASSES), 3))
self.model_path = os.path.join(self.MODEL_BASE_DIR, "model/yolov8n.onnx")
self.onnx_model = None
def load_model(self):
"""加载模型"""
# Load the ONNX model
self.onnx_model: cv2.dnn.Net = cv2.dnn.readNetFromONNX(self.model_path)
def draw_bounding_box(self, img, class_id, confidence, x, y, x_plus_w, y_plus_h):
"""
Draws bounding boxes on the input image based on the provided arguments.
Args:
img (numpy.ndarray): The input image to draw the bounding box on.
class_id (int): Class ID of the detected object.
confidence (float): Confidence score of the detected object.
x (int): X-coordinate of the top-left corner of the bounding box.
y (int): Y-coordinate of the top-left corner of the bounding box.
x_plus_w (int): X-coordinate of the bottom-right corner of the bounding box.
y_plus_h (int): Y-coordinate of the bottom-right corner of the bounding box.
"""
label = f"{self.CLASSES[class_id]} ({confidence:.2f})"
color = self.colors[class_id]
cv2.rectangle(img, (x, y), (x_plus_w, y_plus_h), color, 2)
cv2.putText(img, label, (x - 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
def parse_image(self, image: str, show: bool = False):
"""
parse_image function to load ONNX model, perform inference, draw bounding boxes, and display the output image.
Args:
image (str): 图片路径
show (bool): 是否展示识别后的结果
Returns:
list: List of dictionaries containing detection information such as class_id, class_name, confidence, etc.
"""
if not self.onnx_model:
self.load_model()
if not image.startswith("http"):
# 读取本地图片
original_image: np.ndarray = cv2.imread(image)
else:
# 读取网络图片
response = requests.get(image)
image_array = np.frombuffer(response.content, dtype=np.uint8)
original_image: np.ndarray = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
[height, width, _] = original_image.shape
# Prepare a square image for inference
length = max((height, width))
image = np.zeros((length, length, 3), np.uint8)
image[0:height, 0:width] = original_image
# Calculate scale factor
scale = length / 640
# Preprocess the image and prepare blob for model
blob = cv2.dnn.blobFromImage(image, scalefactor=1 / 255, size=(640, 640), swapRB=True)
self.onnx_model.setInput(blob)
# Perform inference
outputs = self.onnx_model.forward()
# Prepare output array
outputs = np.array([cv2.transpose(outputs[0])])
rows = outputs.shape[1]
boxes = []
scores = []
class_ids = []
# Iterate through output to collect bounding boxes, confidence scores, and class IDs
for i in range(rows):
classes_scores = outputs[0][i][4:]
(minScore, maxScore, minClassLoc, (x, maxClassIndex)) = cv2.minMaxLoc(classes_scores)
if maxScore >= 0.25:
box = [
outputs[0][i][0] - (0.5 * outputs[0][i][2]), outputs[0][i][1] - (0.5 * outputs[0][i][3]),
outputs[0][i][2], outputs[0][i][3]]
boxes.append(box)
scores.append(maxScore)
class_ids.append(maxClassIndex)
# Apply NMS (Non-maximum suppression)
result_boxes = cv2.dnn.NMSBoxes(boxes, scores, 0.25, 0.45, 0.5)
detections = []
# Iterate through NMS results to draw bounding boxes and labels
for i in range(len(result_boxes)):
index = result_boxes[i]
box = boxes[index]
detection = {
"class_id": class_ids[index], # 分类id
"class_name": self.CLASSES[class_ids[index]], # 分类名称
"confidence": scores[index], # 置信度
"box": box,
"scale": scale}
detections.append(detection)
# print(detection)
if show:
self.draw_bounding_box(
original_image, class_ids[index], scores[index],
round(box[0] * scale),
round(box[1] * scale),
round((box[0] + box[2]) * scale),
round((box[1] + box[3]) * scale)
)
if show:
# Display the image with bounding boxes
cv2.imshow('image', original_image)
cv2.waitKey(0)
cv2.destroyAllWindows()
return detections
detect = ImageDetect()
if __name__ == '__main__':
for _ in range(1):
detect.parse_image(image="image/bus.jpg")
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
- 45.
- 46.
- 47.
- 48.
- 49.
- 50.
- 51.
- 52.
- 53.
- 54.
- 55.
- 56.
- 57.
- 58.
- 59.
- 60.
- 61.
- 62.
- 63.
- 64.
- 65.
- 66.
- 67.
- 68.
- 69.
- 70.
- 71.
- 72.
- 73.
- 74.
- 75.
- 76.
- 77.
- 78.
- 79.
- 80.
- 81.
- 82.
- 83.
- 84.
- 85.
- 86.
- 87.
- 88.
- 89.
- 90.
- 91.
- 92.
- 93.
- 94.
- 95.
- 96.
- 97.
- 98.
- 99.
- 100.
- 101.
- 102.
- 103.
- 104.
- 105.
- 106.
- 107.
- 108.
- 109.
- 110.
- 111.
- 112.
- 113.
- 114.
- 115.
- 116.
- 117.
- 118.
- 119.
- 120.
- 121.
- 122.
- 123.
- 124.
- 125.
- 126.
- 127.
- 128.
- 129.
- 130.
- 131.
- 132.
- 133.
- 134.
- 135.
- 136.
- 137.
- 138.
- 139.
- 140.
- 141.
服务器占用资源小,默认数据集不太好用。建议自己针对场景进行训练。
推荐一个占用资源搭,比较好用的:https://github.com/xinyu1205/recognize-anything
作者:一石数字欠我15w!!!