需要将yolov5s官方权重和模型打包在一起,便于使用,否则就得在项目中放入yolov5 model的源码,实时加载,不建议。
模型打包和查看模型计算量和参数量的代码如下:
import torch
from models.common import DetectMultiBackend
from thop import profile, clever_format #
# 模型输入
input = torch.empty(size=(1, 3, 640, 640), dtype=torch.float, device="cpu")
# weights
weights = r"E:\Python_C++_Demo\yolov5-master\weights\yolov5s.pt"
# model
model = DetectMultiBackend(weights=weights)
# save model
torch.save(model, "yolov5s_save.pt")
# model test
model = torch.load("yolov5s_save.pt")
# print(model(input))
# 查看模型计算量
flops, parameters = profile(model, inputs=(input,))
print("flops:", flops, "parameters:", parameters)
# 格式化输出计算量和参数量
flops, parameters = clever_format([flops, parameters], "%.3f")
# 打印结果
print(f"Model FLOPs: {flops}")
print(f"Model Params: {parameters}")
模型推理代码如下:
import cv2
import torch
from utils.general import non_max_suppression
from utils.augmentations import letterbox
import numpy as np
import yaml
device = "cuda" if torch.cuda.is_available() else "cpu"
# 获取类别
with open(r"E:\Python_C++_Demo\yolov5-master\data\coco.yaml", "r", encoding='utf-8') as file:
config = yaml.safe_load(file)
classes = config["names"]
def detect(img_path, weights, show_img=True):
img = cv2.imread(img_path)
im, ratio, (dw, dh) = letterbox(img)
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im)
im = torch.from_numpy(im).to(device)
im = im.float()
im /= 255
im = torch.unsqueeze(im, dim=0)
# print(im.shape)
model = torch.load(weights).to(device)
model.eval() # 开启测试
pred = model(im)
pred = non_max_suppression(pred)
for i, det in enumerate(pred):
for *xyxy, conf, cls in reversed(det.cpu().numpy()):
x1 = int((xyxy[0] - dw) / ratio[0])
y1 = int((xyxy[1] - dh) / ratio[0])
x2 = int((xyxy[2] - dw) / ratio[0])
y2 = int((xyxy[3] - dh) / ratio[0])
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
cls = classes[int(cls)]
# conf_str = "%.2f" % float(conf)
conf_str = f"{conf:.2f}"
# print(conf)
cv2.putText(img, f"{cls}:{conf_str}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
if show_img:
cv2.imshow("img", img)
cv2.waitKey(0)
if __name__ == '__main__':
img_path = r"E:\Python_C++_Demo\yolov5-master\data\images\bus.jpg"
weights = "yolov5s_save.pt"
detect(img_path, weights, show_img=True)
推理结果如下:
有更好的建议,欢迎指点哈~