修改faster rcnn 数据输入,读取wider face

想使用faster rcnn训练脸部识别的数据wider face,就直接clone了faster,接下来需要修改数据读取的函数。

faster有很多对外调用的接口,比如/tool/demo.Py 或者/tool/train_net.Py,如何修改可以参考下一个博客:《faster-rcnn 之训练数据是如何准备的:imdb和roidb的产生》。

实际上从这两个接口文件,查看作者是怎么调用函数的感觉更方便一些,首先调用了/tool/train_faster_rcnn_alt_opt.Py里的函数train_RPN(),然后调用了相同文件中的get_roidb(),而后调用了get_imdb()(在/lib/datasets下)。到这儿需要了解的数据读取的就差不多了。

如何编辑可以读取wider face呢?

如果跟我一样,打算不怎么大改原来的文件,一共有三个地方需要修改,一是数据读入和转格式;二是告诉调用函数;三是调用文件的信息。(补充一下:需要查找想要的文件或关键字可以使用 grep -n -H -R "keys-words")

一 数据读入和转格式

wider face 的xml的标注文档格式和VOC的一致,可看成是VOC的特例。这样就方便了很多,直接复制VOC读取的那个文件过来,然后做些修改就好了。

#coding:utf-8
# --------------------------------------------------------
#时间:2018/1/5
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------

import os
from datasets.imdb import imdb
import datasets.ds_utils as ds_utils
import xml.etree.ElementTree as ET
import numpy as np
import scipy.sparse
import scipy.io as sio
import utils.cython_bbox
import cPickle
import subprocess
import uuid
from voc_eval import voc_eval
from fast_rcnn.config import cfg

class wf(imdb):
    def __init__(self, image_set,  devkit_path=None):#wider face数据较为简单,image_set是【train,test】,和数据有关,和pascal VOC非常相似
        print '^'*10,'this is wf','^'*10
        imdb.__init__(self, image_set)
        #self._year = year
        self._image_set = image_set
        self._devkit_path = '/home1/widerface/widerface/'#wider face下有三个文件,标注,图片,和读取文件列表
        self._data_path = os.path.join(self._devkit_path, '')#其实用不到这个了,VOC数据是分年份来存储的,wider face则没有这么多数据
        self._classes = ('__background__', 'face') # always index 0,total classes is 2.#一共两类
		
        self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))
        self._image_ext = '.jpg'
        self._image_index = self._load_image_set_index()#这个函数是重点之一,其实不用改,但需要知道index的意思:就是图片的文件名
        # Default to roidb handler
        self._roidb_handler = self.selective_search_roidb
        self._salt = str(uuid.uuid4())
        self._comp_id = 'comp4'

        # PASCAL specific config options
        self.config = {'cleanup'     : True,
                       'use_salt'    : True,
                       'use_diff'    : False,
                       'matlab_eval' : False,
                       'rpn_file'    : None,
                       'min_size'    : 2}#在别处看到它的意思了,最小所能包含的像素个数,再少就舍弃了,其实应该大一些,原文是2

        assert os.path.exists(self._devkit_path), \
                'widerface path does not exist: {}'.format(self._devkit_path)
        assert os.path.exists(self._data_path), \
                'Path does not exist: {}'.format(self._data_path)

    def image_path_at(self, i):
        """
        Return the absolute path to image i in the image sequence.
        """
        return self.image_path_from_index(self._image_index[i])

    def image_path_from_index(self, index):
        """
        Construct an image path from the image's "index" identifier.
        """
        image_path = os.path.join(self._data_path, 'JPEGImages',
                                  index + self._image_ext)#在这儿给出了图片的路径
        assert os.path.exists(image_path), \
                'Path does not exist: {}'.format(image_path)
        return image_path

    def _load_image_set_index(self):
        """
        Load the indexes listed in this dataset's image set file.
        """
        # Example path to image set file:
        # self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt
        image_set_file = os.path.join(self._data_path, 'ImageSets', 'Main',
                                      self._image_set + '.txt')#index的来源,是那个txt中的记录下来的
        assert os.path.exists(image_set_file), \
                'Path does not exist: {}'.format(image_set_file)
        with open(image_set_file) as f:
            image_index = [x.strip() for x in f.readlines()]
        return image_index



    def gt_roidb(self):
        """
	#这个函数是再后面一些的函数调用的,使用了cache_path变量
        Return the database of ground-truth regions of interest.

        This function loads/saves from/to a cache file to speed up future calls.
        """
        cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
        if os.path.exists(cache_file):
            with open(cache_file, 'rb') as fid:
                roidb = cPickle.load(fid)
            print '{} gt roidb loaded from {}'.format(self.name, cache_file)
            return roidb

        gt_roidb = [self._load_face_annotation(index)
                    for index in self.image_index]
        with open(cache_file, 'wb') as fid:
            cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)
        print 'wrote gt roidb to {}'.format(cache_file)

        return gt_roidb

    def selective_search_roidb(self):
        """
        Return the database of selective search regions of interest.
        Ground-truth ROIs are also included.

        This function loads/saves from/to a cache file to speed up future calls.
        """
        cache_file = os.path.join(self.cache_path,
                                  self.name + '_selective_search_roidb.pkl')

        if os.path.exists(cache_file):
            with open(cache_file, 'rb') as fid:
                roidb = cPickle.load(fid)
            print '{} ss roidb loaded from {}'.format(self.name, cache_file)
            return roidb

        if self._image_set != 'test':
            gt_roidb = self.gt_roidb()
            ss_roidb = self._load_selective_search_roidb(gt_roidb)
            roidb = imdb.merge_roidbs(gt_roidb, ss_roidb)
        else:
            roidb = self._load_selective_search_roidb(None)
        with open(cache_file, 'wb') as fid:
            cPickle.dump(roidb, fid, cPickle.HIGHEST_PROTOCOL)
        print 'wrote ss roidb to {}'.format(cache_file)

        return roidb

    def rpn_roidb(self):
        if self._image_set != 'test':
            gt_roidb = self.gt_roidb()
            rpn_roidb = self._load_rpn_roidb(gt_roidb)
            roidb = imdb.merge_roidbs(gt_roidb, rpn_roidb)
        else:
            roidb = self._load_rpn_roidb(None)

        return roidb

    def _load_rpn_roidb(self, gt_roidb):
        filename = self.config['rpn_file']
        print 'loading {}'.format(filename)
        assert os.path.exists(filename), \
               'rpn data not found at: {}'.format(filename)
        with open(filename, 'rb') as f:
            box_list = cPickle.load(f)
        return self.create_roidb_from_box_list(box_list, gt_roidb)

    def _load_selective_search_roidb(self, gt_roidb):
        filename = os.path.abspath(os.path.join(cfg.DATA_DIR,
                                                'selective_search_data',
                                                self.name + '.mat'))
        assert os.path.exists(filename), \
               'Selective search data not found at: {}'.format(filename)
        raw_data = sio.loadmat(filename)['boxes'].ravel()

        box_list = []
        for i in xrange(raw_data.shape[0]):
            boxes = raw_data[i][:, (1, 0, 3, 2)] - 1
            keep = ds_utils.unique_boxes(boxes)
            boxes = boxes[keep, :]
            keep = ds_utils.filter_small_boxes(boxes, self.config['min_size'])
            boxes = boxes[keep, :]
            box_list.append(boxes)

        return self.create_roidb_from_box_list(box_list, gt_roidb)

    def _load_face_annotation(self, index):#读取xml,比我之前写的要简单的多,当时用的是xml.sar
        """
        Load image and bounding boxes info from XML file in the wider face
        format.
        """
        filename = os.path.join(self._data_path, 'Annotations', index + '.xml')
        tree = ET.parse(filename)
        objs = tree.findall('object')
        if not self.config['use_diff']:
            # Exclude the samples labeled as difficult
            non_diff_objs = [
                obj for obj in objs if int(obj.find('difficult').text) == 0]
            # if len(non_diff_objs) != len(objs):
            #     print 'Removed {} difficult objects'.format(
            #         len(objs) - len(non_diff_objs))
            objs = non_diff_objs
        num_objs = len(objs)

        boxes = np.zeros((num_objs, 4), dtype=np.uint16)
        gt_classes = np.zeros((num_objs), dtype=np.int32)
        overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
        # "Seg" area for pascal is just the box area
        seg_areas = np.zeros((num_objs), dtype=np.float32)

        # Load object bounding boxes into a data frame.
        for ix, obj in enumerate(objs):
            bbox = obj.find('bndbox')
            # Make pixel indexes 0-based
            x1 = float(bbox.find('xmin').text) - 1#当时VOC的写错了,以前没有减去1
            y1 = float(bbox.find('ymin').text) - 1
            x2 = float(bbox.find('xmax').text) - 1
            y2 = float(bbox.find('ymax').text) - 1
            cls = self._class_to_ind[obj.find('name').text.lower().strip()]
            boxes[ix, :] = [x1, y1, x2, y2]
            gt_classes[ix] = cls
            overlaps[ix, cls] = 1.0
            seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)

        overlaps = scipy.sparse.csr_matrix(overlaps)

        return {'boxes' : boxes,
                'gt_classes': gt_classes,
                'gt_overlaps' : overlaps,
                'flipped' : False,
                'seg_areas' : seg_areas}
    #再后面的函数都基本没什么用了,是一些内容相关的,反正内容wider face和VOC都差不多
    def _get_comp_id(self):
        comp_id = (self._comp_id + '_' + self._salt if self.config['use_salt']
            else self._comp_id)
        return comp_id

    def _get_voc_results_file_template(self):
        # VOCdevkit/results/Main/<comp_id>_det_test_face.txt
        filename = self._get_comp_id() + '_det_' + self._image_set + '_{:s}.txt'
        path = os.path.join(
            self._devkit_path,
            'results',
            'Main',
            filename)
        return path

    def _write_voc_results_file(self, all_boxes):
        for cls_ind, cls in enumerate(self.classes):
            if cls == '__background__':
                continue
            print 'Writing {} results file'.format(cls)
            filename = self._get_voc_results_file_template().format(cls)
            with open(filename, 'wt') as f:
                for im_ind, index in enumerate(self.image_index):
                    dets = all_boxes[cls_ind][im_ind]
                    if dets == []:
                        continue
                    # the VOCdevkit expects 1-based indices
                    for k in xrange(dets.shape[0]):
                        f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.
                                format(index, dets[k, -1],
                                       dets[k, 0] + 1, dets[k, 1] + 1,
                                       dets[k, 2] + 1, dets[k, 3] + 1))
    #这是输出训练精度的函数
    def _do_python_eval(self, output_dir = 'output'):
        annopath = os.path.join(
            self._devkit_path,
            'Annotations',
            '{:s}.xml')
        imagesetfile = os.path.join(
            self._devkit_path,
            'ImageSets',
            'Main',
            self._image_set + '.txt')
        cachedir = os.path.join(self._devkit_path, 'annotations_cache')
        aps = []
        # The PASCAL VOC metric changed in 2010
        use_07_metric =  False
        print 'VOC07 metric? ' + ('Yes' if use_07_metric else 'No')
        if not os.path.isdir(output_dir):
            os.mkdir(output_dir)
        for i, cls in enumerate(self._classes):
            if cls == '__background__':
                continue
            filename = self._get_voc_results_file_template().format(cls)
            rec, prec, ap = voc_eval(
                filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5,
                use_07_metric=use_07_metric)
            aps += [ap]
            print('AP for {} = {:.4f}'.format(cls, ap))
            with open(os.path.join(output_dir, cls + '_pr.pkl'), 'w') as f:
                cPickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)
        print('Mean AP = {:.4f}'.format(np.mean(aps)))
        print('~~~~~~~~')
        print('Results:')
        for ap in aps:
            print('{:.3f}'.format(ap))
        print('{:.3f}'.format(np.mean(aps)))
        print('~~~~~~~~')
        print('')
        print('--------------------------------------------------------------')
        print('Results computed with the **unofficial** Python eval code.')
        print('Results should be very close to the official MATLAB eval code.')
        print('Recompute with `./tools/reval.py --matlab ...` for your paper.')
        print('-- Thanks, The Management')
        print('--------------------------------------------------------------')

    def _do_matlab_eval(self, output_dir='output'):
        print '-----------------------------------------------------'
        print 'Computing results with the official MATLAB eval code.'
        print '-----------------------------------------------------'
        path = os.path.join(cfg.ROOT_DIR, 'lib', 'datasets',
                            'VOCdevkit-matlab-wrapper')
        cmd = 'cd {} && '.format(path)
        cmd += '{:s} -nodisplay -nodesktop '.format(cfg.MATLAB)
        cmd += '-r "dbstop if error; '
        cmd += 'voc_eval(\'{:s}\',\'{:s}\',\'{:s}\',\'{:s}\'); quit;"' \
               .format(self._devkit_path, self._get_comp_id(),
                       self._image_set, output_dir)
        print('Running:\n{}'.format(cmd))
        status = subprocess.call(cmd, shell=True)

    def evaluate_detections(self, all_boxes, output_dir):
        self._write_voc_results_file(all_boxes)
        self._do_python_eval(output_dir)
        if self.config['matlab_eval']:
            self._do_matlab_eval(output_dir)
        if self.config['cleanup']:
            for cls in self._classes:
                if cls == '__background__':
                    continue
                filename = self._get_voc_results_file_template().format(cls)
                os.remove(filename)

    def competition_mode(self, on):
        if on:
            self.config['use_salt'] = False
            self.config['cleanup'] = False
        else:
            self.config['use_salt'] = True
            self.config['cleanup'] = True

if __name__ == '__main__':
    from datasets.pascal_voc import pascal_voc
    d = pascal_voc('trainval', '2007')
    res = d.roidb
    from IPython import embed; embed()

二 将该数据读取类告诉调用函数

修改了VOC的数据读取文件,变成了wider face,然后需要再修改factory下的内容,其实就是告诉get_imdb()去哪儿读取内容数据(路径,/py-faster-rcnn/lib/datasets/factory.py)。

# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------

"""Factory method for easily getting imdbs by name."""
import numpy as np
__sets = {}#这儿的字典只是一个相关于数据读取类的配对,有点函数指针的感觉


from datasets.pascal_voc import pascal_voc
from datasets.coco import coco
from datasets.wf import wf


# Set up voc_<year>_<split> using selective search "fast" mode
for year in ['2007', '2012']:
    for split in ['train', 'val', 'trainval', 'test']:
        name = 'voc_{}_{}'.format(year, split)
        __sets[name] = (lambda split=split, year=year: pascal_voc(split, year))#函数,冒号前的参数,后是谁调用这些参数,返回的是一个对象

# Set up coco_2014_<split>
for year in ['2014']:
    for split in ['train', 'val', 'minival', 'valminusminival']:
        name = 'coco_{}_{}'.format(year, split)
        __sets[name] = (lambda split=split, year=year: coco(split, year))

# Set up coco_2015_<split>
for year in ['2015']:
    for split in ['test', 'test-dev']:
        name = 'coco_{}_{}'.format(year, split)
        __sets[name] = (lambda split=split, year=year: coco(split, year))

# Set up wider face _<split> using selective search "fast" mode
for split in ['train', 'trainval', 'test']:
    name = 'wf_{}'.format(split)#这个名字要记住,因为下面要用到了,根据传入关键字,调用相应的数据读取类
    __sets[name] = (lambda split=split: wf(split))#这儿说清楚了split的意思是啥
     

def get_imdb(name):
    """Get an imdb (image database) by name."""
    if not __sets.has_key(name):
        raise KeyError('Unknown dataset: {}'.format(name))
    return __sets[name]()

def list_imdbs():
    """List all registered imdbs."""
    return __sets.keys()

三 调用文件的信息

官方给出的code链接https://github.com/rbgirshick/py-faster-rcnn下面是有说明怎么训练这个网络的./experiments/scripts/faster_rcnn_end2end.sh [GPU_ID] [NET] [--set ...],或者另外一个.sh文件。
需要修改自定义的文件读取类,增加第三种选项
case $DATASET in
  pascal_voc)
    TRAIN_IMDB="voc_2007_trainval"
    TEST_IMDB="voc_2007_test"
    PT_DIR="pascal_voc"
    ITERS=40000
    ;;
  coco)
    echo "Not implemented: use experiments/scripts/faster_rcnn_end2end.sh for coco"
    exit
    ;;
  *)#随便其他的名字,相同的设置可用于faster_rcnn_alt_opt.sh
    #echo "No dataset given" #data:2018.1.4
    #exit
    TRAIN_IMDB="wf_trainval"#这儿是上一步设置的名称,用来调用数据读取类别
    TEST_IMDB="wf_test"
    PT_DIR="pascal_voc" # 直接使用Pascal VOC的网络结构 model prototxt
    ITERS=40000
    ;;
esac

四 修改训练参数

① 修改prototxt文件

除了需要修改以调整输入数据外,还需要修改网络的输出维度。比如Pascal数据集的训练网络输出是21(20+1)个类别,相应的bounding box为84(4*21),而face数据集只有是与不是两个类别。需要修改的地方可以参考faster rcnn根目录下的tools/train_faster_rcnn_alt_opt.py 文件。

faster rcnn是希望整个网络能和fast rcnn共享参数,所以分了四个阶段的训练,每个阶段都有一些训练用的超参数和网络结构。在上面提到的python文件中可以看到所调用的这些prototxt文件内容

def get_solvers(net_name):
    # Faster R-CNN Alternating Optimization
    n = 'faster_rcnn_alt_opt'
    # Solver for each training stage
    #这儿是所用到的四个solver,solver中提到所用到的训练用的网络结构prototxt
    #路径是py-faster-rcnn/models/pascal_voc/VGG16/faster_rcnn_alt_opt/下
    solvers = [[net_name, n, 'stage1_rpn_solver60k80k.pt'],
               [net_name, n, 'stage1_fast_rcnn_solver30k40k.pt'],
               [net_name, n, 'stage2_rpn_solver60k80k.pt'],
               [net_name, n, 'stage2_fast_rcnn_solver30k40k.pt']]
    solvers = [os.path.join(cfg.MODELS_DIR, *s) for s in solvers]
    # Iterations for each training stage
    max_iters = [80000, 40000, 80000, 40000]
    # max_iters = [100, 100, 100, 100]
    # Test prototxt for the RPN
    rpn_test_prototxt = os.path.join(
        cfg.MODELS_DIR, net_name, n, 'rpn_test.pt')
    return solvers, max_iters, rpn_test_prototxt

在这些所提到的网络参数中,修改层bbox_pred和cls_score是输出维度。此外,faster rcnn除了调用了c++的caffe层结构,还是用了python下所定义的caffe层结构【关于python下定义caffe层结构,可以学习一下faster rcnn的用法了】,所以有些type是python类的层结构(当然也是在这四个solver下的网络结构)也需要注意修改。比如在stage1_fast_rcnn_train.pt下的一个层结构:

layer {
  name: 'data'
  type: 'Python'
  top: 'data'
  top: 'rois'
  top: 'labels'
  top: 'bbox_targets'
  top: 'bbox_inside_weights'
  top: 'bbox_outside_weights'
  python_param {
    module: 'roi_data_layer.layer' #在路径/py-faster-rcnn/lib/roi_data_layer/layer.py下
    layer: 'RoIDataLayer' #该python文件下的RoIDataLayer类--class RoIDataLayer(caffe.Layer):
    param_str: "'num_classes': 2" #pascal是21类,现在只需要2类
  }
}

类似的,该文件夹(module)下的几个网络结构都做如此修改。

② 修改python文件

在运行时,会遇到一些numpy的运行错误,基本上都是numpy的矩阵索引引发的错误,网上可查的错误是因为numpy升级导致的。不过呢,其实就是索引类型是float类型,在numpy后期的类型,索引需要是整形的long ing longlong等等。不过还好,不用预先修改,这些错误都在运行的早期就会暴露出来,出现一个修改一处就好,将这些变量转换下数据格式就好了。

③ 后期修改

很不幸,我虽然一再说要把剩下的文件修改参数都修改为两类,即一共2类(而不是pascal 的21类)以及8个bounding box的输出维度。faster rcnn是多阶段训练的,在训练完RPN网络之后,有一个测试该网络的test phase,也就是这里面的网络参数我没有修改,但是之前的阶段都已经训练结束了,花了很长的时间,要想办法利用起以前的东西。

根据我调用的批处理文件,我注释掉了/tools/train_faster_rcnn_alt_opt.py文件中的运行完的阶段(包括stage 1, stage 1,stage 2 ,stage 2)。阶段间的联系仅仅是靠训练好的caffemodel,所以预先添加上一步的caffemodel的路径就好了,然后修改字典中init_model的变量。相应的变量可以在log日志文件中查到。

五 faster rcnn训练时几个特点

faster rcnn一心想共享网络层结构,除了使用VGG或ZF基础网络用来finetune外,还使用了多个阶段来训练。

① vgg基础网络的参数微调方式

使用基础网络时,因为使用了多个阶段,比如vgg网络,有部分网络结构(全连接层fc6和fc7)不在stage1使用,所以创建了一个假网络层,只用来转运参数。比如查看stage1_rpn_train.pt的网络结构时,就可以发现如下的结构:

上面的dummy_data就是虚设的层,将vgg的参数传送到后面的训练结构中去。

② 使用rpn替代ssp用来训练fast rcnn网络,然后再交叉训练。

③ rpn网络是一种卷积结构

使用卷积结构用来生成框bounding box,也达到了共享参数的作用,大大的(相比于selective search)降低了proposal的时间。

这部分的网络结构如下(rpn_conv/3*3):顺便提一下rpn_bbox_pred是(4*9),对任意一个feature map中的点都是一个anchor,会生成36个维度。而相应的rpn_cls_score(2*9,是或者不是一个物体)是18维的。然后18维度的cls_score被送入到一个softmax分类网络里面,根据score得分得到roi,用于后面的bb回归和fast_rcnn等一系列操作。当然,这些feature map中所得到的roi不是对原图进行裁剪,然后送入到分类器中去,而是直接将roi对应的feature采集输送到分类器中去的。所消耗的时间在于roi的数目,彼此不可共享参数互相独立的用于分类识别最后的过程。

④ 不共享的部分,文中提到bounding box是不共享参数的,是为了得到不同尺寸下的回归框(即9个),这些回归框是不共享参数的,overfeat和fast rcnn再bb 回归的时候是共享参数的。

六 其他好玩的东西

使用服务器训练的时候,可能会关掉客户端的情况。就是有时会再训练的时候断开客户端和服务器间的连接,可以使用一个linux下的命令--tmux
tmux会寄存客户端调用的进程,然后随时可以恢复到客户端。
现在情况下会用到的命令有
tmux new -s new_name_of_ session //创建新的寄存器
tmux a -t name_of_ session //恢复制定现场
tmux ls //列表
//退出当前寄存器,可以去搞其他事情,ctrl+b进入快捷键模式,ctrl+d 退出
//杀死一只tmux, tmux kill session -t name_of_ session

但是,tmux的日志不能很好的保留下来,需要用到tee命令,比如
./experiments/scripts/faster_rcnn_alt_opt.sh [GPU_ID] [NET] [--set ...] 2>&1 | tee log.log
会把运行时产生的文档信息保存到制定的log.log下,路径就是caffe的主目录了,调用.sh的目录
2>&1是保存所有信息,即便是报错后进程停止的信息。

当然,使用faster_rcnn_alt_opt.sh后,这个批处理文件自动保存了日志文件。


展开阅读全文

没有更多推荐了,返回首页