前言:本人最近在学着使用yolo,为了方便调用,今天看了下detect.py源码并做了点修改,不过由于训练的模型是单类型的,所以在推理在结果中还没看出哪个数值代表“类型”,所以以下内容只针对单类型的模型,后面琢磨出来了会更新。希望有大佬能指导一下。
在detect.py中添加如下代码块
class yolo_detector:
def __init__(self,
weights='./Weights/last.pt', # 用train.py训练出的.pt文件
imgsz=(640,640),
conf_thres=0.25,
iou_thres=0.45,
half=False,
):
self.conf_thres = conf_thres
self.iou_thres = iou_thres
self.device = select_device('0')
self.model = DetectMultiBackend(weights, device=self.device) # 加载模型
stride, names, pt = self.model.stride, self.model.names, self.model.pt
self.imgsz = check_img_size(imgsz, s=stride) # check image size
half &= pt and self.device.type != 'cpu