faster rcnn训练

输入命令./experiments/scripts/faster_rcnn_alt_opt.sh 0 ZF pascal_voc

faster_rcnn_alt_opt.sh的训练代码如下:

time ./tools/train_faster_rcnn_alt_opt.py --gpu ${GPU_ID} \

  --net_name ${NET} \
  --weights data/imagenet_models/${NET}.v2.caffemodel \
  --imdb ${TRAIN_IMDB} \
  --cfg experiments/cfgs/faster_rcnn_alt_opt.yml \

  ${EXTRA_ARGS}

参数赋值:${GPU_ID} =0、${NET} =ZF、${TRAIN_IMDB} =pascal_voc

调用./tools/train_faster_rcnn_alt_opt.py文件

1、初始化参数

(1)args = parse_args(),

包括--gpu:这个参数指定训练使用的GPU设备,我的电脑只有一枚GPU,默认情况下自动开启,其gpu_id为0;

--net_name:训练的网络类型

--weights:这个参数指定了finetune的初始参数,我的电脑GPU不怎么高端,只能使用caffenet进行finetune;

--cfg:

--imdb:这个参数指定了训练所需要的训练数据,如果你需要训练自己的数据,那么这个参数是必须要指定的;

--set:


(2)rpn_test_prototxt = get_solvers(args.net_name)

在models/pascal_voc/netname/faster_rcnn_alt_opt文件夹下有stage1_rpn_solver60k80k.pt等不同阶段所对应的solver文件,并定义了各个阶段最大迭代次数,这里支持的net_name有VGG16、VGG_CNN_M_1024、ZF三种

def get_solvers(net_name):
    # Faster R-CNN Alternating Optimization
    n = 'faster_rcnn_alt_opt'
    # Solver for each training stage
    solvers = [[net_name, n, 'stage1_rpn_solver60k80k.pt'],
               [net_name, n, 'stage1_fast_rcnn_solver30k40k.pt'],
               [net_name, n, 'stage2_rpn_solver60k80k.pt'],
               [net_name, n, 'stage2_fast_rcnn_solver30k40k.pt']]
    solvers = [os.path.join(cfg.MODELS_DIR, *s) for s in solvers]
    # Iterations for each training stage
    #每一轮训练的最大迭代次数,建议测试时都设置为100
    max_iters = [80000, 40000, 80000, 40000]
    # max_iters = [100, 100, 100, 100]
    # Test prototxt for the RPN
    rpn_test_prototxt = os.path.join(
        cfg.MODELS_DIR, net_name, n, 'rpn_test.pt')
    return solvers, max_iters, rpn_test_prototxt


2、train_rpn函数

(1)_init_caffe:该函数作用便是初始化caffe对象,仅做了两步操作,第一步是初始化随机种子,第二步是设置GPU。

def _init_caffe(cfg):
    """Initialize pycaffe in a training process.
    """

    import caffe
    # fix the random seeds (numpy and caffe) for reproducibility
    np.random.seed(cfg.RNG_SEED)
    caffe.set_random_seed(cfg.RNG_SEED)
    # set up caffe
    caffe.set_mode_gpu()
    caffe.set_device(cfg.GPU_ID)

(2)roidb, imdb = get_roidb(imdb_name)  

->imdb = get_imdb(imdb_name)

imdb根据imdb_name(默认是“voc_2007_trainval)来获取,这里的imdb对象的获取采用了工厂模式,由\lib\datasets\factory.py生成,根据年份(2007)与切分的数据集(trainval)返回pascal_voc对象,pascal_voc与coco都继承于imdb对象。(\lib\datasets\pascal_voc.py+coco.py)

->imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)

设置gt方法。首先这里初始化了一些配置,比较容易忽略的的是PROPOSAL_METHOD设置成了’gt’,值得一提。这个设定是从get_roidb然后追溯到底层数据类pascal-voc得到体现,可以看到imdb(pascal_voc的父类)通过roidb_handler来决定用什么方式生成roidb,默认为selective_search,这里用了gt_roidb。 

->output_dir = get_output_dir(imdb)

确定路径

->roidb = get_training_roidb(imdb)

roidb是通过lib\fast_rcnn\train.py中的get_training_roidb来获取的,这个roidb是一个imdb的成员变量,包含了训练集图片中框出的每个区域。这个函数做了两件事情,一是将原有的roidb中的每张图片进行水平翻转然后添加回roidb中,第二件事是做一些准备工作(有一些让我很无语……),详细的将在相应的文件进行介绍


(3)model_paths = train_net(solver, roidb, output_dir, pretrained_model=init_model, max_iters=max_iters)

而这个 train_net() 函数是从 lib/fast_rcnn 文件夹中的 train.py 中 import 进来的。train.py主要由一个类SolverWrapper和两个函数get_training_roidb()和train_net()组成。 

def train_net(solver_prototxt, roidb, output_dir,
              pretrained_model=None, max_iters=40000):
    """Train a Fast R-CNN network."""

    roidb = filter_roidb(roidb)
    sw = SolverWrapper(solver_prototxt, roidb, output_dir,
                       pretrained_model=pretrained_model)

    print 'Solving...'
    model_paths = sw.train_model(max_iters)
    print 'done solving'
    return model_paths

可以发现,该函数是通过调用类SolverWrapper来实现其主要功能的,类SolverWrapper初始化完成后,就是要调用train_model函数来进行网络训练。

3、rpn_generate函数

该函数的作用就是根据输入的数据与模型与prototxt产生proposal,可作为下一步的训练所用,也可作为测试。 该函数最最核心的一句代码是rpn_proposals = imdb_proposals(rpn_net, imdb),其他的都是作为参数准备,与输出的一些工作。如果仅仅浅尝辄止的读者,知道这个函数的功能就是对每张图片产生最多2000个roi proposal 与对应的scores然后缓存到某个文件夹即可。 希望打破砂锅问到底的读者,可以追到函数里以及prototxt文件去读其实现细节。其中proposal_layer我也会单独开一个篇章来讨论实现中的一些重要工作。

def rpn_generate(queue=None, imdb_name=None, rpn_model_path=None, cfg=None,
                 rpn_test_prototxt=None):
    """Use a trained RPN to generate proposals.
    """
    # 设置cfg文件
    cfg.TEST.RPN_PRE_NMS_TOP_N = -1     # no pre NMS filtering
    cfg.TEST.RPN_POST_NMS_TOP_N = 2000  # limit top boxes after NMS
    print 'RPN model: {}'.format(rpn_model_path)
    print('Using config:')
    pprint.pprint(cfg)

    import caffe
    _init_caffe(cfg)

    # NOTE: the matlab implementation computes proposals on flipped images, too.
    # We compute them on the image once and then flip the already computed
    # proposals. This might cause a minor loss in mAP (less proposal jittering).
    imdb = get_imdb(imdb_name)
    print 'Loaded dataset `{:s}` for proposal generation'.format(imdb.name)

    # Load RPN and configure output directory
    # 创建一个Net实例对象,以下为个人观点,对boost库以及wraper不熟,只说说个人理解,如有错误,后续更正,欢迎指正: 
    #a)会调用_caffe.cpp里面的Net_Init_Load函数,_caffe.cpp里面`.def("__init__", bp::make_constructor(&Net_Init_Load))`,
    #应该是将Net_Init_Load与__init__构造器绑定,所以在创建caffe.Net对象的时候,会调用Net_Init_Load,
    #终端上就会输出网络的初始化信息同时从pretrainedmodel上copy source layers
    rpn_net = caffe.Net(rpn_test_prototxt, rpn_model_path, caffe.TEST)
    output_dir = get_output_dir(imdb)
    print 'Output will be saved to `{:s}`'.format(output_dir)
    # Generate RPN proposals on all images in an imdb.
    #imdb_proposals函数在generate.py文件中
    #rpn_proposals是一个列表的列表,每个子列表
    rpn_proposals = imdb_proposals(rpn_net, imdb)#核心
    # Write proposals to disk and send the proposal file path through the
    # multiprocessing queue
    # splitext函数分离文件的文件名和拓展名
    rpn_net_name = os.path.splitext(os.path.basename(rpn_model_path))[0]
    rpn_proposals_path = os.path.join(
        output_dir, rpn_net_name + '_proposals.pkl')
    with open(rpn_proposals_path, 'wb') as f:
        cPickle.dump(rpn_proposals, f, cPickle.HIGHEST_PROTOCOL)
        # 将python专有的数据结构(比如rpn_proposals是一个列表)序列化,存储在磁盘中
    print 'Wrote RPN proposals to {}'.format(rpn_proposals_path)
    #返回生成的rpn proposal的存储地址。
    queue.put({'proposal_path': rpn_proposals_path})

4、train_fast_rcnn函数

这个函数就是训练fast-rcnn的部分,首先它将产生roidb的方法设置成rpn_roidb,工厂模式的获取roidb思想,在上文已提。 接下来就是准备一些参数、路径等等,用于送入网络训练,最后保存模型。具体的训练细节,需要阅读prototxt文件才能把它的过程弄得水落石出。

def train_fast_rcnn(queue=None, imdb_name=None, init_model=None, solver=None,
                    max_iters=None, cfg=None, rpn_file=None):
    """Train a Fast R-CNN using proposals generated by an RPN.
    """

    #这个参数的设置是为了提高代码的重用性。可以看到其他文件中,train_rpn和train_fast_rcnn的过程在实现时,会有重复代码,故设置该变量将其合并。
    #conv5后面现在接的是fast-rcnn
    cfg.TRAIN.HAS_RPN = False           # not generating prosals on-the-fly
    #roidb由刚刚训练完的RPN产生
    cfg.TRAIN.PROPOSAL_METHOD = 'rpn'   # use pre-computed RPN proposals instead
    #每个mini-batch包含两张图片,以及他们proposal的roi区域
    #每次训练fast-rcnn使用两张图片
    cfg.TRAIN.IMS_PER_BATCH = 2
    print 'Init model: {}'.format(init_model)
    print 'RPN proposals: {}'.format(rpn_file)
    print('Using config:')
    pprint.pprint(cfg)

    import caffe
    _init_caffe(cfg)

    roidb, imdb = get_roidb(imdb_name, rpn_file=rpn_file)
    output_dir = get_output_dir(imdb)
    print 'Output will be saved to `{:s}`'.format(output_dir)
    # Train Fast R-CNN
    model_paths = train_net(solver, roidb, output_dir,
                            pretrained_model=init_model,
                            max_iters=max_iters)
    # Cleanup all but the final model
    for i in model_paths[:-1]:
        os.remove(i)
    fast_rcnn_model_path = model_paths[-1]
    # Send Fast R-CNN model path over the multiprocessing queue
    queue.put({'model_path': fast_rcnn_model_path})

5、数据准备

不过,关于Fast-RCNN的重头戏我们其实还没开始——那就是如何准备训练数据。

在上面介绍训练的流程中,与此相关的函数是:imdb= get_imdb(args.imdb_name)

这个函数是从从lib/datasets/文件夹中的factory.py中import进来的,我们来看一下这个函数:

def get_imdb(name):
    """Get an imdb (image database) by name."""
    if not __sets.has_key(name):
        raise KeyError('Unknown dataset: {}'.format(name))
    return __sets[name]()
这个函数很简单,其实就是根据字典的key来取得训练数据。 那么这个字典是怎么形成的呢?看下面:

for year in ['2007', '2012']:
    for split in ['train', 'val', 'trainval', 'test']:
        name = 'voc_{}_{}'.format(year, split)
        __sets[name] = (lambda split=split, year=year: pascal_voc(split, year))
它本质上是通过lib/datasets/文件夹下面的pascal_voc.py引入的。所以,现在我们就得开始进入pascal_voc.py:

def __init__(self, image_set, year, devkit_path=None):
        imdb.__init__(self, 'voc_' + year + '_' + image_set)
        self._year = year
        self._image_set = image_set
        self._devkit_path = self._get_default_path() if devkit_path is None \
                            else devkit_path
        self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)
        self._classes = ('__background__', # always index 0
                         'aeroplane', 'bicycle', 'bird', 'boat',
                         'bottle', 'bus', 'car', 'cat', 'chair',
                         'cow', 'diningtable', 'dog', 'horse',
                         'motorbike', 'person', 'pottedplant',
                         'sheep', 'sofa', 'train', 'tvmonitor')
        self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))
        self._image_ext = '.jpg'
        self._image_index = self._load_image_set_index()
        # Default to roidb handler
        # self.selective_search_roidb是一个函数对象,把这个函数对象付给_roidb_handler属性
        self._roidb_handler = self.selective_search_roidb
        self._salt = str(uuid.uuid4())
        self._comp_id = 'comp4'

        # PASCAL specific config options
        self.config = {'cleanup'     : True,
                       'use_salt'    : True,
                       'use_diff'    : False,
                       'matlab_eval' : False,
                       'rpn_file'    : None,
                       'min_size'    : 2}

        assert os.path.exists(self._devkit_path), \
                'VOCdevkit path does not exist: {}'.format(self._devkit_path)
        assert os.path.exists(self._data_path), \
                'Path does not exist: {}'.format(self._data_path)

在初始化自身的同时,先调用了父类的初始化方法,将imdb_name传入,例如(‘voc_2007_trainval’)
下面是成员变量的初始化:
{
    year:’2007’
    image _set:’trainval’
    devkit _path:’data/VOCdevkit2007’
    data _path:’data /VOCdevkit2007/VOC2007’
    classes:(…)_如果想要训练自己的数据,需要修改这里_
    class _to _ind:{…} _一个将类名转换成下标的字典 _
    image _ext:’.jpg’
    image _index: [‘000001’,’000003’,……]_根据trainval.txt获取到的image索引_
    roidb _handler: <Method gt_roidb >
    salt:  <Object uuid >
    comp _id:’comp4’
    config:{…}
}

类 init 构造完成后,会调用函数 roidb,这个函数是从类 imdb 中继承过来的,这个函数会调用 _roidb_handler 来处理,其中 _roidb_handler=self.selective_search_roidb,下面我们来看看这个函数:

    def selective_search_roidb(self):
        """
        Return the database of selective search regions of interest.
        Ground-truth ROIs are also included.

        This function loads/saves from/to a cache file to speed up future calls.
        """
        cache_file = os.path.join(self.cache_path,
                                  self.name + '_selective_search_roidb.pkl')

        if os.path.exists(cache_file):
            with open(cache_file, 'rb') as fid:
                roidb = cPickle.load(fid)
            print '{} ss roidb loaded from {}'.format(self.name, cache_file)
            return roidb

        if int(self._year) == 2007 or self._image_set != 'test':
            gt_roidb = self.gt_roidb()
            ss_roidb = self._load_selective_search_roidb(gt_roidb)
            roidb = imdb.merge_roidbs(gt_roidb, ss_roidb)
        else:
            roidb = self._load_selective_search_roidb(None)
        with open(cache_file, 'wb') as fid:
            cPickle.dump(roidb, fid, cPickle.HIGHEST_PROTOCOL)
        print 'wrote ss roidb to {}'.format(cache_file)

        return roidb
这个函数在训练阶段会首先调用 get_roidb()  函数:

def rpn_roidb(self):
        if int(self._year) == 2007 or self._image_set != 'test':
            gt_roidb = self.gt_roidb()
            # 求取rpn_roidb需要以gt_roidb作为参数才能得到
            rpn_roidb = self._load_rpn_roidb(gt_roidb)
            roidb = imdb.merge_roidbs(gt_roidb, rpn_roidb)
        else:
            roidb = self._load_rpn_roidb(None)

        return roidb
如果存在cache_file,那么get_roidb()就会直接从cache_file中读取信息;如果不存在cache_file,那么会调用_load_pascal_annotation()来取得标注信息。_load_pascal_annotation函数如下所示:

 def _load_pascal_annotation(self, index):
        """
        Load image and bounding boxes info from XML file in the PASCAL VOC
        format.
        """
        filename = os.path.join(self._data_path, 'Annotations', index + '.xml')
        tree = ET.parse(filename)
        objs = tree.findall('object')
        if not self.config['use_diff']:
            # Exclude the samples labeled as difficult
            non_diff_objs = [
                obj for obj in objs if int(obj.find('difficult').text) == 0]
            # if len(non_diff_objs) != len(objs):
            #     print 'Removed {} difficult objects'.format(
            #         len(objs) - len(non_diff_objs))
            objs = non_diff_objs
        num_objs = len(objs)

        boxes = np.zeros((num_objs, 4), dtype=np.uint16)
        gt_classes = np.zeros((num_objs), dtype=np.int32)
        overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
        # "Seg" area for pascal is just the box area
        seg_areas = np.zeros((num_objs), dtype=np.float32)

        # Load object bounding boxes into a data frame.
        for ix, obj in enumerate(objs):
            bbox = obj.find('bndbox')
            # Make pixel indexes 0-based
            x1 = float(bbox.find('xmin').text) - 1
            y1 = float(bbox.find('ymin').text) - 1
            x2 = float(bbox.find('xmax').text) - 1
            y2 = float(bbox.find('ymax').text) - 1
            cls = self._class_to_ind[obj.find('name').text.lower().strip()]
            boxes[ix, :] = [x1, y1, x2, y2]
            gt_classes[ix] = cls
            # 从anatation直接载入图像的信息,因为本身就是ground-truth , 所以overlap直接设为1
            overlaps[ix, cls] = 1.0
            seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)
        # overlaps为 num_objs * K 的数组, K表示总共的类别数, num_objs表示当前这张图片中box的个数
        overlaps = scipy.sparse.csr_matrix(overlaps)

        return {'boxes' : boxes,
                'gt_classes': gt_classes,
                'gt_overlaps' : overlaps,
                'flipped' : False,
                'seg_areas' : seg_areas}
当处理完标注的数据后,接下来就要载入SS阶段获得的数据,通过如下函数完成:

    def _load_selective_search_roidb(self, gt_roidb):
        filename = os.path.abspath(os.path.join(cfg.DATA_DIR,
                                                'selective_search_data',
                                                self.name + '.mat'))
        assert os.path.exists(filename), \
               'Selective search data not found at: {}'.format(filename)
        raw_data = sio.loadmat(filename)['boxes'].ravel()

        box_list = []
        for i in xrange(raw_data.shape[0]):
            boxes = raw_data[i][:, (1, 0, 3, 2)] - 1
            keep = ds_utils.unique_boxes(boxes)
            boxes = boxes[keep, :]
            keep = ds_utils.filter_small_boxes(boxes, self.config['min_size'])
            boxes = boxes[keep, :]
            box_list.append(boxes)

        return self.create_roidb_from_box_list(box_list, gt_roidb)
    # 从XML文件载入图像信息,而且是ground-truth信息,比如boxes   
有一点需要注意的是,ss中获得的box的值,和fast-rcnn中认为的box值有点差别,那就是你需要交换box的x和y坐标。




评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值