目录
--tools--train_net.py PS:这个是训练的主程序
--tools--train_net.py PS:这个是训练的主程序
参数输入部分
def parse_args():
"""
Parse input arguments
"""
parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
parser.add_argument('--gpu', dest='gpu_id',
help='GPU device id to use [0]',
default=0, type=int)
parser...............省略..................................
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
args = parser.parse_args()
return args
#这里是人机交互的界面,gpu是外面输入的关键字,gpu_id是程序里面的关键字
参数修改部分
if args.cfg_file is not None:
cfg_from_file(args.cfg_file)
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs)
cfg.GPU_ID = args.gpu_id
#一种写好cfg_file文件,导入。另一种是直接赋值
cfg_file文件:
EXP_DIR: faster_rcnn_end2end
TRAIN:
HAS_RPN: True
IMS_PER_BATCH: 1
BBOX_NORMALIZE_TARGETS_PRECOMPUTED: True
RPN_POSITIVE_OVERLAP: 0.7
RPN_BATCHSIZE: 256
PROPOSAL_METHOD: gt
BG_THRESH_LO: 0.0
TEST:
HAS_RPN: True
文件导入程序cfg_from_file():
def _merge_a_into_b(a, b):
"""Merge config dictionary a into config dictionary b, clobbering the
options in b whenever they are also specified in a.
"""
if type(a) is not edict:
return
for k, v in a.iteritems():
# a must specify keys that are in b
if not b.has_key(k):
raise KeyError('{} is not a valid config key'.format(k))
# the types must match, too
old_type = type(b[k])
if old_type is not type(v):
if isinstance(b[k], np.ndarray):
v = np.array(v, dtype=b[k].dtype)
else:
raise ValueError(('Type mismatch ({} vs. {}) '
'for config key: {}').format(type(b[k]),
type(v), k))
# recursively merge dicts
if type(v) is edict:
try:
_merge_a_into_b(a[k], b[k])
except:
print('Error under config key: {}'.format(k))
raise
else:
b[k] = v
def cfg_from_file(filename):
"""Load a config file and merge it into the default options."""
import yaml
with open(filename, 'r') as f:
yaml_cfg = edict(yaml.load(f))
_merge_a_into_b(yaml_cfg, __C)
获取数据和训练网络
imdb, roidb = combined_roidb(args.imdb_name) #获取图片和真实框
print '{:d} roidb entries'.format(len(roidb))
output_dir = get_output_dir(imdb) #获取输出路径
print 'Output will be saved to `{:s}`'.format(output_dir)
train_net(args.solver, roidb, output_dir, #训练网络
pretrained_model=args.pretrained_model,
max_iters=args.max_iters)
获取图片和真实框
def combined_roidb(imdb_names):
def get_roidb(imdb_name):
imdb = get_imdb(imdb_name)
print 'Loaded dataset `{:s}` for training'.format(imdb.name)
imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
print 'Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)
roidb = get_training_roidb(imdb)
return roidb
roidbs = [get_roidb(s) for s in imdb_names.split('+')]
#这里的imdb_nams是voc_2007_trainval+voc_2012_trainval
roidb = roidbs[0]
if len(roidbs) > 1:
for r in roidbs[1:]:
roidb.extend(r)
imdb = datasets.imdb.imdb(imdb_names)
else:
imdb = get_imdb(imdb_names)
return imdb, roidb
这里主要设计到,图片的读写和接口的写法,参考。