Retinaface代码记录(二)(数据处理)

一、写在开头

这次主要记录关于Retinaface的数据处理部分。

下面是代码地址:
Retinaface代码地址

主要包括的脚本为:
wider_face.py
data_augment.py

也欢迎阅读其上一篇博客Retinaface代码记录(一)。可以帮助读者对本片博客可以有一个整体上的把握和理解。

二、主要内容

wider_face.py主要是一个WiderFaceDetection类:其继承了torch.utils.data.Dataset类,主要包含三个方法:初始化__init__,获取图像__getitem__,数据集数量 __len__
首先,初始化__init__。获得一个接下来要处理的文件txt的地址,这个txt包含了widerface数据图像的地址,以及每个图像中人脸的box的坐标值和关键点的坐标值。可以按照上述地址中的Readme下载,其主要是下面这个样子。然后是一个预处理类,preproc,主要用来对输入的图像做各种数据增强处理等,这个类位于data_augment.py脚本中。接下来,定义了一个用来存储图像路径的list和一个用来存放关于人脸box和关键点信息的list。下面便是对label.txt的操作,首先读取这个txt,获取每行信息,将其存储在lines中,对其进行遍历,对于遇到“#”开头的,便是图片的地址,将其放入img_path中,然后处理此图片的信息,如下图,每行信息代表如下:首先是box的x,y然后是w,h,接着是5个关键点信息,分别用0.0隔开。下面的操作就是将这些信息放入words中,而isFirst便是处理完一张图片的标志。
__getitem__,主要就是读取图片,并将图片信息一一放入Target中,最后返回这两个东西。
__len__呢,顾名思义,就是返回数据集的数量。

label.txt:

wider_face.py

import os
import os.path
import sys
import torch
import torch.utils.data as data
import cv2
import numpy as np

class WiderFaceDetection(data.Dataset):

    def __init__(self, txt_path, preproc=None):
        self.preproc = preproc
        self.imgs_path = []
        self.words = []

        f = open(txt_path,'r')
        lines = f.readlines()
        isFirst = True
        labels = []
        for line in lines:
            line = line.rstrip()
            if line.startswith('#'):
                if isFirst is True:
                    isFirst = False
                else:
                    labels_copy = labels.copy()
                    self.words.append(labels_copy)
                    labels.clear()
                path = line[2:]
                path = txt_path.replace('label.txt','images/') + path
                self.imgs_path.append(path)
            else:
                line = line.split(' ')
                label = [float(x) for x in line]
                labels.append(label)

        self.words.append(labels)

    def __len__(self):
        return len(self.imgs_path)

    def __getitem__(self, index):
        img = cv2.imread(self.imgs_path[index])
        height, width, _ = img.shape

        labels = self.words[index]
        
        annotations = np.zeros((0, 15))
        if len(labels) == 0:
            return annotations
        for idx, label in enumerate(labels):
            annotation = np.zeros((1, 15))
            # bbox
            annotation[0, 0] = label[0]  # x1
            annotation[0, 1] = label[1]  # y1
            annotation[0, 2] = label[0] + label[2]  # x2
            annotation[0, 3] = label[1] + label[3]  # y2

            # landmarks
            annotation[0, 4] = label[4]    # l0_x
            annotation[0, 5] = label[5]    # l0_y
            annotation[0, 6] = label[7]    # l1_x
            annotation[0, 7] = label[8]    # l1_y
            annotation[0, 8] = label[10]   # l2_x
            annotation[0, 9] = label[11]   # l2_y
            annotation[0, 10] = label[13]  # l3_x
            annotation[0, 11] = label[14]  # l3_y
            annotation[0, 12] = label[16]  # l4_x
            annotation[0, 13] = label[17]  # l4_y
            if (annotation[0, 4]<0):
                annotation[0, 14] = -1
            else:
                annotation[0, 14] = 1

            annotations = np.append(annotations, annotation, axis=0)
        target = np.array(annotations)
        if self.preproc is not None:
            img, target = self.preproc(img, target)

        return torch.from_numpy(img), target

def detection_collate(batch):
    """Custom collate fn for dealing with batches of images that have a different
    number of associated object annotations (bounding boxes).

    Arguments:
        batch: (tuple) A tuple of tensor images and lists of annotations

    Return:
        A tuple containing:
            1) (tensor) batch of images stacked on their 0 dim
            2) (list of tensors) annotations for a given image are stacked on 0 dim
    """
    targets = []
    imgs = []
    for _, sample in enumerate(batch):
        for _, tup in enumerate(sample):
            if torch.is_tensor(tup):
                imgs.append(tup)
            elif isinstance(tup, type(np.empty(0))):
                annos = torch.from_numpy(tup).float()
                targets.append(annos)

    return (torch.stack(imgs, 0), targets)

上述便是pytorch的数据处理的方法,只需要继承一个基类:torch.utils.data.Dataset。然后再改写其中的__init____len____getitem__等方法,你就可以实现一个自己定义的数据接口。
使用的话,如前面所写的位于train.py中相关的,用torch.utils.data.DataLoader类来做进一步封装即可。

images, targets=torch.utils.data.DataLoader(dataset, batch_size, shuffle=True, num_workers=num_workers, collate_fn=detection_collate))

然后是data_augment.py。这个脚本主要包括一些数据增强的方法,裁剪,镜像,填充等等,和一个统一起来的数据处理的类,即wider_face.py中提到的preproc,在这个类里,将上面所有的数据增强的方法都用了一遍最后返回处理的图像和Target。

Tips:
上述不同于Faceboxes的方法有_crop_mirror,因为包含了对关键点信息的处理,更改的时候需要注意。

data_augment.py

import cv2
import numpy as np
import random
from utils.box_utils import matrix_iof


def _crop(image, boxes, labels, landm, img_dim):...
 
def _distort(image):...
   
def _expand(image, boxes, fill, p):...
 
def _mirror(image, boxes, landms):...
   
def _pad_to_square(image, rgb_mean, pad_image_flag):...
    
def _resize_subtract_mean(image, insize, rgb_mean):...
   
class preproc(object):

    def __init__(self, img_dim, rgb_means):
        self.img_dim = img_dim
        self.rgb_means = rgb_means

    def __call__(self, image, targets):
        assert targets.shape[0] > 0, "this image does not have gt"

        boxes = targets[:, :4].copy()
        labels = targets[:, -1].copy()
        landm = targets[:, 4:-1].copy()

        image_t, boxes_t, labels_t, landm_t, pad_image_flag = _crop(image, boxes, labels, landm, self.img_dim) #
        image_t = _distort(image_t)
        image_t = _pad_to_square(image_t,self.rgb_means, pad_image_flag)
        
        image_t, boxes_t, landm_t = _mirror(image_t, boxes_t, landm_t)            #
        height, width, _ = image_t.shape
        image_t = _resize_subtract_mean(image_t, self.img_dim, self.rgb_means)
        boxes_t[:, 0::2] /= width
        boxes_t[:, 1::2] /= height

        landm_t[:, 0::2] /= width                                                #
        landm_t[:, 1::2] /= height

        labels_t = np.expand_dims(labels_t, 1)
        targets_t = np.hstack((boxes_t, landm_t, labels_t))                     #

        return image_t, targets_t

三、结尾

数据处理的介绍到这里就结束了。有什么疑问的,可以看下我的这篇博客,里面是一个综述。

  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 13
    评论
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值