paddlex-目标检测demo

# -*- coding: utf-8 -*-
# @Time    : 2021/6/9 10:03
# @Author  : Johnson
#设置工作路径
import matplotlib
matplotlib.use("Agg")
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
import paddlex as pdx

os.listdir('/home/work/')

#生成数据集的TXT文件
'''
paddleX支持VOC格式数据,训练集和测试集需要定义txt文件,该文件保存图片路径和标注文件路径,格式如下:
JPEGImages/2009_003143.jpg Annotations/2009_003143.xml

JPEGImages/2012_001604.jpg Annotations/2012_001604.xml
'''
from random import shuffle, seed

base = '/home/aistudio/work/pascalvoc/VOCdevkit/VOC2012/'

imgs = os.listdir(os.path.join(base, 'JPEGImages'))
print('total:', len(imgs))

seed(666)
shuffle(imgs)

with open(os.path.join(base, 'train_list.txt'), 'w') as f:
    for im in imgs[:5000]:
        info = 'JPEGImages/'+im+' '
        info += 'Annotations/'+im[:-4]+'.xml\n'
        f.write(info)

with open(os.path.join(base, 'val_list.txt'), 'w') as f:
    for im in imgs[-1000:]:
        info = 'JPEGImages/'+im+' '
        info += 'Annotations/'+im[:-4]+'.xml\n'
        f.write(info)

CLASSES = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',

           'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',

           'motorbike', 'person', 'pottedplant', 'sheep', 'sofa',

           'train', 'tvmonitor']

with open('labels.txt', 'w') as f:
    for v in CLASSES:
        f.write(v+'\n')

# 定义数据预处理模块
#这里使用了图像混合、随机像素变换、随机膨胀、随机裁剪、随机水平翻转等数据增强方法

from paddlex.det import transforms
train_transforms = transforms.Compose([
    transforms.MixupImage(mixup_epoch=250),
    transforms.RandomDistort(),
    transforms.RandomExpand(),
    transforms.RandomCrop(),
    transforms.Resize(target_size=512, interp='RANDOM'),
    transforms.RandomHorizontalFlip(),
    transforms.Normalize(),
])

eval_transforms = transforms.Compose([
    transforms.Resize(target_size=512, interp='CUBIC'),
    transforms.Normalize(),
])

#定义训练集和测试集
base = '/home/aistudio/work/pascalvoc/VOCdevkit/VOC2012/'

train_dataset = pdx.datasets.VOCDetection(
    data_dir = base,
    file_list = os.path.join(base,'train_list.txt'),
    label_list = 'labels.txt',
    transforms = transforms,
    shuffle=True
)

eval_dataset = pdx.datasets.VOCDetection(
    data_dir=base,
    file_list=os.path.join(base,'val_list.txt'),
    label_list = 'labels.txt',
    transforms = eval_transforms
)

# 定义并训练模型
num_classes = len(train_dataset.labels) + 1
print('class num:', num_classes)
model = pdx.det.YOLOv3(
    num_classes=num_classes,
    backbone='MobileNetV3_large'
)
model.train(
    num_epochs=60,
    train_dataset=train_dataset,
    train_batch_size=4,
    eval_dataset=eval_dataset,
    learning_rate=0.00025,
    lr_decay_epochs=[20, 40],
    save_interval_epochs=4,
    log_interval_steps=100,
    save_dir='./YOLOv3',
    use_vdl=True)

#评估模型
model = pdx.load_model('./YOLOv3/best_model')
model.evaluate(eval_dataset, batch_size=1, epoch_id=None, metric=None, return_details=False)


### 测试模型检测结果
import cv2
import time
import numpy as np
import matplotlib.pyplot as plt
# %matplotlib inline

image_name = './test.jpg'
start = time.time()
result = model.predict(image_name, eval_transforms)
print('infer time:{:.6f}s'.format(time.time()-start))
print('detected num:', len(result))

im = cv2.imread(image_name)
font = cv2.FONT_HERSHEY_SIMPLEX
threshold = 0.01

for value in result:
    xmin, ymin, w, h = np.array(value['bbox']).astype(np.int)
    cls = value['category']
    score = value['score']
    if score < threshold:
        continue
    cv2.rectangle(im, (xmin, ymin), (xmin+w, ymin+h), (0, 255, 0), 4)
    cv2.putText(im, '{:s} {:.3f}'.format(cls, score),
                    (xmin, ymin), font, 0.5, (255, 0, 0), thickness=2)

cv2.imwrite('result.jpg', im)
plt.figure(figsize=(15,12))
plt.imshow(im[:, :, [2,1,0]])
plt.show()


#添加目标追踪
# pip install dlib

import dlib
import cv2


def plot_bboxes(image, bboxes, line_thickness=None):
    # Plots one bounding box on image img
    tl = line_thickness or round(
        0.002 * (image.shape[0] + image.shape[1]) / 2) + 1  # line/font thickness
    for (x1, y1, x2, y2, cls_id, pos_id) in bboxes:
        if cls_id in ['smoke', 'phone', 'eat']:
            color = (0, 0, 255)
        else:
            color = (0, 255, 0)
        if cls_id == 'eat':
            cls_id = 'eat-drink'
        c1, c2 = (x1, y1), (x2, y2)
        cv2.rectangle(image, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
        tf = max(tl - 1, 1)  # font thickness
        t_size = cv2.getTextSize(cls_id, 0, fontScale=tl / 3, thickness=tf)[0]
        c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
        cv2.rectangle(image, c1, c2, color, -1, cv2.LINE_AA)  # filled
        cv2.putText(image, '{} ID-{}'.format(cls_id, pos_id), (c1[0], c1[1] - 2), 0, tl / 3,
                    [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)

    return image


def update_tracker(target_detector, image):

    raw = image.copy()

    if target_detector.frameCounter > 2e+4:
        target_detector.frameCounter = 0

    faceIDtoDelete = []

    for faceID in target_detector.faceTracker.keys():
        trackingQuality = target_detector.faceTracker[faceID].update(image)

        if trackingQuality < 8:
            faceIDtoDelete.append(faceID)

    for faceID in faceIDtoDelete:
        target_detector.faceTracker.pop(faceID, None)
        target_detector.faceLocation1.pop(faceID, None)
        target_detector.faceLocation2.pop(faceID, None)
        target_detector.faceClasses.pop(faceID, None)

    new_faces = []

    if not (target_detector.frameCounter % target_detector.stride):

        _, bboxes = target_detector.detect(image)

        for (x1, y1, x2, y2, cls_id, _) in bboxes:
            x = int(x1)
            y = int(y1)
            w = int(x2-x1)
            h = int(y2-y1)

            x_bar = x + 0.5 * w
            y_bar = y + 0.5 * h

            matchCarID = None

            for faceID in target_detector.faceTracker.keys():
                trackedPosition = target_detector.faceTracker[faceID].get_position(
                )

                t_x = int(trackedPosition.left())
                t_y = int(trackedPosition.top())
                t_w = int(trackedPosition.width())
                t_h = int(trackedPosition.height())

                t_x_bar = t_x + 0.5 * t_w
                t_y_bar = t_y + 0.5 * t_h

                if t_x <= x_bar <= (t_x + t_w) and t_y <= y_bar <= (t_y + t_h):
                    if x <= t_x_bar <= (x + w) and y <= t_y_bar <= (y + h):
                        matchCarID = faceID

            if matchCarID is None:
                # 新出现的目标
                tracker = dlib.correlation_tracker()
                tracker.start_track(
                    image, dlib.rectangle(x, y, x + w, y + h))

                target_detector.faceTracker[target_detector.currentCarID] = tracker
                target_detector.faceLocation1[target_detector.currentCarID] = [
                    x, y, w, h]

                matchCarID = target_detector.currentCarID
                target_detector.currentCarID = target_detector.currentCarID + 1

                if cls_id == 'face':
                    pad_x = int(w * 0.15)
                    pad_y = int(h * 0.15)
                    if x > pad_x:
                        x = x-pad_x
                    if y > pad_y:
                        y = y-pad_y
                    face = raw[y:y+h+pad_y*2, x:x+w+pad_x*2]
                    new_faces.append((face, matchCarID))

                target_detector.faceClasses[matchCarID] = cls_id

    bboxes2draw = []
    for faceID in target_detector.faceTracker.keys():
        trackedPosition = target_detector.faceTracker[faceID].get_position()

        t_x = int(trackedPosition.left())
        t_y = int(trackedPosition.top())
        t_w = int(trackedPosition.width())
        t_h = int(trackedPosition.height())
        cls_id = target_detector.faceClasses[faceID]
        target_detector.faceLocation2[faceID] = [t_x, t_y, t_w, t_h]
        bboxes2draw.append(
            (t_x, t_y, t_x+t_w, t_y+t_h, cls_id, faceID)
        )

    image = plot_bboxes(image, bboxes2draw)

    return image, bboxes2draw


from os import walk
import cv2
import paddlex as pdx

class baseDet(object):

    def __init__(self):

        self.img_size = 640 # 图像大小
        self.threshold = 0.01 # 检测阈值
        self.stride = 2 # 检测步长(抽帧)
        self.model = pdx.load_model('./YOLOv3/best_model')
        self.build_config()

    def build_config(self):
        # 初始化追踪所需的变量
        self.faceTracker = {}
        self.faceClasses = {}
        self.faceLocation1 = {}
        self.faceLocation2 = {}
        self.frameCounter = 0
        self.currentCarID = 0
        self.walk_dict = {}
        self.recorded = []

        self.font = cv2.FONT_HERSHEY_SIMPLEX

    def feedCap(self, im):

        im, bboxes = update_tracker(self, im)

        return im, bboxes # 返回检测结果

    def detect(self, im):
        result = self.model.predict(im)
        pred_boxes = []
        for value in result:
            x1, y1, w, h = np.array(value['bbox']).astype(np.int)
            cls = value['category']
            score = value['score']
            if score > self.threshold:
                pred_boxes.append(
                    (x1, y1, x1+w, y1+h, cls, score)
                )
        return im, pred_boxes


DET = baseDet()

import matplotlib.pyplot as plt

im = cv2.imread("./test.jpg")

plt.imshow(im[:, :, [2,1,0]])
plt.show()


import numpy as np
res_im,bboxes = DET.feedCap(im)
plt.imshow(res_im[:, :, [2,1,0]])
plt.show()


for k, v in DET.faceLocation2.items():
    print(k, v)


import os
from tqdm import tqdm

class VideoCapture(object):

    def __init__(self, img_path):
        self.name = img_path
        self.base = '../MOT20/images/test/{}/img1'
        self.img_path = self.base.format(img_path)
        self.num = len(os.listdir(self.img_path))
        self.count = 0

    def read(self):
        self.count += 1
        img = os.path.join(self.img_path, '{:06}.jpg'.format(self.count))
        image = cv2.imread(img)
        return not image is None, image

cap = VideoCapture('MOT20-04')
font = cv2.FONT_HERSHEY_SIMPLEX

for fid in tqdm(range(cap.num)):
    success, frame = cap.read()
    if not success:
        break
    res_im, bboxes = DET.feedCap(frame)
    for id_, output in DET.faceLocation2.items():
        print(k, v)
        x1, y1 = output[0], output[1]
        w, h = output[2], output[3]
        conf_ = 1.0
        bboxes.append([fid, id_, x1, y1, w,
                               h, conf_, -1, -1, -1])
        # < frame >,< id >,< bb_left >,< bb_top >,< bb_width >,< bb_height >,< conf >,< x >,< y >,< z>

with open(cap.name + '.txt', 'w') as f:
    for box in bboxes:
        line = ''
        for v in box:
            line += ',{}'.format(v)
        line = line[1:] + '\n'+([ h, conf_, -1, -1, -1])
        # < frame >,< id >,< bb_left >,< bb_top >,< bb_width >,< bb_height >,< conf >,< x >,< y >,< z>

with open(cap.name + '.txt', 'w') as f:
    for box in bboxes:
        line = ''
        for v in box:
            line += ',{}'.format(v)
        line = line[1:] + '\n'
        f.write(line)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值