@: train_faster_rcnn_alt_opt.py
前言:
本文主要功能是实现Alternating optimization的训练方法,在NIPS的那篇paper上仅有较为简略的介绍,所以我认为通过源码来学习是很有必要的。
其他主要模块的解读也会在后续陆续放出。本人水平有限,如果有一些理解有偏差;或者说有些重要的点被一笔带过,读者想要深入了解的,可以在评论区留言,共同交流,共同进步!
该文件包含以下函数,并作简要介绍:
- parse_args()
传递参数
- get_roidb(imdb_name, rpn_file= None)
获取roidb与lmdb
- get_solvers(net_name)
获取solvers
- _init_caffe(cfg)
根据config加载caffe对象
- train_rpn(queue, imdb_name, init_model, solver=None,
max_iters, cfg)
训练RPN网络
- rpn_generate(queue, imdb_name, rpn_model_path, cfg,
rpn_test_prototxt)
用RPN网络生产proposal
- train_fast_rcnn(queue, imdb_name, init_model, solver,
max_iters, cfg, rpn_file)
用RPN网络产生的proposal来训练fast_rcnn网络
- 最后是主函数,详细的描述了四个训练步骤
get_roidb
imdb根据imdb_name(默认是“voc_2007_trainval)来获取,这里的imdb对象的获取采用了工厂模式,由\lib\datasets\factory.py生成,根据年份(2007)与切分的数据集(trainval)返回pascal_voc对象,pascal_voc与coco都继承于imdb对象。(\lib\datasets\pascal_voc.py+coco.py)
roidb是通过lib\faster_rcnn\train.py中的get_training_roidb来获取的,这个roidb是一个imdb的成员变量,包含了训练集图片中框出的每个区域。这个函数做了两件事情,一是将原有的roidb中的每张图片进行水平翻转然后添加回roidb中,第二件事是做一些准备工作(有一些让我很无语……),详细的将在相应的文件进行介绍
def get_roidb(imdb_name, rpn_file=None):
imdb = get_imdb(imdb_name)
print 'Loaded dataset {:s} for training'.format(imdb.name)
imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
print 'Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)
if rpn_file is not None:
imdb.config['rpn_file'] = rpn_file
roidb = get_training_roidb(imdb)
return roidb, imdb
get_solvers
在models/pascal_voc/netname/faster_rcnn_alt_opt文件夹下有stage1_rpn_solver60k80k.pt等不同阶段所对应的solver文件,并定义了各个阶段最大迭代次数,这里支持的net_name有VGG16、VGG_CNN_M_1024、ZF三种
def get\_solvers(net\_name):
# Faster R-CNN Alternating Optimization
n = 'faster_rcnn_alt_opt'
# Solver for each training stage
solvers = [[net_name, n, 'stage1_rpn_solver60k80k.pt'],
[net_name, n, 'stage1_fast_rcnn_solver30k40k.pt'],
[net_name, n, 'stage2_rpn_solver60k80k.pt'],
[net_name, n, 'stage2_fast_rcnn_solver30k40k.pt']]
solvers = [os.path.join(cfg.MODELS_DIR, *s) for s in solvers]
# Iterations for each training stage
max_iters = [80000, 40000, 80000, 40000]
# max_iters = [100, 100, 100, 100]
# Test prototxt for the RPN
rpn_test_prototxt = os.path.join(
cfg.MODELS_DIR, net_name, n, 'rpn_test.pt')
return solvers, max_iters, rpn_test_prototxt
_init_caffe
该函数作用便是初始化caffe对象,仅做了两步操作,第一步是初始化随机种子,第二步是设置GPU。