py-faster-rcnn源码解析之处理训练数据

本文详述了py-faster-rcnn在训练模型时处理数据的过程,从错误入手,深入剖析了imdb和roidb的生成,包括get_roidb、get_imdb等关键函数,揭示了数据预处理的各个环节,如目标框调整、数据增强等,最后解释了图像数据在神经网络前向传播时的读取方式。
摘要由CSDN通过智能技术生成

因为最近在使用py-faster-rcnn训练自己的数据时报如下错:

roidb[i]['image'] = imdb.image_path_at(i) 
IndexError: list index out of range 

看了网上的很多说法都是让删除py-faster-rcnn/data/cache下的pkl文件,但是该方法对我并没有起作用,于是就将py-faster-rcnn处理训练数据部分的代码跟踪了一下,这里和大家一起分享,也做个记录。
下面的解说都是以py-faster-rcnn目录为根目录,后面就不再重复了。
我是通过执行scripts/faster_rcnn_alt_opt.sh来训练模型的,从该脚本的第46行代码:

time ./tools/train_faster_rcnn_alt_opt.py --gpu ${GPU_ID} \

我们可以知道模型是通过tools/train_faster_rcnn_alt_opt.py进行训练的,接下里我们就去看这个py文件的源码。
模型训练分为两个Stage,每个Stage都是从RPN训练开始的,所以我们直接看train_rpn函数:

def train_rpn(queue=None, imdb_name=None, init_model=None, solver=None,
              max_iters=None, cfg=None):
    """Train a Region Proposal Network in a separate training process.
    """

    # Not using any proposals, just ground-truth boxes
    cfg.TRAIN.HAS_RPN = True
    cfg.TRAIN.BBOX_REG = False  # applies only to Fast R-CNN bbox regression
    cfg.TRAIN.PROPOSAL_METHOD = 'gt'
    cfg.TRAIN.IMS_PER_BATCH = 1
    print 'Init model: {}'.format(init_model)
    print('Using config:')
    pprint.pprint(cfg)

    import caffe
    _init_caffe(cfg)

    roidb, imdb = get_roidb(imdb_name)
    print 'roidb len: {}'.format(len(roidb))
    output_dir = get_output_dir(imdb)
    print 'Output will be saved to `{:s}`'.format(output_dir)

    model_paths = train_net(solver, roidb, output_dir,
                            pretrained_model=init_model,
                            max_iters=max_iters)
    # Cleanup all but the final model
    for i in model_paths[:-1]:
        os.remove(i)
    rpn_model_path = model_paths[-1]
    # Send final model path through the multiprocessing queue
    queue.put({
   'model_path': rpn_model_path})

前面的几行代码是进行训练的配置,一直到这几行代码开始准备数据:

    roidb, imdb = get_roidb(imdb_name)
    print 'roidb len: {}'.format(len(roidb))
    output_dir = get_output_dir(imdb)
    print 'Output will be saved to `{:s}`'.format(output_dir)

所以我们再跳去get_roidb函数去看它是如何实现的:

def get_roidb(imdb_name, rpn_file=None):
    imdb = get_imdb(imdb_name)
    print 'Loaded dataset `{:s}` for training'.format(imdb.name)
    imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
    print 'Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)
    if rpn_file is not None:
        imdb.config['rpn_file'] = rpn_file
    roidb = get_training_roidb(imdb)
    return roidb, imdb

Stage1 RPN, init from ImageNet model时输入参数imdb_name是voc_2007_trainval,rpn_file是None。从这个函数我们能够得到的信息是roidb是与imdb相关的,下面我们先看imdb是怎么得到的,即先看get_imdb函数,这个函数的代码在lib/datasets/factory.py中:

def get_imdb(name):
    """Get an imdb (image database) by name."""
    if not __sets.has_key(name):
        raise KeyError('Unknown dataset: {}'.format(name))
    
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值