训练样本裁剪,背景裁剪,yoloV5裁剪样本检测

有时训练和检测图像的分辨率可能会比较大,直接缩放后放到模型中训练和识别可能小目标的性能会比较差,多尺度训练可以改善,样本裁剪也可以改善,yolt 针对高分辨率卫星图像的目标检测就是这个思路

import os
import pandas as pd
import numpy as np
import cv2
import json
import random
import glob
import math
import shutil
from tqdm import trange
from pathlib import Path
from PIL import Image
from sklearn.model_selection import train_test_split
import matplotlib.pylab as plt

def drawImg(img, resizeCo):
    resizeCo = 50
    b,g,r = cv2.split(img)
    image_RGB = cv2.merge([r,g,b])
    plt.figure(figsize=(image_RGB.shape[0] / resizeCo, image_RGB.shape[1]/ resizeCo))
    plt.imshow(image_RGB)
    plt.show()

def mat_inter(box1, box2):
    # 判断两个矩形是否相交
    # box=(xA,yA,xB,yB)
    (x01, y01, x02, y02) = box1
    (x11, y11, x12, y12)  = box2
 
    lx = abs((x01 + x02) / 2 - (x11 + x12) / 2)
    ly = abs((y01 + y02) / 2 - (y11 + y12) / 2)
    sax = abs(x01 - x02)
    sbx = abs(x11 - x12)
    say = abs(y01 - y02)
    sby = abs(y11 - y12)
    if lx <= (sax + sbx) / 2 and ly <= (say + sby) / 2:
        return True
    else:
        return False


def solve_coincide(box1,box2):
    # 计算两个矩形框的重合度
    # box=(xA,yA,xB,yB)
    if mat_inter(box1,box2)==True:
        (x01, y01, x02, y02) = box1
        (x11, y11, x12, y12)  = box2
        col=min(x02,x12)-max(x01,x11)
        row=min(y02,y12)-max(y01,y11)
        intersection=col*row
        area1=(x02-x01)*(y02-y01)
        area2=(x12-x11)*(y12-y11)
        coincide=intersection/(area1+area2-intersection)
        coincide=intersection/(area2)
        return coincide
    else:
        return 0


def expandBox(bbox, expandRatio):
    # 缩放矩形框,由中心扩展
    boxCenter = ((bbox[0]+bbox[2])/2, (bbox[1]+bbox[3])/2)
    w = bbox[2] - bbox[0]
    h = bbox[3] - bbox[1]
    nw = w * expandRatio
    nh = h * expandRatio

    x1 = boxCenter[0] - nw / 2
    x2 = boxCenter[0] + nw / 2
    y1 = boxCenter[1] - nh / 2
    y2 = boxCenter[1] + nh / 2
    return [x1, y1, x2, y2]


def resizeBox(bbox, expandRatio):
    """
    缩放矩形标注框
    """
    x1 = bbox[0]*expandRatio
    x2 = bbox[2]*expandRatio
    y1 = bbox[1]*expandRatio
    y2 = bbox[3]*expandRatio
    return [x1, y1, x2, y2]


def prepare_dirs(prefix='/output/'):
    """
    """
    img_dir = os.path.join(prefix, "images")
    label_dir = os.path.join(prefix, "labels")
    if os.path.exists(img_dir) and os.path.isdir(img_dir):
        shutil.rmtree(img_dir)
    if os.path.exists(label_dir) and os.path.isdir(label_dir):
        shutil.rmtree(label_dir)
    os.makedirs(img_dir, exist_ok=True)
    os.makedirs(label_dir, exist_ok=True)
    return img_dir, label_dir


def covertYoloLabelxyxy(imWidth, imHeight, x1, y1, x2, y2):
    dw = 1 / imWidth
    dh = 1 / imHeight
 
    centerX = (x1 + x2) / 2.0 
    centerY = (y1 + y2) / 2.0
    w = (x2 - x1)
    h = (y2 - y1)
    
    centerX = centerX * dw 
    centerY = centerY * dh
    w = w * dw
    h = h * dh
    
    return [centerX, centerY, w, h]

def drawBBox(image, color, objBbox):
    image = cv2.rectangle(image, (int(objBbox[0]), int(objBbox[1])), (int(objBbox[2]), int(objBbox[3])), color, 2)
    

def Yolo2RealLabel(imWidth, imHeight, centerX, centerY, w, h):
    centerX = centerX * imWidth 
    centerY = centerY * imHeight
    w = w * imWidth
    h = h * imHeight
    
    x1 = centerX - w/2
    x2 = centerX + w/2
    y1 = centerY - h/2
    y2 = centerY + h/2
    
    return [x1, y1, x2, y2]

def calLabelAddRate(ratio):
    coe = (math.atan(ratio)- 0.7853981633974483) / (1.5707963 - 0.7853981633974483)
    coe = 0.9 - coe * 0.4
    return coe

单张图像裁剪

def cropImgSlidingWindow(img, cropSize, repetitionRate):
    """
    滑动窗口裁剪图像
    img:<h,w,c> or <h,w>
    cropSize:裁剪大小,宽高一致
    repetitionRate:滑动窗口间的重叠率
    return:裁剪框,裁剪图像
    """
    height = img.shape[0]
    width = img.shape[1]
    
    croppedBoxs = []
    for i in range(int((height - CropSize * RepetitionRate) / (CropSize * (1 - RepetitionRate)))):
        for j in range(int((width - CropSize * RepetitionRate) / (CropSize * (1 - RepetitionRate)))):
            
            ys = int(i * CropSize * (1 - RepetitionRate)) 
            ye = int(i * CropSize * (1 - RepetitionRate)) + CropSize
            xs = int(j * CropSize * (1 - RepetitionRate)) 
            xe = int(j * CropSize * (1 - RepetitionRate)) + CropSize
            if xe > width:
                print('游动裁剪')
                print((xs, ys, xe, ye))
            croppedBoxs.append((xs, ys, xe, ye))
                 

    for i in range(int((height-CropSize*RepetitionRate)/(CropSize*(1-RepetitionRate)))):
        ys = int(i * CropSize * (1 - RepetitionRate))
        ye = int(i * CropSize * (1 - RepetitionRate)) + CropSize
        xs = (width - CropSize)
        xe = width
        
        if xe > width or (xe - xs) != CropSize:
            print('img.shape:', height, width)
            print('向前裁剪最后一列')
            print((xs, ys, xe, ye))
        croppedBoxs.append((xs, ys, xe, ye))
        
    #  向前裁剪最后一行
    for j in range(int((width - CropSize * RepetitionRate) / (CropSize * (1 - RepetitionRate)))):
        ys = (height - CropSize) 
        ye = height
        xs = int(j * CropSize * (1 - RepetitionRate))
        xe = int(j * CropSize * (1 - RepetitionRate)) + CropSize
        if xe > width or (xe - xs) != CropSize:
            print('img.shape:', height, width)
            print('向前裁剪最后一行')
            print((xs, ys, xe, ye))
        croppedBoxs.append((xs, ys, xe, ye))
    
    #  裁剪右下角
    ys = (height - CropSize) 
    ye = height
    xs = (width - CropSize)
    xe = width
    if xe > width or (xe - xs) != CropSize:
            print('img.shape:', height, width)
            print('裁剪右下角')
            print((xs, ys, xe, ye))
    croppedBoxs.append((xs, ys, xe, ye))
    cropImgs = []
    for croppedBox in croppedBoxs:
        (xs, ys, xe, ye) = croppedBox
        
        #  如果图像是单波段
        if(len(img.shape) == 2):
            cropped = img[ys : ye, xs : xe]
        #  如果图像是多波段
        else:
            cropped = img[ys : ye, xs : xe, :]
        if (ye - ys) != cropped.shape[0] or (xe - xs) != cropped.shape[1]:
            raise Exception("错误:裁剪图像的宽高与设计部一致")
        
        cropImgs.append(cropped)
    if len(cropImgs)!= len(croppedBoxs):
        raise Exception("错误:裁剪图像数与裁剪框数不一致")
    return croppedBoxs, cropImgs

增加标注框的处理

def cropLabelImg(imgSavePath, labelSavePath, imgName, img, width, height, cropSize, repetitionRate, labelAddRate, labels, bboxs):
    """
    imgSavePath:裁剪图像保存路径
    labelSavePath:标注文件保存路径
    imgName:原始文件名
    img:图像数据
    width:宽
    height:高
    CropSize:裁剪大小
    RepetitionRate:裁剪框重叠率
    labelAddRate: 标注框添加阈值,如果标注框与裁剪框相交面积占其面积的比值超过阈值则添加到裁剪图像的标注
    labels:原始标注框类别
    bboxs:原始标注框
    """
    
    # 对于小于裁剪图像大小的原始样本图像放大到裁剪框大小,方便裁剪,相对直接缩放会好点
    resizeRatio = 1
    edgeMax = min(width, height)
    if edgeMax < CropSize:
        resizeRatio = CropSize / edgeMax
       
    if edgeMax > 1600:
        resizeRatio = 1600 / edgeMax
    
    # 缩放图像
    nimg = img
    if resizeRatio != 1:
        width = int(round(width * resizeRatio))
        height = int(round(height * resizeRatio))
        nimg = cv2.resize(img, (width, height), interpolation=cv2.INTER_LINEAR)   
    
    # 缩放标注框
    nbboxs = []
    for bbox in bboxs:
        nbbox = resizeBox(bbox, resizeRatio)
        nbboxs.append(nbbox)

    croppedBoxs, cropImgs =  cropImgSlidingWindow(nimg, cropSize, repetitionRate)
    
    cropBoxId = 0
    for croppedBox, cropImg in zip(croppedBoxs, cropImgs):
        (xs, ys, xe, ye) = croppedBox
        
        #  写图像
        cropImgPath = imgSavePath + "/" + imgName + "_" + str(cropBoxId) + ".jpg"
        cv2.imwrite(cropImgPath, cropImg)
        cv2.waitKey(1)

        #  写标注文件
        cropImgLabelPath = labelSavePath + "/" + imgName + "_" + str(cropBoxId) + ".txt"
#         print(cropImgLabelPath)
        f_Label = open(cropImgLabelPath, "w+")
        croppedDraw = cropped.copy()
        for label, nbbox in zip(labels, nbboxs):
            # 标注框
            boxAddRate = labelAddRate
            # 针对操作杆目标,目标本身占标注框的比值很小,因此虽然与裁剪框相交面积大,但也可能存在目标完全不在裁剪框里的情况,calLabelAddRate函数定义了一个使用长宽比来动态设置IOU阈值比例的函数
            if label == '4': # operatingbar
                w = nbbox[2]-nbbox[0]
                h = nbbox[3]-nbbox[1]
                boxAddRate = calLabelAddRate(max(w, h) / min(w, h))
            if solve_coincide(croppedBox, nbbox) > boxAddRate:
#                 print(cropImgLabelPath, '+nbboxs:', nbboxs)
                x1,y1,x2,y2 = nbbox[0],nbbox[1],nbbox[2],nbbox[3]
                if x1 < xs:
                    x1 = xs
                if x2 > xe:
                    x2 = xe
                if y1 < ys:
                    y1 = ys
                if y2 > ye:
                    y2 = ye

                x1 = x1 - xs
                x2 = x2 - xs
                y1 = y1 - ys
                y2 = y2 - ys
                
#                 print(cropImgLabelPath, '+nbboxs:', nbboxs)
                drawBBox(croppedDraw, (0, 0, 255), (x1, y1, x2, y2))

                bb = covertYoloLabelxyxy(cropped.shape[1], cropped.shape[0], x1, y1, x2, y2)
                f_Label.writelines(str(label) + " " + " ".join(str(b) for b in bb) + '\n')
        f_Label.close()
        
        #  文件名 + 1
        cropBoxId = cropBoxId + 1      

仅裁剪没有标注框的背景图像

def cropBgImg(imgSavePath, labelSavePath, imgName, img, width, height, cropSize, repetitionRate, labelAddRate, labels, bboxs):
    """
    从训练集中截取背景图像
    imgSavePath:裁剪图像保存路径
    labelSavePath:标注文件保存路径
    imgName:原始文件名
    img:图像数据
    width:宽
    height:高
    cropSize:裁剪大小
    repetitionRate:裁剪框重叠率
    labelAddRate: 标注框添加阈值,如果任意标注框与裁剪框相交面积占标注框面积的比值超过阈值则不添加到背景图像集
    labels:原始标注框类别
    bboxs:原始标注框
    """
    
    # 对于小于裁剪图像大小的原始样本图像放大到裁剪框大小,方便裁剪,相对直接缩放会好点
    resizeRatio = 1
    edgeMax = min(width, height)
    if edgeMax < CropSize:
        resizeRatio = CropSize / edgeMax
       
    if edgeMax > 1600:
        resizeRatio = 1600 / edgeMax
    
    # 缩放图像
    nimg = img
    if resizeRatio != 1:
        width = int(round(width * resizeRatio))
        height = int(round(height * resizeRatio))
        nimg = cv2.resize(img, (width, height), interpolation=cv2.INTER_LINEAR)   
    
    # 缩放标注框
    nbboxs = []
    for bbox in bboxs:
        nbbox = resizeBox(bbox, resizeRatio)
        nbboxs.append(nbbox)

    croppedBoxs, cropImgs =  cropImgSlidingWindow(nimg, cropSize, repetitionRate)
    
    cropBoxId = 0
    for croppedBox, cropImg in zip(croppedBoxs, cropImgs):
        (xs, ys, xe, ye) = croppedBox

        
        lableNoCropCount = 0
        for label, nbbox in zip(labels, nbboxs):
            if solve_coincide(croppedBox, nbbox) < labelAddRate:
                lableNoCropCount = lableNoCropCount + 1
        if lableNoCropCount == len(nbboxs):
             #  写图像
            cropImgPath = imgSavePath + "/" + imgName + "_" + str(cropBoxId) + ".jpg"
            cv2.imwrite(cropImgPath, cropImg)
            cv2.waitKey(1)
            
            #  写标注文件
            cropImgLabelPath = labelSavePath + "/" + imgName + "_" + str(cropBoxId) + ".txt"
            f_Label = open(cropImgLabelPath, "w+")
            f_Label.close()
        
        #  文件名 + 1
        cropBoxId = cropBoxId + 1      

yolov5 检测经过裁剪的图像

# detect-crop.py

import argparse
import time
from pathlib import Path
import glob
import os
import cv2
import torch
import torch.backends.cudnn as cudnn
import numpy as np
from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_requirements, weighted_boxes_fusion, non_max_suppression, apply_classifier, \
    scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box
from utils.plots import colors, plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized


def cropImgSlidingWindow(img, cropSize, repetitionRate):
    """
    滑动窗口裁剪图像
    img:<c, h, w,c> or <h, w>
    cropSize:裁剪大小,宽高一致
    repetitionRate:滑动窗口间的重叠率
    return:裁剪框,裁剪图像
    """
    height = img.shape[1]
    width = img.shape[2]
    
    croppedBoxs = []
    for i in range(int((height - cropSize * repetitionRate) / (cropSize * (1 - repetitionRate)))):
        for j in range(int((width - cropSize * repetitionRate) / (cropSize * (1 - repetitionRate)))):
            
            ys = int(i * cropSize * (1 - repetitionRate))
            ye = int(i * cropSize * (1 - repetitionRate)) + cropSize
            xs = int(j * cropSize * (1 - repetitionRate))
            xe = int(j * cropSize * (1 - repetitionRate)) + cropSize
            if xe > width:
                print('游动裁剪')
                print((xs, ys, xe, ye))
            croppedBoxs.append((xs, ys, xe, ye))

    for i in range(int((height-cropSize*repetitionRate)/(cropSize*(1-repetitionRate)))):
        ys = int(i * cropSize * (1 - repetitionRate))
        ye = int(i * cropSize * (1 - repetitionRate)) + cropSize
        xs = (width - cropSize)
        xe = width
        
        if xe > width or (xe - xs) != cropSize:
            print('img.shape:', height, width)
            print('向前裁剪最后一列')
            print((xs, ys, xe, ye))
        croppedBoxs.append((xs, ys, xe, ye))
        
    #  向前裁剪最后一行
    for j in range(int((width - cropSize * repetitionRate) / (cropSize * (1 - repetitionRate)))):
        ys = (height - cropSize)
        ye = height
        xs = int(j * cropSize * (1 - repetitionRate))
        xe = int(j * cropSize * (1 - repetitionRate)) + cropSize
        if xe > width or (xe - xs) != cropSize:
            print('img.shape:', height, width)
            print('向前裁剪最后一行')
            print((xs, ys, xe, ye))
        croppedBoxs.append((xs, ys, xe, ye))
    
    #  裁剪右下角
    ys = (height - cropSize)
    ye = height
    xs = (width - cropSize)
    xe = width
    if xe > width or (xe - xs) != cropSize:
            print('img.shape:', height, width)
            print('裁剪右下角')
            print((xs, ys, xe, ye))
    croppedBoxs.append((xs, ys, xe, ye))
    cropImgs = []
    for croppedBox in croppedBoxs:
        (xs, ys, xe, ye) = croppedBox
        
        #  如果图像是单波段
        if(len(img.shape) == 2):
            cropped = img[ys: ye, xs: xe]
        #  如果图像是多波段
        else:
            cropped = img[:, ys: ye, xs: xe]
        if (ye - ys) != cropped.shape[1] or (xe - xs) != cropped.shape[2]:
            raise Exception("错误:裁剪图像的宽高与设计N一致")
        
        cropImgs.append(cropped)
    if len(cropImgs) != len(croppedBoxs):
        raise Exception("错误:裁剪图像数与裁剪框数不一致")
    return croppedBoxs, cropImgs


img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo']  # acceptable image suffixes


class LoadImagesforCrop:  # for inference
    def __init__(self, path, img_size=1600, crop_size=1024):
        p = str(Path(path).absolute())  # os-agnostic absolute path
        if '*' in p:
            files = sorted(glob.glob(p, recursive=True))  # glob
        elif os.path.isdir(p):
            files = sorted(glob.glob(os.path.join(p, '*.*')))  # dir
        elif os.path.isfile(p):
            files = [p]  # files
        else:
            raise Exception(f'ERROR: {p} does not exist')

        images = [x for x in files if x.split('.')[-1].lower() in img_formats]
        ni = len(images)

        self.img_size = img_size
        self.files = images
        self.nf = ni  # number of files
        self.crop_size = crop_size
        self.mode = 'image'

        assert self.nf > 0, f'No images found in {p}. ' \
                            f'Supported formats are:\nimages: {img_formats}'

    def __iter__(self):
        self.count = 0
        return self

    def __next__(self):
        if self.count == self.nf:
            raise StopIteration
        path = self.files[self.count]

        # Read image
        self.count += 1
        img0 = cv2.imread(path)  # BGR
        assert img0 is not None, 'Image Not Found ' + path
        print(f'image {self.count}/{self.nf} {path}: ', end='')

        height = img0.shape[0]
        width = img0.shape[1]

        resize_ratio = 1
        edgeMax = min(width, height)
        if edgeMax < self.crop_size:
            resize_ratio = self.crop_size / edgeMax

        if edgeMax > self.img_size:
            resize_ratio = self.img_size / edgeMax

        # 缩放图像
        nimg = img0
        if resize_ratio != 1:
            width = int(round(width * resize_ratio))
            height = int(round(height * resize_ratio))
            nimg = cv2.resize(img0, (width, height), interpolation=cv2.INTER_LINEAR)

        # Convert
        img = nimg[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
        img = np.ascontiguousarray(img)

        return path, img, img0, resize_ratio

    def __len__(self):
        return self.nf  # number of files


def resizeBox(bbox, expandRatio):
    x1 = bbox[0]*expandRatio
    x2 = bbox[2]*expandRatio
    y1 = bbox[1]*expandRatio
    y2 = bbox[3]*expandRatio
    return [x1, y1, x2, y2]

@torch.no_grad()
def detect(weights='yolov5s.pt',  # model.pt path(s)
           source='data/images',  # file/dir/URL/glob, 0 for webcam
           resize_target = 1600,  # 图像缩放目标 by wxf
           crop_size=1024,  # 裁剪图像宽高大小 by wxf
           repetition_rate=0.3,  # 裁剪框间重叠率 by wxf
           imgsz=640,  # inference size (pixels)
           conf_thres=0.25,  # confidence threshold
           iou_thres=0.45,  # NMS IOU threshold
           max_det=1000,  # maximum detections per image
           device='',  # cuda device, i.e. 0 or 0,1,2,3 or cpu
           view_img=False,  # show results
           save_txt=False,  # save results to *.txt
           save_conf=False,  # save confidences in --save-txt labels
           save_crop=False,  # save cropped prediction boxes
           nosave=False,  # do not save images/videos
           classes=None,  # filter by class: --class 0, or --class 0 2 3
           agnostic_nms=False,  # class-agnostic NMS
           augment=False,  # augmented inference
           update=False,  # update all models
           project='runs/detect',  # save results to project/name
           name='exp',  # save results to project/name
           exist_ok=False,  # existing project/name ok, do not increment
           line_thickness=3,  # bounding box thickness (pixels)
           hide_labels=False,  # hide labels
           hide_conf=False,  # hide confidences
           half=False,  # use FP16 half-precision inference
           ):
    save_img = not nosave and not source.endswith('.txt')  # save inference images
    webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
        ('rtsp://', 'rtmp://', 'http://', 'https://'))

    # Directories
    save_dir = increment_path(Path(project) / name, exist_ok=exist_ok)  # increment run
    (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True)  # make dir

    # Initialize
    set_logging()
    device = select_device(device)
    half &= device.type != 'cpu'  # half precision only supported on CUDA

    # Load model
    model = attempt_load(weights, map_location=device)  # load FP32 model
    names = model.module.names if hasattr(model, 'module') else model.names  # get class names
    if half:
        model.half()  # to FP16

    # Second-stage classifier
    classify = False
    if classify:
        modelc = load_classifier(name='resnet101', n=2)  # initialize
        modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']).to(device).eval()

    # Set Dataloader
    vid_path, vid_writer = None, None

    dataset = LoadImagesforCrop(source, img_size=1600, crop_size=1024)

    # Run inference
    if device.type != 'cpu':
        model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters())))  # run once
    t0 = time.time()

    noCheckPaths = []
    for path, img, im0s, resize_ratio in dataset:

        # 计算滑动窗口方式的裁剪框和裁剪图像
        croppedBoxs, cropImgs = cropImgSlidingWindow(img, crop_size, repetition_rate)
        
        # 调用模型对裁剪图像进行识别
        predSrcImg = []
        for croppedBox, cropImg in zip(croppedBoxs, cropImgs):
            (xs, ys, xe, ye) = croppedBox
        
            img = torch.from_numpy(cropImg).to(device)
            img = img.half() if half else img.float()  # uint8 to fp16/32
            img /= 255.0  # 0 - 255 to 0.0 - 1.0
            if img.ndimension() == 3:
                img = img.unsqueeze(0)

            # Inference
            t1 = time_synchronized()
            pred = model(img, augment=augment)[0]
            # Apply NMS
            pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
            # pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms) # wxf modify
            t2 = time_synchronized()

            for p in pred[0]:
                p[0] = (p[0] + xs) / resize_ratio
                p[2] = (p[2] + xs) / resize_ratio
                p[1] = (p[1] + ys) / resize_ratio
                p[3] = (p[3] + ys) / resize_ratio

            predSrcImg.extend(pred[0])

        predSrcImg = torch.stack(predSrcImg)
        wbf_boxes = predSrcImg[:, 0:4].clone() / 4096.0
        wbf_scores = predSrcImg[:, 4]
        wbf_classId = predSrcImg[:, 5]
        wbf_b, wbf_s, wbf_c = weighted_boxes_fusion([wbf_boxes], [wbf_scores], [wbf_classId], iou_thr=0.1, skip_box_thr=0.65, weights=None)

        wbf_b = wbf_b * 4096
        wbf_b = torch.from_numpy(wbf_b).cuda()
        wbf_s = torch.from_numpy(wbf_s).view(len(wbf_s), -1).cuda()
        wbf_c = torch.from_numpy(wbf_c).view(len(wbf_s), -1).cuda()
        pred = [torch.cat((wbf_b, wbf_s, wbf_c), 1)]

        # Apply Classifier
        if classify:
            pred = apply_classifier(pred, modelc, img, im0s)

        # Process detections
        for i, det in enumerate(pred):  # detections per image
            if webcam:  # batch_size >= 1
                p, s, im0, frame = path[i], f'{i}: ', im0s[i].copy(), dataset.count
            else:
                p, s, im0, frame = path, '', im0s.copy(), getattr(dataset, 'frame', 0)

            p = Path(p)  # to Path
            save_path = str(save_dir / p.name)  # img.jpg
            txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}')  # img.txt
            s += '%gx%g ' % img.shape[2:]  # print string
            gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
            imc = im0.copy() if save_crop else im0  # for save_crop
            if len(det):
                # Rescale boxes from img_size to im0 size
                # det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()

                # Print results
                for c in det[:, -1].unique():
                    n = (det[:, -1] == c).sum()  # detections per class
                    s += f"{n} {names[int(c)]}{'s' * (n > 1)}, "  # add to string

                # Write results
                for *xyxy, conf, cls in reversed(det):
                    if save_txt:  # Write to file
                        xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
                        # line = (cls, *xywh, conf) if save_conf else (cls, *xywh)  # label format

                        line = (cls, *xyxy, conf) if save_conf else (cls, *xyxy)  # label format modify by wxf
                        with open(txt_path + '.txt', 'a') as f:
                            f.write(('%g ' * len(line)).rstrip() % line + '\n')

                    if save_img or save_crop or view_img:  # Add bbox to image
                        c = int(cls)  # integer class
                        label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
                        plot_one_box(xyxy, im0, label=label, color=colors(c, True), line_thickness=line_thickness)
                        if save_crop:
                            save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
            else:
                pass
            # Print time (inference + NMS)
            print(f'{s}Done. ({t2 - t1:.3f}s)')
            if 'person' not in s: # add by wxf:check person
                print("NO ret !!!!!!!!!!!!!")
                noCheckPaths.append(path)
                print(len(noCheckPaths))

            # Stream results
            if view_img:
                cv2.imshow(str(p), im0)
                cv2.waitKey(1)  # 1 millisecond

            # Save results (image with detections)
            if save_img:
                if dataset.mode == 'image':
                    cv2.imwrite(save_path, im0)

    if save_txt or save_img:
        s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
        print(f"Results saved to {save_dir}{s}")

    if update:
        strip_optimizer(weights)  # update model (to fix SourceChangeWarning)
    print(noCheckPaths)
    print(f'Done. ({time.time() - t0:.3f}s)')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
    parser.add_argument('--source', type=str, default='data/images', help='file/dir/URL/glob, 0 for webcam')
    parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)')
    parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
    parser.add_argument('--iou-thres', type=float, default=0.6, help='NMS IoU threshold')
    parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    parser.add_argument('--view-img', action='store_true', help='show results')
    parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
    parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
    parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
    parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
    parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
    parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
    parser.add_argument('--augment', action='store_true', help='augmented inference')
    parser.add_argument('--update', action='store_true', help='update all models')
    parser.add_argument('--project', default='runs/detect', help='save results to project/name')
    parser.add_argument('--name', default='exp', help='save results to project/name')
    parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
    parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
    parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
    parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
    parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
    opt = parser.parse_args()
    print(opt)
    check_requirements(exclude=('tensorboard', 'thop'))

    detect(**vars(opt))

自己在测试的过程中还是存在一些问题的,主要是检测目标过大,裁剪框只裁剪到部分目标,裁剪方法并不使用与目标面积与原始图像面积比小的情况,另一个就是检测结果框的合并,NMS,感觉也不太好,有空再测测流行的SOFT-NMS WBF DIOU-NMS

  • 6
    点赞
  • 50
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值