论文地址
代码地址
视频讲解
代码包结构介绍
- data (存放训练和测试数据)
- docker (存放docker文件,配置环境用)
- experiments (存放些参数文件、日志、实验脚本)
- lib (库文件)
- tools (训练测试评估脚本)
tools文件夹下的trainval_net.py文件的main函数
首先从tools文件夹下的trainval_net.py文件的main函数开始读代码
if __name__ == '__main__':
args = parse_args() #---参数解析
print('Called with args:')
print(args) #---显示参数
parse_args()函数
def parse_args():
"""
Parse input arguments
"""
parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
parser.add_argument('--cfg', dest='cfg_file',
help='optional config file',
default=None, type=str)
parser.add_argument('--weight', dest='weight',
help='initialize with pretrained model weights',
type=str)
parser.add_argument('--imdb', dest='imdb_name',
help='dataset to train on',
default='voc_2007_trainval', type=str)
parser.add_argument('--imdbval', dest='imdbval_name',
help='dataset to validate on',
default='voc_2007_test', type=str)
parser.add_argument('--iters', dest='max_iters',
help='number of iterations to train',
default=70000, type=int)
parser.add_argument('--tag', dest='tag',
help='tag of the model',
default=None, type=str)
parser.add_argument('--net', dest='net',
help='vgg16, res50, res101, res152, mobile',
default='res50', type=str)
parser.add_argument('--set', dest='set_cfgs',
help='set config keys', default=None,
nargs=argparse.REMAINDER)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
args = parser.parse_args()
return args
本段代码主要是 ‘argparse’ 这个包的使用。
argparse 在官方解释中为:命令行选项与参数解析(译),是为程序提供一些可选择的参数保存便于调用修改等等。
接下来main中的代码为三个子程序cfg_from_file(args.cfg_file),cfg_from_list(args.set_cfgs),cfg,这两个子程序以及变量路径为lib/model/config.py
if args.cfg_file is not None:
cfg_from_file(args.cfg_file)
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs)
print('Using config:')
pprint.pprint(cfg)
首先看一下: cfg_from_file(args.cfg_file)
def cfg_from_file(filename):
"""Load a config file and merge it into the default options."""
import yaml #---加载yaml模块
with open(filename, 'r') as f:
yaml_cfg = edict(yaml.load(f)) #---打开文件,并且转化为字典形式的数据结构
_merge_a_into_b(yaml_cfg, __C) #---将字典 yaml_cfg 合并到 __C
看一下字典 __C 的数据内容
from easydict import EasyDict as edict #---easydict的作用:可以使得以属性的方式去访问字典的值!
了解模块类edict应用列子,请查看博客: link
__C = edict() #---将__C声明成一个字典
cfg = __C #---将__C的地址赋给cfg,修改__C,等同于修改cfg。
#---需要调用cfg使用语句:from fast_rcnn_config import cfg
__C.TRAIN = edict() #---将__C.TRAIN声明成一个字典,存放的是Training 训练参数
__C.TRAIN.LEARNING_RATE = 0.001 #---初始化学习率 Initial learning rate
__C.TRAIN.MOMENTUM = 0.9 #---是梯度下降法中一种常用的加速参数
__C.TRAIN.WEIGHT_DECAY = 0.0001
__C.TRAIN.GAMMA = 0.1
__C.TRAIN.STEPSIZE = [30000]
__C.TRAIN.DISPLAY = 10
__C.TRAIN.DOUBLE_BIAS = True
__C.TRAIN.TRUNCATED = False
__C.TRAIN.BIAS_DECAY = False
__C.TRAIN.USE_GT = False
__C.TRAIN.ASPECT_GROUPING = False
__C.TRAIN.SNAPSHOT_KEPT = 3
__C.TRAIN.SUMMARY_INTERVAL = 180
__C.TRAIN.SCALES = (600,)
__C.TRAIN.MAX_SIZE = 1000
__C.TRAIN.IMS_PER_BATCH = 1
__C.TRAIN.BATCH_SIZE = 128
__C.TRAIN.FG_FRACTION = 0.25
__C.TRAIN.FG_THRESH = 0.5
__C.TRAIN.BG_THRESH_HI = 0.5
__C.TRAIN.BG_THRESH_LO = 0.1
__C.TRAIN.USE_FLIPPED = True
__C.TRAIN.BBOX_REG = True
__C.TRAIN.BBOX_THRESH = 0.5
__C.TRAIN.SNAPSHOT_ITERS = 5000
__C.TRAIN.SNAPSHOT_PREFIX = 'res101_faster_rcnn'
__C.TRAIN.BBOX_NORMALIZE_TARGETS = True
__C.TRAIN.BBOX_INSIDE_WEIGHTS = (1.0, 1.0, 1.0, 1.0)
__C.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED = True
__C.TRAIN.BBOX_NORMALIZE_MEANS = (0.0, 0.0, 0.0, 0.0)
__C.TRAIN.BBOX_NORMALIZE_STDS = (0.1, 0.1, 0.2, 0.2)
__C.TRAIN.PROPOSAL_METHOD = 'gt'
__C.TRAIN.HAS_RPN = True
__C.TRAIN.RPN_POSITIVE_OVERLAP = 0.7
__C.TRAIN.RPN_NEGATIVE_OVERLAP = 0.3
__C.TRAIN.RPN_CLOBBER_POSITIVES = False
__C.TRAIN.RPN_FG_FRACTION = 0.5
__C.TRAIN.RPN_BATCHSIZE = 256
__C.TRAIN.RPN_NMS_THRESH = 0.7
__C.TRAIN.RPN_PRE_NMS_TOP_N = 12000
__C.TRAIN.RPN_POST_NMS_TOP_N = 2000
__C.TRAIN.RPN_BBOX_INSIDE_WEIGHTS = (1.0, 1.0, 1.0, 1.0)
__C.TRAIN.RPN_POSITIVE_WEIGHT = -1.0
__C.TRAIN.USE_ALL_GT = True
__C.TEST = edict() #---将__C.TEST声明成一个字典,存放的是测试参数
__C.TEST.SCALES = (600,)
__C.TEST.MAX_SIZE = 1000
__C.TEST.NMS = 0.3
__C.TEST.SVM = False
__C.TEST.BBOX_REG = True
__C.TEST.HAS_RPN = False
__C.TEST.PROPOSAL_METHOD = 'gt'
__C.TEST.RPN_NMS_THRESH = 0.7
__C.TEST.RPN_PRE_NMS_TOP_N = 6000
__C.TEST.RPN_POST_NMS_TOP_N = 300
__C.TEST.MODE = 'nms'
__C.TEST.RPN_TOP_N = 5000
__C.RESNET = edict() #---将__C.RESNET声明成一个字典,存放的是ResNet options
__C.RESNET.MAX_POOL = False
__C.RESNET.FIXED_BLOCKS = 1
__C.MOBILENET = edict() #---将__C.MOBILENET声明成一个字典,存放的是ResNet options
__C.MOBILENET.REGU_DEPTH = False
__C.MOBILENET.FIXED_LAYERS = 5
__C.MOBILENET.WEIGHT_DECAY = 0.00004
__C.MOBILENET.DEPTH_MULTIPLIER = 1.
# Pixel mean values (BGR order) as a (1, 1, 3) array
# We use the same pixel mean for all networks even though it's not exactly what
# they were trained with
__C.PIXEL_MEANS = np.array([[[102.9801, 115.9465, 122.7717]]])
# For reproducibility
__C.RNG_SEED = 3
# Root directory of project
__C.ROOT_DIR = osp.abspath(osp.join(osp.dirname(__file__), '..', '..'))
# Data directory
__C.DATA_DIR = osp.abspath(osp.join(__C.ROOT_DIR, 'data'))
# Name (or path to) the matlab executable
__C.MATLAB = 'matlab'
# Place outputs under an experiments directory
__C.EXP_DIR = 'default'
# Use GPU implementation of non-maximum suppression
__C.USE_GPU_NMS = True
# Use an end-to-end tensorflow model.
# Note: models in E2E tensorflow mode have only been tested in feed-forward mode,
# but these models are exportable to other tensorflow instances as GraphDef files.
__C.USE_E2E_TF = True
# Default pooling mode, only 'crop' is available
__C.POOLING_MODE = 'crop'
# Size of the pooled region after RoI pooling
__C.POOLING_SIZE = 7
# Anchor scales for RPN
__C.ANCHOR_SCALES = [8,16,32]
# Anchor ratios for RPN
__C.ANCHOR_RATIOS = [0.5,1,2]
# Number of filters for the RPN layer
__C.RPN_CHANNELS = 512
接下来看一下: cfg_from_list(args.cfg_list)。cfg_from_list(cfg_list)的函数功能为将cfg_list中的数据赋值给对应的__C中的键名。
def cfg_from_list(cfg_list):
"""Set config keys via list (e.g., from command line).通过列表设置字典的键值"""
from ast import literal_eval
assert len(cfg_list) % 2 == 0
for k, v in zip(cfg_list[0::2], cfg_list[1::2]):
key_list = k.split('.')
d = __C
for subkey in key_list[:-1]:
assert subkey in d
d = d[subkey]
subkey = key_list[-1]
assert subkey in d
try:
value = literal_eval(v)
except:
# handle the case when v is a string literal
value = v
assert type(value) == type(d[subkey]), \
'type {} does not match original type {}'.format(type(value), type(d[subkey]))
d[subkey] = value
程序首先判断cfg_list的长度是否为2的倍数(assert len(cfg_list) % 2 == 0)这里是为了确保后续的程序运行顺利,先进行检查操作。
从cfg_list 中读取k和v(for k, v in zip(cfg_list[0::2], cfg_list[1::2]):)。用一个测试程序让大家明白这是一个怎么样的过程。
a=[0,1,2,3,4,5,6,7,8,9]
a[0::2]
Out[9]: [0, 2, 4, 6, 8]
a[1::2]
Out[10]: [1, 3, 5, 7, 9]
将读取出来的键用“.”分割开来,变成一个list(key_list = k.split(‘.’))。在这里为了便于理解需要观察前面的 __C这个字典,字典中是嵌套着其他的字典,key_list中依次存储字典名
for subkey in key_list[:-1]:
assert d.has_key(subkey)
d = d[subkey]
在此处读取key_list中的最后一个键名(subkey = key_list[-1]),确保最后一层字典中有这个键名(assert d.has_key(subkey))。
如果v为字符类型者使用literal_eval()将v转换为数值型赋值给value,如果v是字符型并且不是一个数则直接将这个数值赋值给value。