基于MTCNN的人脸检测

一、人脸检测:

人脸检测为单类多目标检测

检测的图片有两种类型:一张图片中仅含有一个人脸

                                       一张图片中含有多个人脸

深度学习是基于仿生学,计算机视觉和人眼一样,通过感受野(人的视野)的扫描去寻找目标,当感受野特别大目标特别小时人和机器一样无法识别出是否是目标,所以要放大图片或者减少感受野才能看清。人脸检测便是从图片中找到人脸,因图片有大有小,人脸也是如此,为了能检测到图片中人脸,需要通过卷积神经网络来实现,卷积核便是感受野(人的视野),这时为了检测到所有人脸会有两种方法,第一种:感受野从小到大,图片大小不变,扫描整张图片;第二种:图片从大到小,感受野大小不变,扫描多张大小不同的图片。而基于MTCNN的人脸检测采用的是第二种方法,此种方法便是图像金字塔

图像金字塔:只在侦测时PNet前使用,将图片等比缩放(scale缩放比率为0.709(相当于面积的一半))

但是最终缩放到最小边长≥12,因为PNet最小输入图片size为(3,12,12)

红色框为感受野(12*12)

二、理解MTCNN(多任务卷积神经网络):

多任务包括:人脸的检测;特征点的对齐

人脸的检测包括:判断图片中是否存在人脸;若存在人脸需找到人脸的位置

特征点的对齐则是:找到人脸的两只眼睛,鼻尖,嘴角的两端5个特征点

级联思想:图片输入到PNet得到输出,将PNet的输出输入到RNet得到输出,再将RNet的输出输入到ONet得到最终的输出。

三、MTCNN流程:

从上图可以看出,将图片传入图像金字塔中得到数据形状为(N,3,H,W)的图片(注:min(H,W)≥12),将这些图片输入PNet得到形状为(N,5,h,w)(包括(N,1,h,w)的置信度和(N,4,h,w)的候选框四个坐标点的偏移量)的输出,将这些输出通过筛选置信度,再反算到原图片上的候选框,候选框再通过NMS丢掉一些框留下的数量为n,再将这些框做正方形转换,再进行裁剪,变形成(n,3,24,24)的图片集。将图片集传入到RNet得到形状为(n,5) (包括(n,1)的置信度和(n,4)的候选框四个坐标点的偏移量)的输出,再将这些输出通过筛选置信度,再反算原图上的候选框,候选框再通过NMS丢掉一些框留下的数量为m,再将这些框做正方形转换,再进行裁剪,变形成(m,3,48,48)的图片集。将图片集传入到ONet得到形状为(n,5) (包括(n,1)的置信度和(n,4)的候选框四个坐标点的偏移量)的输出,再将这些输出通过筛选置信度,再反算原图上的候选框,候选框再通过NMS(注:这里NMS所用的IOU为交集比最小集)丢掉一些框,剩下的为最终的预测框,再将预测框画在图片上。

四、制作样本

制作样本所用到的是CelebA数据集,可以通过http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html进行下载。

CelebA Anno:img_id x1 y1 w h, 我们需要计算成cx cy w h(cx, cy为中心点坐标)以方便后面的使用

1、建议框的选取:

如图

红色的框为标签框,红色的点为标签框的中心点(cx, cy),获取到标签框的最大边长max_side,我们通过偏移中心点和最大边长来得到建议框,偏移方法是先制定一个1:1:3或者1:3:5的种子(随机小数)(如:[0.1, 0.5, 0.85, 0.9, 0.95] 0.1偏移小的大概率为正样本,0.5偏移大概率为部分样本,0.9偏移大的大概率为负样本)

建议框的中心点=真实框中心点±种子*真实框中心点

建议框的边长=真实框最大边长±种子*真实框最大边长(建议框为正方形)

但为更好的区分样本,我们使用IOU(图像重合度)来区分样本,IOU为两个框交集面积与并集面积的比值,IOU小则重合程度低,IOU大则重合程度高,IOU>0.65为正样本 如绿色框,0.4<IOU<0.65为部分样本 如黄色框,IOU<0.3的为负样本 如蓝色框。

效果:正样本部分样本负样本

2、偏移量的计算:

如图

蓝色框为真实框,红色框为建议框,而坐标点x1,y1,x2,y2的偏移量offset=(真实框的坐标-建议框的坐标)/建议框的边长(建议框是正方形),为什么要除偏移框的边长而不是真实框的边长,因为我们训练的是建议框而非真实框。

3、样本文件和标签文件:

如图

保存的路径下包含3个size的子文件夹,每个子文件夹下包含positive正、negative负、part部分3个文件夹和标签txt文本文件

标签值为:sample/image_id  cls  x1_offset  y1_offset  x2_offset  y2_offset

正样本cls(置信度)为1,部分样本cls为2,负样本cls为0(且负样本的offset都可以为0)

五:网络结构

所有网络激活函数都是使用的PReLU,PReLU负半轴有斜率不会丢失负半轴的部分特征

所有网络的隐藏层都加BatchNormal,效果比不加好。

1、PNet训练12*12样本集(层数浅、图片尺寸小(参数量小)所以训练最快)

PNet为全卷积网络,因为在侦测时PNet的输入是图像金字塔变形后的图片集每张图片大小都不一样,使用全卷积网络才能训练这些图片,PNet只有3层隐藏层,层数很浅作用是为了侦测时更快的获取大量的候选框。(不放过图片中任何一个人脸)

训练时输入为(N,3,12,12)N为12*12(包含正、部分、负样本)的所有样本图片数

           输出为(N,1,1,1)置信度,图片中使用的是softmax输出函数,而我使用的是Sigmoid输出函数

                  和(N,4,1,1)4个坐标点的偏移量

侦测时输入为(N,3,H,W)N为做完图像金字塔所得到的所有图片数(min(H, W)≥12)

           输出为(N,1,h,w)置信度

                   和(N,4,h,w)4个坐标点的偏移量

2、RNet训练24*24样本集

Rnet为卷积+全连接网络,加全连接的目的是为了特征融合,在侦测时是为了筛选掉一部分不存在人脸的候选框。

训练时输入为(N,3,24,24)N为24*24(包含正、部分、负样本)的所有样本图片数

           输出为(N,1,1,1)置信度,图片中使用的是softmax输出函数,而我使用的是Sigmoid输出函数

                  和(N,4,1,1)4个坐标点的偏移量

侦测时输入为(n,3,24,24)n为PNet输出的所有候选框经过筛选、反算、NMS、正方形转换、裁剪、变形后的图片数

           输出为(n,1,1,1)置信度

                   和(n,4,1,1)4个坐标点的偏移量

3、ONet训练48*48样本集

ONet也为卷积+全连接网络,但是层数更深为5层为了提取更细节的特征,回归候选框更精确。

训练时输入为(N,3,48,48)N为48*48(包含正、部分、负样本)的所有样本图片数

           输出为(N,1,1,1)置信度,图片中使用的是softmax输出函数,而我使用的是Sigmoid输出函数

                  和(N,4,1,1)4个坐标点的偏移量

侦测时输入为(m,3,48,48)m为RNet输出的所有候选框经过筛选、反算、NMS、正方形转换、裁剪、变形后的图片数

           输出为(m,1,1,1)置信度

                   和(m,4,1,1)4个坐标点的偏移量

六、训练模型

MTCNN可以并行训练(3个网络同时训练,前提是内存够大)

损失函数:cls置信度输出函数为Sigmoid并且是判断是否存在人脸的分类问题,所以使用BCELoss(二分类交叉熵损失函数)

                  offset偏移量为回归候选框的左边点的偏移量使用的是MSELoss(均方差损失函数)

优化器:都使用的是Adam

cls置信度是使用的网络输出中索引属于正样本和负样本的数据进行回归的

offset偏移量是使用的网络输出中索引属于正样本和部分样本的数据进行回归的

PNet训练停止损失值为0.01,RNet训练停止损失值为0.001,ONet训练停止损失值为0.0005

七、工具

1、图像IOU(图像重合度):

候选框有以下几种情况:

图1前4种情况为两个候选框相交,第5种情况为大框套小框,第6种情况为两个候选框分离(这种很有可能两个框都是人脸框)

IOU为重合程度,计算方法有两种如图2

第一种为交并比IOU(交集面积/并集面积)

先找到交集,从图1第一种情况可以看出 交集框x1,y1为两个框x1,y1的最大值,交集框x2,y2为两个框x2,y2的最小值

交集面积 = (min(x2,x2') - max(x1,x1')) * (min(y2,y2') - max(y1,y1'))

并集面积 = 两个框的面积和 - 交集面积

第二种为最小集IOU(交集面积/最小框面积)

这种是计算方法是因为获取的所有候选框中,有大框套小框的存在,并且小框远小于大框(如上图),这样使用交并比IOU是无法去除非人脸框的,这是就只能用最小集IOU,因为大框套小框交集面积=最小集面积所以IOU=1,这种是可以将不合适的框滤除掉的。这种IOU计算方法只有O网络用到。

2、NMS(非极大值抑制):

NMS顾名思义是留下值大的,抑制(丢弃)值小的,而这个值在MTCNN中指的是cls置信度

如这张图中有多个框框在同一张人脸上,我们需要去除多余的框

去重方法如此图,先按照cls置信度将所有候选框的数据排序(从大到小),第一个框(也就是置信度最大的框)一定为人脸框需要保留存为预测框,剩下的框依次和第一个框做IOU,若重合度高于阈值则丢弃,低于阈值说明可能是另一个人脸的框所以保留

重复上述操作直到候选框中只剩下一个框或者没有框,若只剩下一个说明此框可能是人脸框也保留存储在预测框中

NMS后的效果:

3、图像坐标正方形转换

此工具在侦测时RNet和ONet之前使用,目的是因为PNet和RNet输出的框大可能是长方形的,但是RNet和ONet传入的是24*24和48*48固定的正方形,如果直接变形会导致人脸特征扭曲变形,不利于侦测,所以需要在变形之前转换成正方形。

如上图,如果直接传入长方形会少一部分特征,而且RNet和ONet需求是固定size的正方形;如果直接resize会使脸部特征扭曲变形不利于侦测;如果放入黑色正方形背景中特征不全会影响侦测结果;最好的办法是在原图上变成正方形再resize。

变形方法:计算长方形框的最大边长max_side,然后在计算正方形的x1,y1,x2,y2

                  x1, y1 = (x1'+x2')/2+max_side/2,  (y1'+y2')/2+max_side/2    x2, y2 = x1+max_side, y1+max_side

4、图像坐标反算

图像坐标反算每个Net都要进行,但是PNet不同,因为在传入PNet之前做了图像金字塔如果要反算回原图就要/scale(缩放比率)

如上图:可以看到左边为原图,框为建议框,右边为输出的结果(N,5,h,w)的结果,先筛选出置信度大于阈值的框

(PNet使用的是index = torch.nonzero(torch.gt(cls, 0.5))方法,RNet和ONet使用的是index,_ = np.where(cls>0.6)方法)

先反算建议框:x1' = (index[:,1] * stride) / scale     x2' = (index[:, 1] * stride + side) / scale   (PNet:strde=2,side=12)

                         y1' = (index[:,0] * stride) / scale     y2' = (index[:, 0] * stride + side) / scale

                         w' = x2' - x1'    h' = y2' - y1'

再反算预测框:根据偏移量计算公式:offset = (x - x') / w'

                         x1 = x1' + w' * offset[0, index[:,0], index[:,1]]     y1 = y1' + h' * offset[1, index[:,0], index[:,1]] 

                         x2 = x2' + w' * offset[2, index[:,0], index[:,1]]     y2 = y2' + h' * offset[3, index[:,0], index[:,1]] 

而RNet和ONet都是直接输入的建议框,所以不需要反算建议框,直接反算预测框

                         x1 = x1' + w' * offset[index, 0]     y1 = y1' + h' * offset[index, 1] 

                         x2 = x2' + w' * offset[index, 2]     y2 = y2' + h' * offset[index, 4] 

八、代码

工具代码:

import numpy as np


def iou(box, boxes, isMin=False):
    box_area = (box[2] - box[0]) * (box[3] - box[1])
    boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
    x1 = np.maximum(box[0], boxes[:, 0])
    y1 = np.maximum(box[1], boxes[:, 1])
    x2 = np.minimum(box[2], boxes[:, 2])
    y2 = np.minimum(box[3], boxes[:, 3])
    w = np.maximum(0, x2 - x1)
    h = np.maximum(0, y2 - y1)
    inter_area = w * h
    if isMin:
        ratio = np.true_divide(inter_area, np.minimum(box_area, boxes_area))
    else:
        ratio = np.true_divide(inter_area, box_area + boxes_area - inter_area)
    return ratio


def nms(boxes, thresh=0.3, isMin=False):
    if boxes.shape[0] == 0:
        return np.array([])

    boxes = boxes[(-boxes[:, 4]).argsort()]
    _boxes = []

    while boxes.shape[0] > 1:
        max_box = boxes[0]
        other_boxes = boxes[1:]
        _boxes.append(max_box)
        index = np.where(iou(max_box, other_boxes, isMin) < thresh)
        boxes = other_boxes[index]

    if boxes.shape[0] > 0: 
        _boxes.append(boxes[0])
    return np.stack(_boxes)


def convert_to_square(boxes):
    square_boxes = boxes.copy() 
    if boxes.shape[0] == 0:
        return np.array([])
    w = boxes[:, 2] - boxes[:, 0]
    h = boxes[:, 3] - boxes[:, 1]
    max_side = np.maximum(w, h)
    square_boxes[:, 0] = boxes[:, 0] + w/2 - max_side/2
    square_boxes[:, 1] = boxes[:, 1] + h/2 - max_side/2
    square_boxes[:, 2] = square_boxes[:, 0] + max_side
    square_boxes[:, 3] = square_boxes[:, 1] + max_side
    return square_boxes


if __name__ == "__main__":
    a = np.array([1, 1, 11, 11])
    bs = np.array([[1, 1, 10, 10], [14, 15, 20, 20]])
    print(iou(a, bs))

    bs = np.array([[1, 1, 10, 10, 0.98], [1, 1, 9, 9, 0.8], [9, 8, 13, 20, 0.7], [6, 11, 18, 17, 0.85]])
    print((-bs[:, 4]).argsort())
    print(nms(bs))
    

制作样本代码:

import os
import traceback
from PIL import Image
import numpy as np
from tools.utils import iou, nms
import shutil


save_path = r"datasets"
img_path = r"E:\CelebA\Img\img_celeba"
anno_file = r"E:\CelebA\Anno\list_bbox_celeba.txt"
float_num = [0.1, 0.5, 0.5, 0.5, 0.95, 0.96, 0.97, 0.98, 0.99]


def sample(img_size, number):
    print(f"sample image size:{img_size}")
    positive_img_path = os.path.join(save_path, str(img_size), "positive")
    negative_img_path = os.path.join(save_path, str(img_size), "negative")
    part_img_path = os.path.join(save_path, str(img_size), "part")
    for path in [positive_img_path, negative_img_path, part_img_path]:
        if not os.path.exists(path):
            os.makedirs(path)
    positive_anno_filename = os.path.join(save_path, str(img_size), "positive.txt")
    negative_anno_filename = os.path.join(save_path, str(img_size), "negative.txt")
    part_anno_filename = os.path.join(save_path, str(img_size), "part.txt")
    positive_count = 0
    negative_count = 0
    part_count = 0

    try:
        positive_anno_file = open(positive_anno_filename, "w")
        negative_anno_file = open(negative_anno_filename, "w")
        part_anno_file = open(part_anno_filename, "w")
        for i, line in enumerate(open(anno_file)):
            if i < 2:
                continue
            try:
                strs = line.split()
                img_filename = strs[0].strip()
                img_file = os.path.join(img_path, img_filename)
                img = Image.open(img_file)
                img_w, img_h = img.size
                x1 = int(strs[1].strip())
                y1 = int(strs[2].strip())
                w = int(strs[3].strip())
                h = int(strs[4].strip())
                x2 = int(x1 + w)
                y2 = int(y1 + h)
                if max(w, h) < 40 or x1 < 0 or y1 < 0 or w < 0 or h < 0:
                    continue
                boxes = [[x1, y1, x2, y2]]
                cx = x1 + w / 2
                cy = y1 + h / 2
                max_side = max(w, h)
                seed = float_num[np.random.randint(0, len(float_num))]
                count = 0
                for _ in range(5):
                    _max_side = max_side + np.random.randint(int(-max_side * seed), int(max_side * seed))
                    _cx = cx + np.random.randint(int(-cx * seed), int(cx * seed))
                    _cy = cy + np.random.randint(int(-cy * seed), int(cy * seed))
                    _x1 = _cx - _max_side / 2
                    _y1 = _cy - _max_side / 2
                    _x2 = _x1 + _max_side
                    _y2 = _y1 + _max_side
                    if _x1 < 0 or _y1 < 0 or _x2 > img_w or _y2 > img_h:
                        continue
                    offset_x1 = (x1 - _x1) / _max_side
                    offset_y1 = (y1 - _y1) / _max_side
                    offset_x2 = (x2 - _x2) / _max_side
                    offset_y2 = (y2 - _y2) / _max_side
                    crop_box = [_x1, _y1, _x2, _y2]
                    img_crop = img.crop(crop_box)
                    img_resize = img_crop.resize((img_size, img_size))
                    value = iou(crop_box, np.array(boxes), isMin=False)[0]
                    if value > 0.65:
                        positive_anno_file.write(f"positive/{positive_count}.jpg {1} {offset_x1} {offset_y1} {offset_x2} {offset_y2}\n")
                        positive_anno_file.flush()
                        img_resize.save(os.path.join(positive_img_path, f"{positive_count}.jpg"))
                        positive_count += 1
                    elif 0.65 > value > 0.4:
                        part_anno_file.write(f"part/{part_count}.jpg {2} {offset_x1} {offset_y1} {offset_x2} {offset_y2}\n")
                        part_anno_file.flush()
                        img_resize.save(os.path.join(part_img_path, f"{part_count}.jpg"))
                        part_count += 1
                    elif value < 0.2:
                        negative_anno_file.write(f"negative/{negative_count}.jpg {0} {0} {0} {0} {0}\n")
                        negative_anno_file.flush()
                        img_resize.save(os.path.join(negative_img_path, f"{negative_count}.jpg"))
                        negative_count += 1
                    count = positive_count + negative_count + part_count
                    print(f"img_size:{img_size}, count:{count}, positive_count:{positive_count}, negative_count:{negative_count}, part_count:{part_count}")
                if count >= number:
                    break
            except:
                traceback.print_exc()
    except:
        traceback.print_exc()


if __name__ == "__main__":
    path = "datasets"
    if os.path.exists(path):
        shutil.rmtree(path)
    os.mkdir(path)
    sample(12, 20000)
    sample(24, 20000)
    sample(48, 20000)

数据集代码:

from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
from torchvision import transforms
import torch


class DataSet(Dataset):
    mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]

    def __init__(self, path):
        self.path = path
        self.dataSet = []
        self.dataSet.extend(open(os.path.join(path, "positive.txt")).readlines())
        self.dataSet.extend(open(os.path.join(path, "negative.txt")).readlines())
        self.dataSet.extend(open(os.path.join(path, "part.txt")).readlines())

    def __len__(self):
        return len(self.dataSet)

    def __getitem__(self, index):
        datas = self.dataSet[index]
        strs = datas.strip().split()
        img_filename = strs[0]
        cls = [float(strs[1])]
        offsets = list(map(float, strs[2:]))
        img_file = Image.open(os.path.join(self.path, img_filename))
        img_data = transforms.ToTensor()(img_file)
        img_data = transforms.Normalize(DataSet.mean, DataSet.std)(img_data)
        cls = torch.tensor(cls, dtype=torch.float32)
        offsets = torch.tensor(offsets, dtype=torch.float32)
        return img_data, cls, offsets


if __name__ == "__main__":
    dataSet = DataSet(r"datasets/48")
    dataLoader = DataLoader(dataSet, batch_size=10, shuffle=True)
    for i, (x, cls, offset) in enumerate(dataLoader):
        print(x.shape)
        print(cls)
        print(offset)
        break

网络代码:

import torch
import torch.nn as nn


class PNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer3 = nn.Sequential(
            nn.Conv2d(3, 10, 3, 1),  # (10, 10, 10)
            nn.BatchNorm2d(10),
            nn.PReLU(),
            nn.MaxPool2d(2, 2),  # (10, 5, 5)
            nn.Conv2d(10, 16, 3, 1),  # (16, 3, 3)
            nn.BatchNorm2d(16),
            nn.PReLU(),
            nn.Conv2d(16, 32, 3, 1),  # (32, 1, 1)
            nn.BatchNorm2d(32),
            nn.PReLU()
        )
        self.layer4_1 = nn.Conv2d(32, 1, 1, 1)
        self.layer4_2 = nn.Conv2d(32, 4, 1, 1)

    def forward(self, x):
        x = self.layer3(x)
        cls = torch.sigmoid(self.layer4_1(x))
        offset = self.layer4_2(x)
        return cls, offset


class RNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer3 = nn.Sequential(
            nn.Conv2d(3, 28, 3, 1),  # (28, 22, 22)
            nn.BatchNorm2d(28),
            nn.PReLU(),
            nn.MaxPool2d(3, 2, 1),  # (28, 11, 11)
            nn.Conv2d(28, 48, 3, 1),  # (48, 9, 9)
            nn.BatchNorm2d(48),
            nn.PReLU(),
            nn.MaxPool2d(3, 2),  # (48, 4, 4)
            nn.Conv2d(48, 64, 2, 1),  # (64, 3, 3)
            nn.BatchNorm2d(64),
            nn.PReLU()
        )
        self.layer4 = nn.Linear(64*3*3, 128)
        self.pReLU = nn.PReLU()
        self.layer5_1 = nn.Linear(128, 1)
        self.layer5_2 = nn.Linear(128, 4)

    def forward(self, x):
        x = self.layer3(x)
        # x = torch.reshape(x, (-1, 64*3*3))
        x = x.view(x.size(0), -1)
        x = self.pReLU(self.layer4(x))
        cls = torch.sigmoid(self.layer5_1(x))
        offset = self.layer5_2(x)
        return cls, offset


class ONet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer4 = nn.Sequential(
            nn.Conv2d(3, 32, 3, 1),  # (32, 46, 46)
            nn.BatchNorm2d(32),
            nn.PReLU(),
            nn.MaxPool2d(3, 2, 1),  # (32, 23, 23)
            nn.Conv2d(32, 64, 3, 1),  # (64, 21, 21)
            nn.BatchNorm2d(64),
            nn.PReLU(),
            nn.MaxPool2d(3, 2),  # (64, 10, 10)
            nn.Conv2d(64, 64, 3, 1),  # (64, 8, 8)
            nn.BatchNorm2d(64),
            nn.PReLU(),
            nn.MaxPool2d(2, 2),  # (64, 4, 4)
            nn.Conv2d(64, 128, 2, 1),  # (128, 3, 3)
            nn.BatchNorm2d(128),
            nn.PReLU(),
        )
        self.layer5 = nn.Linear(128*3*3, 256)
        self.pReLU = nn.PReLU()
        self.layer6_1 = nn.Linear(256, 1)
        self.layer6_2 = nn.Linear(256, 4)

    def forward(self, x):
        x = self.layer4(x)
        # x = torch.reshape(x, (-1, 128*3*3))
        # x = x.reshape((-1, 128*3*3))
        x = x.view(x.size(0), -1)
        x = self.pReLU(self.layer5(x))
        cls = torch.sigmoid(self.layer6_1(x))
        offset = self.layer6_2(x)
        return cls, offset


if __name__ == "__main__":
    data = torch.randn(2, 3, 48, 48)
    net = ONet()
    cls, offset = net(data)
    print(cls.shape)
    print(offset.shape)

训练代码:

import torch
import torch.nn as nn
from dataset import DataSet, DataLoader
import os
from my_net import PNet, RNet, ONet


class Trainer:
    def __init__(self, data_path, net, save_path):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.net = net().to(self.device)
        self.save_path = save_path
        self.cls_loss_fn = nn.BCELoss()
        self.offset_loss_fn = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.net.parameters())
        self.dataSet = DataSet(data_path)
        self.dataLoader = DataLoader(self.dataSet, batch_size=512, shuffle=True, num_workers=4)

    def train(self, stop_value):
        if os.path.exists(self.save_path):
            self.net.load_state_dict(torch.load(self.save_path))
        else:
            print(f"NO {self.save_path}")

        loss = 0
        while True:
            for i, (x, cls, offset) in enumerate(self.dataLoader):
                x, cls, offset = x.to(self.device), cls.to(self.device), offset.to(self.device)
                _cls_out, _offset_out = self.net(x)
                cls_out, offset_out = _cls_out.view(-1, 1), _offset_out.view(-1, 4)

                cls_mask = torch.lt(cls, 2)
                cls_out = torch.masked_select(cls_out, cls_mask)
                cls_target = torch.masked_select(cls, cls_mask)
                cls_loss = self.cls_loss_fn(cls_out, cls_target)

                offset_mask = torch.gt(cls, 0)
                offset_out = torch.masked_select(offset_out, offset_mask)
                offset_target = torch.masked_select(offset, offset_mask)
                offset_loss = self.offset_loss_fn(offset_out, offset_target)
                loss = cls_loss + offset_loss

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                print(f"loss:{loss.float()}, cls_loss:{cls_loss.float()}, offset_loss:{offset_loss.float()}")
            torch.save(self.net.state_dict(), self.save_path)
            print(f"{self.save_path}")

            if loss.float() < stop_value:
                break


def train_net(data_path, net, save_path, stop_value):
    path = "models"
    if not os.path.exists(path):
        os.mkdir(path)
    save_path = os.path.join(path, save_path)
    train = Trainer(data_path, net, save_path)
    train.train(stop_value)


if __name__ == "__main__":
    train_net("datasets/12", PNet, "p_net.pth", 0.01)
    train_net("datasets/24", RNet, "r_net.pth", 0.001)
    train_net("datasets/48", ONet, "o_net.pth", 0.0005)

侦测代码:

import torch
from torchvision import transforms
from my_net import PNet, RNet, ONet
from dataset import DataSet
import time
import numpy as np
from tools.utils import nms, convert_to_square
from PIL import Image, ImageDraw, ImageFont


class Detector:
    def __init__(self, isCuda=True):
        self.isCuda = isCuda
        self.p_net = PNet()
        self.r_net = RNet()
        self.o_net = ONet()
        if self.isCuda:
            self.p_net.cuda()
            self.r_net.cuda()
            self.o_net.cuda()
        self.p_net.load_state_dict(torch.load("models/p_net.pth"))
        self.r_net.load_state_dict(torch.load("models/r_net.pth"))
        self.o_net.load_state_dict(torch.load("models/o_net.pth"))
        self.p_net.eval()
        self.r_net.eval()
        self.o_net.eval()
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(DataSet.mean, DataSet.std)
        ])

    def detect(self, image):
        p_start_time = time.time()
        p_boxes = self.__p_detect(image)
        if p_boxes.shape[0] == 0:
            return np.array([]), np.array([]), np.array([])
        p_end_time = time.time()
        p_time = p_end_time - p_start_time

        r_start_time = time.time()
        r_boxes = self.__r_detect(image, p_boxes)
        if r_boxes.shape[0] == 0:
            return p_boxes, np.array([]), np.array([])
        r_end_time = time.time()
        r_time = r_end_time - r_start_time

        o_start_time = time.time()
        o_boxes = self.__o_detect(image, r_boxes)
        if o_boxes.shape[0] == 0:
            return p_boxes, r_boxes, np.array([])
        o_end_time = time.time()
        o_time = o_end_time - o_start_time
        count_time = p_time + r_time + o_time
        print(f"count_time:{count_time}, p_time:{p_time}, r_time:{r_time}, o_time:{o_time}")
        return p_boxes, r_boxes, o_boxes

    def __p_detect(self, image, stride=2, side=12):
        scale = 1
        boxes = []
        w, h = image.size
        min_side = min(w, h)
        while min_side >= 12:
            img_data = self.transform(image)
            img_data = img_data.unsqueeze(0)
            if self.isCuda:
                img_data = img_data.cuda()
            cls, offset = self.p_net(img_data)  # offset.shape=[1, 4, 528, 795]
            _cls, _offset = cls[0][0].cpu().data, offset[0].cpu().data  # _offset.shape=[4, 528, 795]
            mask = torch.gt(_cls, 0.5)
            index = torch.nonzero(mask)  # (321,2)

            x1 = (index[:, 1] * stride) / scale  # (321)
            y1 = (index[:, 0] * stride) / scale
            x2 = x1 + side / scale
            y2 = y1 + side / scale
            _w = x2 - x1   # (321)
            _h = y2 - y1
            _x1 = _offset[0, index[:, 0], index[:, 1]] * _w + x1
            _y1 = _offset[1, index[:, 0], index[:, 1]] * _h + y1
            _x2 = _offset[2, index[:, 0], index[:, 1]] * _w + x2
            _y2 = _offset[3, index[:, 0], index[:, 1]] * _h + y2
            _cls = _cls[index[:, 0], index[:, 1]]
            box = torch.stack([_x1, _y1, _x2, _y2, _cls], dim=1)
            boxes.extend(box.numpy())
            scale *= 0.709
            img_w, img_h = int(scale * w), int(scale * h)
            image = image.resize((img_w, img_h))
            min_side = min(img_w, img_h)
        return nms(np.array(boxes), threshold=0.3, isMin=False)

    def __r_detect(self, image, p_boxes):
        p_boxes = convert_to_square(p_boxes)
        x1 = p_boxes[:, 0]  # (668.)
        y1 = p_boxes[:, 1]
        x2 = p_boxes[:, 2]
        y2 = p_boxes[:, 3]
        box = np.stack((x1, y1, x2, y2), axis=1)
        img_dataset = [self.transform(image.crop(x).resize((24, 24))) for x in box]
        img_dataset = torch.stack(img_dataset)
        if self.isCuda:
            img_dataset = img_dataset.cuda()
        cls, offset = self.r_net(img_dataset)
        cls, offset = cls.cpu().data.numpy(), offset.cpu().data.numpy()  # (668, 1) (668, 4)
        index, _ = np.where(cls > 0.6)  # (101,)
        box = p_boxes[index]  # (101, 5)
        x1 = box[:, 0]  # (101,)
        y1 = box[:, 1]
        x2 = box[:, 2]
        y2 = box[:, 3]
        w = x2 - x1  # (101,)
        h = y2 - y1
        _x1 = offset[index, 0] * w + x1
        _y1 = offset[index, 1] * h + y1
        _x2 = offset[index, 2] * w + x2
        _y2 = offset[index, 3] * h + y2
        _cls = cls[index, 0]
        boxes = np.stack((_x1, _y1, _x2, _y2, _cls), axis=1)
        return nms(boxes, threshold=0.3, isMin=False)

    def __o_detect(self, image, r_boxes):
        r_boxes = convert_to_square(r_boxes)
        x1 = r_boxes[:, 0]
        y1 = r_boxes[:, 1]
        x2 = r_boxes[:, 2]
        y2 = r_boxes[:, 3]
        r_box = np.stack((x1, y1, x2, y2), axis=1)
        img_dataset = [self.transform(image.crop(x).resize((48, 48))) for x in r_box]
        img_dataset = torch.stack(img_dataset)
        if self.isCuda:
            img_dataset = img_dataset.cuda()
        cls, offset = self.o_net(img_dataset)
        cls, offset = cls.cpu().data.numpy(), offset.cpu().data.numpy()
        index, _ = np.where(cls > 0.97)  # (44,)
        x1 = r_boxes[index, 0]  # (44,)
        y1 = r_boxes[index, 1]
        x2 = r_boxes[index, 2]
        y2 = r_boxes[index, 3]
        w = x2 - x1  # (44,)
        h = y2 - y1
        _x1 = offset[index, 0] * w + x1
        _y1 = offset[index, 1] * h + y1
        _x2 = offset[index, 2] * w + x2
        _y2 = offset[index, 3] * h + y2
        _cls = cls[index, 0]
        boxes = np.stack((_x1, _y1, _x2, _y2, _cls), axis=1)
        return nms(boxes, threshold=0.3, isMin=True)


if __name__ == "__main__":
    x = time.time()  # 侦测开始计时
    font = ImageFont.truetype(r"C:\Windows\Fonts\simhei", size=20)
    with torch.no_grad() as grad:
        image_file = r"imgs/img1.jpg"  # 图片路径
        detector = Detector()  # 实例化
        with Image.open(image_file) as img:
            p_img = img.copy()
            r_img = img.copy()
            o_img = img.copy()
            P_boxes, R_boxes, O_boxes = detector.detect(img)  # 将图片传入detect进行侦测,得到真实框的所有值
            print(P_boxes.shape, R_boxes.shape, O_boxes.shape)
            # imDraw = ImageDraw.Draw(p_img)  # 画P网络输出图
            # for box in P_boxes:  # 遍历所有P网络输出的真实框 box[4] 为置信度
            #     x1 = int(box[0])
            #     y1 = int(box[1])
            #     x2 = int(box[2])
            #     y2 = int(box[3])
            #     cls = box[4]
            #     imDraw.rectangle((x1, y1, x2, y2), outline='red')  # 画出侦测后的所有真实框
            # imDraw = ImageDraw.Draw(r_img)  # 画R网络输出图
            # for box in R_boxes:  # 遍历所有R网络输出的真实框 box[4] 为置信度
            #     x1 = int(box[0])
            #     y1 = int(box[1])
            #     x2 = int(box[2])
            #     y2 = int(box[3])
            #     cls = box[4]
            #     imDraw.rectangle((x1, y1, x2, y2), outline='red')  # 画出侦测后的所有真实框
            #     imDraw.text((x1, y1), "{:.3f}".format(cls), fill="red", font=font)
            imDraw = ImageDraw.Draw(o_img)  # 画O网络输出图
            for box in O_boxes:  # 遍历所有O网络输出的真实框 box[4] 为置信度
                x1 = int(box[0])
                y1 = int(box[1])
                x2 = int(box[2])
                y2 = int(box[3])
                cls = box[4]
                imDraw.rectangle((x1, y1, x2, y2), outline='red')  # 画出侦测后的所有真实框
                imDraw.text((x1, y1), "{:.3f}".format(cls), fill="red", font=font)
            y = time.time()  # 计算侦测总用时
            print(y - x)
            # p_img.show()
            # r_img.show()
            o_img.show()

 

九、效果

  • 5
    点赞
  • 76
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值