# -*- coding: utf-8 -*-
# @Time : 2021/6/9 10:03
# @Author : Johnson
#设置工作路径
import matplotlib
matplotlib.use("Agg")
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
import paddlex as pdx
os.listdir('/home/work/')
#生成数据集的TXT文件
'''
paddleX支持VOC格式数据,训练集和测试集需要定义txt文件,该文件保存图片路径和标注文件路径,格式如下:
JPEGImages/2009_003143.jpg Annotations/2009_003143.xml
JPEGImages/2012_001604.jpg Annotations/2012_001604.xml
'''
from random import shuffle, seed
base = '/home/aistudio/work/pascalvoc/VOCdevkit/VOC2012/'
imgs = os.listdir(os.path.join(base, 'JPEGImages'))
print('total:', len(imgs))
seed(666)
shuffle(imgs)
with open(os.path.join(base, 'train_list.txt'), 'w') as f:
for im in imgs[:5000]:
info = 'JPEGImages/'+im+' '
info += 'Annotations/'+im[:-4]+'.xml\n'
f.write(info)
with open(os.path.join(base, 'val_list.txt'), 'w') as f:
for im in imgs[-1000:]:
info = 'JPEGImages/'+im+' '
info += 'Annotations/'+im[:-4]+'.xml\n'
f.write(info)
CLASSES = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant', 'sheep', 'sofa',
'train', 'tvmonitor']
with open('labels.txt', 'w') as f:
for v in CLASSES:
f.write(v+'\n')
# 定义数据预处理模块
#这里使用了图像混合、随机像素变换、随机膨胀、随机裁剪、随机水平翻转等数据增强方法
from paddlex.det import transforms
train_transforms = transforms.Compose([
transforms.MixupImage(mixup_epoch=250),
transforms.RandomDistort(),
transforms.RandomExpand(),
transforms.RandomCrop(),
transforms.Resize(target_size=512, interp='RANDOM'),
transforms.RandomHorizontalFlip(),
transforms.Normalize(),
])
eval_transforms = transforms.Compose([
transforms.Resize(target_size=512, interp='CUBIC'),
transforms.Normalize(),
])
#定义训练集和测试集
base = '/home/aistudio/work/pascalvoc/VOCdevkit/VOC2012/'
train_dataset = pdx.datasets.VOCDetection(
data_dir = base,
file_list = os.path.join(base,'train_list.txt'),
label_list = 'labels.txt',
transforms = transforms,
shuffle=True
)
eval_dataset = pdx.datasets.VOCDetection(
data_dir=base,
file_list=os.path.join(base,'val_list.txt'),
label_list = 'labels.txt',
transforms = eval_transforms
)
# 定义并训练模型
num_classes = len(train_dataset.labels) + 1
print('class num:', num_classes)
model = pdx.det.YOLOv3(
num_classes=num_classes,
backbone='MobileNetV3_large'
)
model.train(
num_epochs=60,
train_dataset=train_dataset,
train_batch_size=4,
eval_dataset=eval_dataset,
learning_rate=0.00025,
lr_decay_epochs=[20, 40],
save_interval_epochs=4,
log_interval_steps=100,
save_dir='./YOLOv3',
use_vdl=True)
#评估模型
model = pdx.load_model('./YOLOv3/best_model')
model.evaluate(eval_dataset, batch_size=1, epoch_id=None, metric=None, return_details=False)
### 测试模型检测结果
import cv2
import time
import numpy as np
import matplotlib.pyplot as plt
# %matplotlib inline
image_name = './test.jpg'
start = time.time()
result = model.predict(image_name, eval_transforms)
print('infer time:{:.6f}s'.format(time.time()-start))
print('detected num:', len(result))
im = cv2.imread(image_name)
font = cv2.FONT_HERSHEY_SIMPLEX
threshold = 0.01
for value in result:
xmin, ymin, w, h = np.array(value['bbox']).astype(np.int)
cls = value['category']
score = value['score']
if score < threshold:
continue
cv2.rectangle(im, (xmin, ymin), (xmin+w, ymin+h), (0, 255, 0), 4)
cv2.putText(im, '{:s} {:.3f}'.format(cls, score),
(xmin, ymin), font, 0.5, (255, 0, 0), thickness=2)
cv2.imwrite('result.jpg', im)
plt.figure(figsize=(15,12))
plt.imshow(im[:, :, [2,1,0]])
plt.show()
#添加目标追踪
# pip install dlib
import dlib
import cv2
def plot_bboxes(image, bboxes, line_thickness=None):
# Plots one bounding box on image img
tl = line_thickness or round(
0.002 * (image.shape[0] + image.shape[1]) / 2) + 1 # line/font thickness
for (x1, y1, x2, y2, cls_id, pos_id) in bboxes:
if cls_id in ['smoke', 'phone', 'eat']:
color = (0, 0, 255)
else:
color = (0, 255, 0)
if cls_id == 'eat':
cls_id = 'eat-drink'
c1, c2 = (x1, y1), (x2, y2)
cv2.rectangle(image, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
tf = max(tl - 1, 1) # font thickness
t_size = cv2.getTextSize(cls_id, 0, fontScale=tl / 3, thickness=tf)[0]
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
cv2.rectangle(image, c1, c2, color, -1, cv2.LINE_AA) # filled
cv2.putText(image, '{} ID-{}'.format(cls_id, pos_id), (c1[0], c1[1] - 2), 0, tl / 3,
[225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
return image
def update_tracker(target_detector, image):
raw = image.copy()
if target_detector.frameCounter > 2e+4:
target_detector.frameCounter = 0
faceIDtoDelete = []
for faceID in target_detector.faceTracker.keys():
trackingQuality = target_detector.faceTracker[faceID].update(image)
if trackingQuality < 8:
faceIDtoDelete.append(faceID)
for faceID in faceIDtoDelete:
target_detector.faceTracker.pop(faceID, None)
target_detector.faceLocation1.pop(faceID, None)
target_detector.faceLocation2.pop(faceID, None)
target_detector.faceClasses.pop(faceID, None)
new_faces = []
if not (target_detector.frameCounter % target_detector.stride):
_, bboxes = target_detector.detect(image)
for (x1, y1, x2, y2, cls_id, _) in bboxes:
x = int(x1)
y = int(y1)
w = int(x2-x1)
h = int(y2-y1)
x_bar = x + 0.5 * w
y_bar = y + 0.5 * h
matchCarID = None
for faceID in target_detector.faceTracker.keys():
trackedPosition = target_detector.faceTracker[faceID].get_position(
)
t_x = int(trackedPosition.left())
t_y = int(trackedPosition.top())
t_w = int(trackedPosition.width())
t_h = int(trackedPosition.height())
t_x_bar = t_x + 0.5 * t_w
t_y_bar = t_y + 0.5 * t_h
if t_x <= x_bar <= (t_x + t_w) and t_y <= y_bar <= (t_y + t_h):
if x <= t_x_bar <= (x + w) and y <= t_y_bar <= (y + h):
matchCarID = faceID
if matchCarID is None:
# 新出现的目标
tracker = dlib.correlation_tracker()
tracker.start_track(
image, dlib.rectangle(x, y, x + w, y + h))
target_detector.faceTracker[target_detector.currentCarID] = tracker
target_detector.faceLocation1[target_detector.currentCarID] = [
x, y, w, h]
matchCarID = target_detector.currentCarID
target_detector.currentCarID = target_detector.currentCarID + 1
if cls_id == 'face':
pad_x = int(w * 0.15)
pad_y = int(h * 0.15)
if x > pad_x:
x = x-pad_x
if y > pad_y:
y = y-pad_y
face = raw[y:y+h+pad_y*2, x:x+w+pad_x*2]
new_faces.append((face, matchCarID))
target_detector.faceClasses[matchCarID] = cls_id
bboxes2draw = []
for faceID in target_detector.faceTracker.keys():
trackedPosition = target_detector.faceTracker[faceID].get_position()
t_x = int(trackedPosition.left())
t_y = int(trackedPosition.top())
t_w = int(trackedPosition.width())
t_h = int(trackedPosition.height())
cls_id = target_detector.faceClasses[faceID]
target_detector.faceLocation2[faceID] = [t_x, t_y, t_w, t_h]
bboxes2draw.append(
(t_x, t_y, t_x+t_w, t_y+t_h, cls_id, faceID)
)
image = plot_bboxes(image, bboxes2draw)
return image, bboxes2draw
from os import walk
import cv2
import paddlex as pdx
class baseDet(object):
def __init__(self):
self.img_size = 640 # 图像大小
self.threshold = 0.01 # 检测阈值
self.stride = 2 # 检测步长(抽帧)
self.model = pdx.load_model('./YOLOv3/best_model')
self.build_config()
def build_config(self):
# 初始化追踪所需的变量
self.faceTracker = {}
self.faceClasses = {}
self.faceLocation1 = {}
self.faceLocation2 = {}
self.frameCounter = 0
self.currentCarID = 0
self.walk_dict = {}
self.recorded = []
self.font = cv2.FONT_HERSHEY_SIMPLEX
def feedCap(self, im):
im, bboxes = update_tracker(self, im)
return im, bboxes # 返回检测结果
def detect(self, im):
result = self.model.predict(im)
pred_boxes = []
for value in result:
x1, y1, w, h = np.array(value['bbox']).astype(np.int)
cls = value['category']
score = value['score']
if score > self.threshold:
pred_boxes.append(
(x1, y1, x1+w, y1+h, cls, score)
)
return im, pred_boxes
DET = baseDet()
import matplotlib.pyplot as plt
im = cv2.imread("./test.jpg")
plt.imshow(im[:, :, [2,1,0]])
plt.show()
import numpy as np
res_im,bboxes = DET.feedCap(im)
plt.imshow(res_im[:, :, [2,1,0]])
plt.show()
for k, v in DET.faceLocation2.items():
print(k, v)
import os
from tqdm import tqdm
class VideoCapture(object):
def __init__(self, img_path):
self.name = img_path
self.base = '../MOT20/images/test/{}/img1'
self.img_path = self.base.format(img_path)
self.num = len(os.listdir(self.img_path))
self.count = 0
def read(self):
self.count += 1
img = os.path.join(self.img_path, '{:06}.jpg'.format(self.count))
image = cv2.imread(img)
return not image is None, image
cap = VideoCapture('MOT20-04')
font = cv2.FONT_HERSHEY_SIMPLEX
for fid in tqdm(range(cap.num)):
success, frame = cap.read()
if not success:
break
res_im, bboxes = DET.feedCap(frame)
for id_, output in DET.faceLocation2.items():
print(k, v)
x1, y1 = output[0], output[1]
w, h = output[2], output[3]
conf_ = 1.0
bboxes.append([fid, id_, x1, y1, w,
h, conf_, -1, -1, -1])
# < frame >,< id >,< bb_left >,< bb_top >,< bb_width >,< bb_height >,< conf >,< x >,< y >,< z>
with open(cap.name + '.txt', 'w') as f:
for box in bboxes:
line = ''
for v in box:
line += ',{}'.format(v)
line = line[1:] + '\n'+([ h, conf_, -1, -1, -1])
# < frame >,< id >,< bb_left >,< bb_top >,< bb_width >,< bb_height >,< conf >,< x >,< y >,< z>
with open(cap.name + '.txt', 'w') as f:
for box in bboxes:
line = ''
for v in box:
line += ',{}'.format(v)
line = line[1:] + '\n'
f.write(line)
paddlex-目标检测demo
最新推荐文章于 2024-07-12 17:17:59 发布