py-faster-rcnn源码解读系列(一)——train_faster_rcnn_alt_opt.py

本文详细解读py-faster-rcnn的train_faster_rcnn_alt_opt.py源码,涵盖get_roidb、get_solvers、train_rpn和train_fast_rcnn等关键函数,深入理解RPN和Fast R-CNN的训练流程。
摘要由CSDN通过智能技术生成

@: train_faster_rcnn_alt_opt.py

前言:

本文主要功能是实现Alternating optimization的训练方法,在NIPS的那篇paper上仅有较为简略的介绍,所以我认为通过源码来学习是很有必要的。
其他主要模块的解读也会在后续陆续放出。本人水平有限,如果有一些理解有偏差;或者说有些重要的点被一笔带过,读者想要深入了解的,可以在评论区留言,共同交流,共同进步!

该文件包含以下函数,并作简要介绍:
- parse_args()
传递参数
- get_roidb(imdb_name, rpn_file= None)
获取roidb与lmdb
- get_solvers(net_name)
获取solvers
- _init_caffe(cfg)
根据config加载caffe对象
- train_rpn(queue, imdb_name, init_model, solver=None,
 max_iters, cfg)
训练RPN网络
- rpn_generate(queue, imdb_name, rpn_model_path, cfg,
 rpn_test_prototxt)
 用RPN网络生产proposal
- train_fast_rcnn(queue, imdb_name, init_model, solver,
 max_iters, cfg, rpn_file)
用RPN网络产生的proposal来训练fast_rcnn网络
- 最后是主函数,详细的描述了四个训练步骤

get_roidb

imdb根据imdb_name(默认是“voc_2007_trainval)来获取,这里的imdb对象的获取采用了工厂模式,由\lib\datasets\factory.py生成,根据年份(2007)与切分的数据集(trainval)返回pascal_voc对象,pascal_voc与coco都继承于imdb对象。(\lib\datasets\pascal_voc.py+coco.py)

roidb是通过lib\faster_rcnn\train.py中的get_training_roidb来获取的,这个roidb是一个imdb的成员变量,包含了训练集图片中框出的每个区域。这个函数做了两件事情,一是将原有的roidb中的每张图片进行水平翻转然后添加回roidb中,第二件事是做一些准备工作(有一些让我很无语……),详细的将在相应的文件进行介绍

def get_roidb(imdb_name, rpn_file=None):
 imdb = get_imdb(imdb_name)
 print 'Loaded dataset {:s} for training'.format(imdb.name)
 imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
 print 'Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)
 if rpn_file is not None:
     imdb.config['rpn_file'] = rpn_file
 roidb = get_training_roidb(imdb)
 return roidb, imdb

get_solvers

在models/pascal_voc/netname/faster_rcnn_alt_opt文件夹下有stage1_rpn_solver60k80k.pt等不同阶段所对应的solver文件,并定义了各个阶段最大迭代次数,这里支持的net_name有VGG16、VGG_CNN_M_1024、ZF三种

def get\_solvers(net\_name):
 # Faster R-CNN Alternating Optimization
 n = 'faster_rcnn_alt_opt'
 # Solver for each training stage
 solvers = [[net_name, n, 'stage1_rpn_solver60k80k.pt'],
            [net_name, n, 'stage1_fast_rcnn_solver30k40k.pt'],
            [net_name, n, 'stage2_rpn_solver60k80k.pt'],
            [net_name, n, 'stage2_fast_rcnn_solver30k40k.pt']]
 solvers = [os.path.join(cfg.MODELS_DIR, *s) for s in solvers]
 # Iterations for each training stage
 max_iters = [80000, 40000, 80000, 40000]
 # max_iters = [100, 100, 100, 100]
 # Test prototxt for the RPN
 rpn_test_prototxt = os.path.join(
     cfg.MODELS_DIR, net_name, n, 'rpn_test.pt')
 return solvers, max_iters, rpn_test_prototxt

_init_caffe

该函数作用便是初始化caffe对象,仅做了两步操作,第一步是初始化随机种子,第二步是设置GPU。

  • 3
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值