使用faster rcnn训练umdfaces数据集

实验目的:

基于caffe框架使用faster rcnn方法进行人脸检测;实验所使用的数据集为umdfaces ,总共有三个文件,一共8000+个类别,总共36W张人脸图片,全都是经过标注的样本,标注信息保存在csv文件中,除了人脸的box,还有人脸特征点的方位信息,强力推荐!


实验平台及准备:

训练服务器: NVIDIA Tesla K80

预测终端    : NVIDIA TX1

框架            : caffe

方法            : faster rcnn

训练样本    : umdfaces人脸数据库


实验步骤:

在服务器上和预测终端上分别部署caffe环境,一定要使用faster rcnn作者GitHub上的那一个版本,地址:https://github.com/rbgirshick/py-faster-rcnn

对于环境的部署还有测试,网上有很多教程,我就不重复叙述了,这里只介绍记录一下对于本次实验的一些步骤。不过大家有什么问题可以联系我,有时间一定会帮忙解答。


一、数据库准备

到umdfaces官网去下载,链接地址在上面给出来了,文件比较大,三个压缩包,每一个都是10几个G,其中有一个居然到了20G,而且还是在谷歌云盘,国内的话我估计大部分人都不要想了,大家还是寻求在国外的小伙伴帮助,下载以后通过QQ或者其他方式发给你,我就是这么干的。。。

联系方式:

Q#Q:43597717#0(去掉#)


二、样本数据的处理

使用过faster rcnn的小伙伴都知道,作者打开了caffe的python layer,因此在caffe搭建网络训练样本的时候肯定少不了与python代码的交互的,这里不同于官网版本的caffe框架,对于其中的有些层,layer,作者是用了python定义的,比如说,数据层,跟官方提供方法先制作lmdb文件不一样,这里使用了python定义了文件的读取还有roi的标志。


所以我们的umdfaces还需要制作成VOC2007格式(我习惯这么叫了),在faster rcnn的示例教程中可以看到,每一个训练集,除了时间,命名,还要将其注册到工厂类中,更重要的是对于数据集最好有一个自己的类,以该数据集命名,比如我的就是face。VOC2007数据集格式为下面图片所示:



废话不多说,把数据集格式了解清楚以后就要开始去准备和处理了,umdfaces提供了标注信息,但是保存在csv文件中,大家要做的就是通过脚本或者程序将umdfaces数据集改变成faster rcnn默认支持的格式。这里为了方便大家复现还有,方便我继续写这一篇博文,我把自己的处理脚本上传到了GitHub上,有需要可以去下载,地址:luuuyi/umdfaces2VOC2007  。介绍一下,也就是先自己把上图中的第一个和第三个文件夹创建好以后,将数据集路径修改对就好了,需要对Python有一定的了解,代码的一个片段如下:

#!/usr/bin/env python

from tool_csv import loadCSVFile
from tool_lxml import createXML
import cv2
import os

FILEDIR = "/media/scs4450/hard/umdfaces_batch1/"         #
IMGSTORE = "/media/scs4450/hard/JPEGImages/"
FILENAME = "umdfaces_batch1_ultraface.csv"               #
ANNOTATIONDIR = "/media/scs4450/hard/Annotations/"

if __name__ == "__main__":
    csv_content = loadCSVFile(FILEDIR+FILENAME)
    cvs_content_part = csv_content[1:,1:10]
    i=1
    base=3000000                                         #
    limit = 1000000                                       #
    
    for info in cvs_content_part:
        if i==limit:
            print "Reach Limit, Stop..."
            break

        print "Process No." + str(i) + " Data...."

        str_splite = '/'
        str_spilte_list = str(info[0]).split(str_splite)

        jpg_path = info[0]
        #jpg_file = str_spilte_list[len(str_spilte_list)-1]
        jpg_file = str(base+i)+'.jpg'
        os.system('cp '+ FILEDIR+jpg_path + ' ' + IMGSTORE+jpg_file)

        img = cv2.imread(FILEDIR+jpg_path)
        sp = img.shape
        #print sp
        height = sp[0]                 #height(rows) of image
        width = sp[1]                  #width(colums) of image
        depth = sp[2]                  #the pixels value is made up of three primary colors
        #print 'width: %d \nheight: %d \nnumber: %d' %(width,height,depth)

        xmin = int(float(info[3]))
        ymin = int(float(info[4]))
        xmax = int(float(info[3])+float(info[5]))
        ymax = int(float(info[4])+float(info[6]))
        #print 'xmin: %d \nymin: %d \nxmax: %d \nymax: %d' %(xmin,ymin,xmax,ymax)

        transf = dict()
        transf['folder'] = "FACE2016"
        transf['filename'] = jpg_file
        transf['width'] = str(width)
        transf['height'] = str(height)
        transf['depth'] = str(depth)
        transf['xmin'] = str(xmin)
        transf['ymin'] = str(ymin)
        transf['xmax'] = str(xmax)
        transf['ymax'] = str(ymax)

        print "Create No." + str(i) + " XML...."
        createXML(transf,ANNOTATIONDIR)
        i = i + 1
        #print jpg_path, jpg_file
        #jpg
    
    print "Done..."

    

对于第二个文件夹,也就是ImageSets文件夹,其中的内容生成我借鉴了这篇博客的方法: 将数据集做成VOC2007格式用于Faster-RCNN训练   主要是为了快速开发,对于这个文件夹内容生成的脚本我就偷懒了,不过我后续会更新到自己的github主页上的(求给个星鼓励下!!),在原博主的方法中,可能需要下载一下MATLAB,因为他的脚本是用matlab来写的。


三、faster rcnn训练代码修改

先定义一下,faster rcnn在Linux系统中大家添加个环境变量吧,这里方便描述将其定义为 $FASTERRCNN,大家都知道在 $FASTERRCNN/experiments/scripts中有一个训练脚本,这里第一个修改的文件就是他:

在其中的DATASET部分添加一个自己的face,如下图所示:


其余地方不变,好了,入口修改好了,这下去看python代码,首先到 $FASTERRCNN/lib/datasets目录,这里的改变为,添加一个face.py文件,该文件的内容其实就是仿造pascal_voc.py仿写的一个类,后续我把全部的代码都贴出来吧,省得一个一个的去截图,这里先说一下factory.py这个文件,顾名思义这是个工厂类,需要在其中注册之后自己写的face类,改动如下:



这里是对face类的一个注册,之后就是对face.py文件的一个创建修改了,这里索性代码全贴上来吧:

# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------

import os
from datasets.imdb import imdb
import datasets.ds_utils as ds_utils
import xml.etree.ElementTree as ET
import numpy as np
import scipy.sparse
import scipy.io as sio
import utils.cython_bbox
import cPickle
import subprocess
import uuid
from voc_eval import voc_eval
from fast_rcnn.config import cfg

class face(imdb):    #luyi
    def __init__(self, image_set, year, devkit_path=None):
        imdb.__init__(self, 'face_' + year + '_' + image_set)   #luyi
        self._year = year
        self._image_set = image_set
        self._devkit_path = self._get_default_path() if devkit_path is None \
                            else devkit_path
        self._data_path = os.path.join(self._devkit_path, 'FACE' + self._year)  #luyi
        self._classes = ('__background__', # always index 0
                         'face')           #luyi
        self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))
        self._image_ext = '.jpg'
        self._image_index = self._load_image_set_index()
        # Default to roidb handler
        self._roidb_handler = self.selective_search_roidb
        self._salt = str(uuid.uuid4())
        self._comp_id = 'comp4'

        # PASCAL specific config options
        self.config = {'cleanup'     : True,
                       'use_salt'    : True,
                       'use_diff'    : False,
                       'matlab_eval' : False,
                       'rpn_file'    : None,
                       'min_size'    : 16}     #luyi

        assert os.path.exists(self._devkit_path), \
                'VOCdevkit path does not exist: {}'.format(self._devkit_path)
        assert os.path.exists(self._data_path), \
                'Path does not exist: {}'.format(self._data_path)

    def image_path_at(self, i):
        """
        Return the absolute path to image i in the image sequence.
        """
        return self.image_path_from_index(self._image_index[i])

    def image_path_from_index(self, index):
        """
        Construct an image path from the image's "index" identifier.
        """
        image_path = os.path.join(self._data_path, 'JPEGImages',
                                  index + self._image_ext)
        assert os.path.exists(image_path), \
                'Path does not exist: {}'.format(image_path)
        return image_path

    def _load_image_set_index(self):
        """
        Load the indexes listed in this dataset's image set file.
        """
        # Example path to image set file:
        # self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt
        image_set_file = os.path.join(self._data_path, 'ImageSets', 'Main',
                                      self._image_set + '.txt')
        assert os.path.exists(image_set_file), \
                'Path does not exist: {}'.format(image_set_file)
        with open(image_set_file) as f:
            image_index = [x.strip() for x in f.readlines()]
        return image_index

    def _get_default_path(self):
        """
        Return the default path where PASCAL VOC is expected to be installed.
        """
        return os.path.join(cfg.DATA_DIR, 'VOCdevkit' + '2007')

    def gt_roidb(self):
        """
        Return the database of ground-truth regions of interest.

        This function loads/saves from/to a cache file to speed up future calls.
        """
        cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
        if os.path.exists(cache_file):
            with open(cache_file, 'rb') as fid:
                roidb = cPickle.load(fid)
            print '{} gt roidb loaded from {}'.format(self.name, cache_file)
            return roidb

        gt_roidb = [self._load_pascal_annotation(index)
                    for index in self.image_index]
        with open(cache_file, 'wb') as fid:
            cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)
        print 'wrote gt roidb to {}'.format(cache_file)

        return gt_roidb

    def selective_search_roidb(self):
        """
        Return the database of selective search regions of interest.
        Ground-truth ROIs are also included.

        This function loads/saves from/to a cache file to speed up future calls.
        """
        cache_file = os.path.join(self.cache_path,
                                  self.name + '_selective_search_roidb.pkl')

        if os.path.exists(cache_file):
            with open(cache_file, 'rb') as fid:
                roidb = cPickle.load(fid)
            print '{} ss roidb loaded from {}'.format(self.name, cache_file)
            return roidb

        if int(self._year) == 2007 or self._image_set != 'test':
            gt_roidb = self.gt_roidb()
            ss_roidb = self._load_selective_search_roidb(gt_roidb)
            roidb = imdb.merge_roidbs(gt_roidb, ss_roidb)
        else:
            roidb = self._load_selective_search_roidb(None)
        with open(cache_file, 'wb') as fid:
            cPickle.dump(roidb, fid, cPickle.HIGHEST_PROTOCOL)
        print 'wrote ss roidb to {}'.format(cache_file)

        return roidb

    def rpn_roidb(self):
        if int(self._year) == 2007 or self._image_set != 'test':
            gt_roidb = self.gt_roidb()
            rpn_roidb = self._load_rpn_roidb(gt_roidb)
            roidb = imdb.merge_roidbs(gt_roidb, rpn_roidb)
        else:
            roidb = self._load_rpn_roidb(None)

        return roidb

    def _load_rpn_roidb(self, gt_roidb):
        filename = self.config['rpn_file']
        print 'loading {}'.format(filename)
        assert os.path.exists(filename), \
               'rpn data not found at: {}'.format(filename)
        with open(filename, 'rb') as f:
            box_list = cPickle.load(f)
        return self.create_roidb_from_box_list(box_list, gt_roidb)

    def _load_selective_search_roidb(self, gt_roidb):
        filename = os.path.abspath(os.path.join(cfg.DATA_DIR,
                                                'selective_search_data',
                                                self.name + '.mat'))
        assert os.path.exists(filename), \
               'Selective search data not found at: {}'.format(filename)
        raw_data = sio.loadmat(filename)['boxes'].ravel()

        box_list = []
        for i in xrange(raw_data.shape[0]):
            boxes = raw_data[i][:, (1, 0, 3, 2)] - 1
            keep = ds_utils.unique_boxes(boxes)
            boxes = boxes[keep, :]
            keep = ds_utils.filter_small_boxes(boxes, self.config['min_size'])
            boxes = boxes[keep, :]
            box_list.append(boxes)

        return self.create_roidb_from_box_list(box_list, gt_roidb)

    def _load_pascal_annotation(self, index):
        """
        Load image and bounding boxes info from XML file in the PASCAL VOC
        format.
        """
        filename = os.path.join(self._data_path, 'Annotations', index + '.xml')
        tree = ET.parse(filename)
        objs = tree.findall('object')
        if not self.config['use_diff']:
            # Exclude the samples labeled as difficult
            non_diff_objs = [
                obj for obj in objs if int(obj.find('difficult').text) == 0]
            # if len(non_diff_objs) != len(objs):
            #     print 'Removed {} difficult objects'.format(
            #         len(objs) - len(non_diff_objs))
            objs = non_diff_objs
        num_objs = len(objs)

        boxes = np.zeros((num_objs, 4), dtype=np.uint16)
        gt_classes = np.zeros((num_objs), dtype=np.int32)
        overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
        # "Seg" area for pascal is just the box area
        seg_areas = np.zeros((num_objs), dtype=np.float32)

        # Load object bounding boxes into a data frame.
        for ix, obj in enumerate(objs):
            bbox = obj.find('bndbox')
            # Make pixel indexes 0-based
            x1 = float(bbox.find('xmin').text)        #luyi
            y1 = float(bbox.find('ymin').text)        #luyi
            x2 = float(bbox.find('xmax').text) - 1
            y2 = float(bbox.find('ymax').text) - 1
            cls = self._class_to_ind[obj.find('name').text.lower().strip()]
            boxes[ix, :] = [x1, y1, x2, y2]
            gt_classes[ix] = cls
            overlaps[ix, cls] = 1.0
            seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)

        overlaps = scipy.sparse.csr_matrix(overlaps)

        return {'boxes' : boxes,
                'gt_classes': gt_classes,
                'gt_overlaps' : overlaps,
                'flipped' : False,
                'seg_areas' : seg_areas}

    def _get_comp_id(self):
        comp_id = (self._comp_id + '_' + self._salt if self.config['use_salt']
            else self._comp_id)
        return comp_id

    def _get_voc_results_file_template(self):
        # VOCdevkit/results/VOC2007/Main/<comp_id>_det_test_aeroplane.txt
        filename = self._get_comp_id() + '_det_' + self._image_set + '_{:s}.txt'
        path = os.path.join(
            self._devkit_path,
            'results',
            'VOC' + '2007',    #luyi
            'Main',
            filename)
        return path

    def _write_voc_results_file(self, all_boxes):
        for cls_ind, cls in enumerate(self.classes):
            if cls == '__background__':
                continue
            print 'Writing {} VOC results file'.format(cls)
            filename = self._get_voc_results_file_template().format(cls)
            with open(filename, 'wt') as f:
                for im_ind, index in enumerate(self.image_index):
                    dets = all_boxes[cls_ind][im_ind]
                    if dets == []:
                        continue
                    # the VOCdevkit expects 1-based indices
                    for k in xrange(dets.shape[0]):
                        f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.
                                format(index, dets[k, -1],
                                       dets[k, 0] + 1, dets[k, 1] + 1,
                                       dets[k, 2] + 1, dets[k, 3] + 1))

    def _do_python_eval(self, output_dir = 'output'):
        annopath = os.path.join(
            self._devkit_path,
            'FACE' + self._year,       #LUYI
            'Annotations',
            '{:s}.xml')
        imagesetfile = os.path.join(
            self._devkit_path,
            'FACE' + self._year,      #LUYI
            'ImageSets',
            'Main',
            self._image_set + '.txt')
        cachedir = os.path.join(self._devkit_path, 'annotations_cache')
        aps = []
        # The PASCAL VOC metric changed in 2010
        use_07_metric = True if int('2007') < 2010 else False      #luyi
        print 'VOC07 metric? ' + ('Yes' if use_07_metric else 'No')
        if not os.path.isdir(output_dir):
            os.mkdir(output_dir)
        for i, cls in enumerate(self._classes):
            if cls == '__background__':
                continue
            filename = self._get_voc_results_file_template().format(cls)
            rec, prec, ap = voc_eval(
                filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5,
                use_07_metric=use_07_metric)
            aps += [ap]
            print('AP for {} = {:.4f}'.format(cls, ap))
            with open(os.path.join(output_dir, cls + '_pr.pkl'), 'w') as f:
                cPickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)
        print('Mean AP = {:.4f}'.format(np.mean(aps)))
        print('~~~~~~~~')
        print('Results:')
        for ap in aps:
            print('{:.3f}'.format(ap))
        print('{:.3f}'.format(np.mean(aps)))
        print('~~~~~~~~')
        print('')
        print('--------------------------------------------------------------')
        print('Results computed with the **unofficial** Python eval code.')
        print('Results should be very close to the official MATLAB eval code.')
        print('Recompute with `./tools/reval.py --matlab ...` for your paper.')
        print('-- Thanks, The Management')
        print('--------------------------------------------------------------')

    def _do_matlab_eval(self, output_dir='output'):
        print '-----------------------------------------------------'
        print 'Computing results with the official MATLAB eval code.'
        print '-----------------------------------------------------'
        path = os.path.join(cfg.ROOT_DIR, 'lib', 'datasets',
                            'VOCdevkit-matlab-wrapper')
        cmd = 'cd {} && '.format(path)
        cmd += '{:s} -nodisplay -nodesktop '.format(cfg.MATLAB)
        cmd += '-r "dbstop if error; '
        cmd += 'voc_eval(\'{:s}\',\'{:s}\',\'{:s}\',\'{:s}\'); quit;"' \
               .format(self._devkit_path, self._get_comp_id(),
                       self._image_set, output_dir)
        print('Running:\n{}'.format(cmd))
        status = subprocess.call(cmd, shell=True)

    def evaluate_detections(self, all_boxes, output_dir):
        self._write_voc_results_file(all_boxes)
        self._do_python_eval(output_dir)
        if self.config['matlab_eval']:
            self._do_matlab_eval(output_dir)
        if self.config['cleanup']:
            for cls in self._classes:
                if cls == '__background__':
                    continue
                filename = self._get_voc_results_file_template().format(cls)
                os.remove(filename)

    def competition_mode(self, on):
        if on:
            self.config['use_salt'] = False
            self.config['cleanup'] = False
        else:
            self.config['use_salt'] = True
            self.config['cleanup'] = True

if __name__ == '__main__':
    from datasets.face import face           #luyi
    d = face('trainval', '2016')             #luyi
    res = d.roidb
    from IPython import embed; embed()

修改后的完全版代码在这儿,大家各取所需,对于每一个修改的部分我都在后面注释过,有问题大家一起交流学习。

在训练过程中,有个地方有些经验分享一下吧

1)使用faster rcnn训练自己样本的时候,一定记得在每一次训练之前把上一次的缓存文件给删了,在两个目录下有这些缓存文件:

$FASTERRCNN/data/cache
$FASTERRCNN/data/VOCdevkit2007

这两个目录下的缓存文件一定要删掉才能继续训练


2)记得多去看训练之后的日志,对着日志的每一行去阅读代码,我觉得是一个很好的流程认识的方法,日志的路径一般在

$FASTERRCNN/experiment/logs

四、结果

这里展示的结果是对umdfaces第三个batch训练完的结果,大概是9w张图片,迭代次数为3W次,最终的一个test结果如下:

平均预测值为0.888,虽然还没有达到百分之九十,但是这是第一次接触,感觉还算不错,能接受,之后可以继续学习调整。

使用faster rcnn的demo演示,测试同样的几张图片,结果如下:

可以看出来,总共6张图片,只检测到其中两张有人脸,一张正确的,一张误检了猫的脸,后续还有很大的改进空间,比如全部36w张图一起训练,然后迭代次数到7w次,或者调解网络的参数,方法很多。


五、总结

博文最好的地方就是想怎么写就怎么写,没总结,但是最后想说的是,这一次实验带给我的不仅仅是结果上的一个呈现,最重要的是一周前,我还是对于深度学习,对于caffe啥都不懂的一个路人,经过动手去学习尝试之后,慢慢的对于深度学习有了一个大概的认识,这是个痛苦的过程,不过现在回头望,的确也成长了不少。还有很重要的一点就是大家要是想在这一块做研究,python一定要懂一点,就这样,有问题留言沟通。

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
使用Faster RCNN训练自己的数据集的步骤如下: 1. 配置电脑环境:确保电脑配置满足要求,包括安装好所需的支持包和软件。\[2\] 2. 准备数据集:收集并标注自己的数据集,确保每个图像都有对应的标注框和类别信息。 3. 修改配置文件:根据自己的数据集和需求,修改Faster RCNN源码中的配置文件,包括类别数、路径等参数。 4. 数据预处理:将数据集进行预处理,包括图像的缩放、裁剪、归一化等操作,以适应模型的输入要求。 5. 训练模型:使用修改后的配置文件和预处理后的数据集,进行模型的训练。可以使用训练的模型作为初始权重,然后进行迭代训练。 6. 模型评估:训练完成后,使用测试集对模型进行评估,计算模型的准确率、召回率等指标,以评估模型的性能。 7. 模型保存:将训练得到的最终模型保存下来,可以将其拷贝到指定的目录中,以备后续使用。\[3\] 需要注意的是,训练自己的数据集需要一定的时间和计算资源,并且需要对Faster RCNN的源码和配置文件进行一定的了解和修改。同时,还需要对数据集进行充分的标注和预处理,以获得更好的训练效果。 #### 引用[.reference_title] - *1* *2* [【目标检测】用自己的数据训练Faster RCNN的详细全过程(步骤很详细很直观,小白可入)](https://blog.csdn.net/qq_38391210/article/details/104607895)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] - *3* [faster rcnn 训练自己的数据](https://blog.csdn.net/hanpengpeng329826/article/details/64905021)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值