import argparse
import time, os
from pathlib import Path
import cv2
import torch
import torch.backends.cudnn as cudnn
from numpy import random
from models.experimental import attempt_load
from utils.datasets import *
from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path
from utils.plots import plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized
class RetailF(object):
def __init__(self, imgsz=800, weights='', device='CPU', logger=None, vis_img=False):
self.vis_img = vis_img
self.logger = logger
self.iou_thres = 0.25
self.conf_thres = 0.45
self.agnostic_nms = 0.5
self.classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
self.colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(self.classes))]
self.device = select_device('cpu')
self.model = attempt_load(weights, map_location=self.device)
self.model.eval()
self.stride = int(self.model.stride.max())
self.imgsz = check_img_size(imgsz, s=self.stride)
self.half = self.device.type != 'cpu'
if self.half:
self.model.half()
self.names = self.model.module.names if hasattr(self.model, 'module') else self.model.names
self.colors = [[random.randint(0, 255) for _ in range(3)] for _ in self.names]
def predict(self, im0):
# save_dir = "./data/pricetagresult/"
# webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(('rtsp://', 'rtmp://', 'http://'))
# if webcam:
# view_img = check_imshow()
# cudnn.benchmark = True # set True to speed up constant image size inference
# dataset = LoadStreams(source, img_size=self.imgsz, stride=self.stride)
# else:
# dataset = LoadImages(source, img_size=self.imgsz, stride=self.stride)
# Run inference
# img = torch.zeros((1, 3, self.imgsz, self.imgsz), device=device) # init img
result_lists = list()
# save_img = 1
# """
# path 图片/视频路径
# img 进行resize+pad之后的图片
# img0 原size图片
# cap 当读取图片时为None,读取视频时为视频源
# """
# for path, img, im0s, vid_cap in dataset:
# shape = (self.imgsz, self.imgsz)
# for img in source:
# img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
img = letterbox(im0, new_shape=self.imgsz)[0]
# Convert
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
img = np.ascontiguousarray(img)
img = torch.from_numpy(img).to(self.device)
img = img.half() if self.half else img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
# Inference
# t1 = torch_utils.time_synchronized()
pred = self.model(img, augment=False)[0]
# Apply NMS
pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, classes=None, agnostic=self.agnostic_nms)
# t2 = torch_utils.time_synchronized()
# gn = torch.tensor(img.shape)[[1, 0, 1, 0]]
# Process detectionsprint('img.shape:{}'.format(img.shape))
# n = 0
result_list = list()
for i, det in enumerate(pred):
# n += 1
# p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count
# if webcam: # batch_size >= 1
# p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count
# else:
# p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)
# p = Path(p) # to Path
# save_path = str(save_dir / p.name) # img.jpg
# txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt
# s += '%gx%g ' % img.shape[2:] # print string
# gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
# Print results
# print(det)
# for c in det[:, -1].unique():
# n = (det[:, -1] == c).sum() # detections per class
# # s += '%g %s, ' % (n, self.classes[int(c)]) # add to string
# s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
# det = list(det)
# for i in det:
# i[5] = i[3] - i[1]
# det.sort(key = self.sort_by_conf, reverse=False)
for *xyxy, conf, cls in det:
results = [int(xyxy[0]), int(xyxy[1]), int(xyxy[2]),int(xyxy[3])]
result_list.append(results)
# x y w h classes conf
# if save_img or view_img: # Add bbox to image
# label = f'{self.names[int(cls)]} {conf:.2f}'
# plot_one_box(xyxy, im0, label='', color=self.colors[int(cls)], line_thickness=3)
# img_name = str(n) + ".jpg"
# if save_img:
# cv2.imwrite(save_dir + img_name, im0)
# result_lists.append(result_list)
return result_list
if __name__ == '__main__':
retailf = RetailF(weights="./weights/0715pricetagbest.pt", vis_img=True)
path = "./data/pritag1/"
# print('***'*30)
# result = retailf.predict(path)
# print(result)
for img_name in os.listdir(path):
# if '.bmp' in img_name:
img = cv2.imread(path + img_name)
# img = cv2.resize(img, (100,50))
result = retailf.predict(img)
for item in result:
#print("**************", item)
cv2.rectangle(img, (int(item[0]), int(item[1])), (int(item[2]), int(item[3])), (0, 255, 0), 5)
# if result == "DH":
# print(img_name)
cv2.imwrite("./data/pricetagresult/" + img_name, img)
yolov5推理类
于 2021-07-15 22:17:32 首次发布