yolov5推理类

import argparse
import time, os
from pathlib import Path

import cv2
import torch
import torch.backends.cudnn as cudnn
from numpy import random

from models.experimental import attempt_load
from utils.datasets import *
from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
    scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path
from utils.plots import plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized


class RetailF(object):
    def __init__(self, imgsz=800, weights='', device='CPU', logger=None, vis_img=False):
        self.vis_img = vis_img
        self.logger = logger
        self.iou_thres = 0.25
        self.conf_thres = 0.45
        self.agnostic_nms = 0.5
        self.classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
        self.colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(self.classes))]
        self.device = select_device('cpu')
        
        self.model = attempt_load(weights, map_location=self.device)
        self.model.eval()
        self.stride = int(self.model.stride.max())
        self.imgsz = check_img_size(imgsz, s=self.stride)
        self.half = self.device.type != 'cpu'
        if self.half:
            self.model.half()
        self.names = self.model.module.names if hasattr(self.model, 'module') else self.model.names
        self.colors = [[random.randint(0, 255) for _ in range(3)] for _ in self.names]
        


    def predict(self, im0):
        # save_dir = "./data/pricetagresult/"
        
        # webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(('rtsp://', 'rtmp://', 'http://'))
        # if webcam:
        #     view_img = check_imshow()
        #     cudnn.benchmark = True  # set True to speed up constant image size inference
        #     dataset = LoadStreams(source, img_size=self.imgsz, stride=self.stride)
        # else:
        #     dataset = LoadImages(source, img_size=self.imgsz, stride=self.stride)

        # Run inference
        # img = torch.zeros((1, 3, self.imgsz, self.imgsz), device=device)  # init img

        result_lists = list()
        # save_img = 1
        # """
        # path 图片/视频路径
        # img 进行resize+pad之后的图片
        # img0 原size图片
        # cap 当读取图片时为None,读取视频时为视频源
        # """
        # for path, img, im0s, vid_cap in dataset:
        # shape = (self.imgsz, self.imgsz)
        # for img  in source:

        # img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
        img = letterbox(im0, new_shape=self.imgsz)[0]

        # Convert
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
        img = np.ascontiguousarray(img)

        
        img = torch.from_numpy(img).to(self.device)
        img = img.half() if self.half else img.float()  # uint8 to fp16/32
        img /= 255.0  # 0 - 255 to 0.0 - 1.0
        if img.ndimension() == 3:
            img = img.unsqueeze(0)

        # Inference
        # t1 = torch_utils.time_synchronized()
        pred = self.model(img, augment=False)[0]

        # Apply NMS
        pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, classes=None, agnostic=self.agnostic_nms)
        # t2 = torch_utils.time_synchronized()
        # gn = torch.tensor(img.shape)[[1, 0, 1, 0]]

        # Process detectionsprint('img.shape:{}'.format(img.shape))
        # n = 0

        result_list = list()

        for i, det in enumerate(pred):

            # n += 1
            
            # p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count

            # if webcam:  # batch_size >= 1
            #     p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count
            # else:
            #     p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)

            # p = Path(p)  # to Path
            # save_path = str(save_dir / p.name)  # img.jpg
            # txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}')  # img.txt
            # s += '%gx%g ' % img.shape[2:]  # print string
            # gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh

            if len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()

                # Print results
                # print(det)
                # for c in det[:, -1].unique():
                #     n = (det[:, -1] == c).sum()  # detections per class
                #     # s += '%g %s, ' % (n, self.classes[int(c)])  # add to string
                #     s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, "  # add to string

                # det = list(det)
                # for i in det:
                #     i[5] = i[3] - i[1]
                # det.sort(key = self.sort_by_conf, reverse=False)

                for *xyxy, conf, cls in det:
                    results = [int(xyxy[0]), int(xyxy[1]), int(xyxy[2]),int(xyxy[3])]
                    result_list.append(results)
                    # x y w h classes conf
                    # if save_img or view_img:  # Add bbox to image
                    #     label = f'{self.names[int(cls)]} {conf:.2f}'
            #             plot_one_box(xyxy, im0, label='', color=self.colors[int(cls)], line_thickness=3)
            # img_name = str(n) + ".jpg"
            # if save_img:
            #     cv2.imwrite(save_dir + img_name, im0)

            # result_lists.append(result_list)

        return result_list

if __name__ == '__main__':
    retailf = RetailF(weights="./weights/0715pricetagbest.pt", vis_img=True)
    path = "./data/pritag1/"
    # print('***'*30)
    # result  = retailf.predict(path)
    # print(result)
    for img_name in os.listdir(path):

        # if '.bmp' in img_name:
        img = cv2.imread(path + img_name)
        # img = cv2.resize(img, (100,50))
        result = retailf.predict(img)

        for item in result:
            #print("**************", item)
            cv2.rectangle(img, (int(item[0]), int(item[1])), (int(item[2]), int(item[3])), (0, 255, 0), 5)

        # if result == "DH":
        #     print(img_name)
        cv2.imwrite("./data/pricetagresult/" + img_name, img)


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值