# train_net.py
#!/usr/bin/env python # -------------------------------------------------------- # Fast R-CNN # Copyright (c) 2015 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ross Girshick # -------------------------------------------------------- """Train a Fast R-CNN network on a region of interest database.""" import _init_paths from fast_rcnn.train import get_training_roidb, train_net from fast_rcnn.config import cfg, cfg_from_file, cfg_from_list, get_output_dir from datasets.factory import get_imdb import datasets.imdb import caffe import argparse import pprint import numpy as np import sys def parse_args():
# 运行时命令行参数 """ Parse input arguments """ parser = argparse.ArgumentParser(description='Train a Fast R-CNN network') parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]', default=0, type=int) parser.add_argument('--solver', dest='solver', help='solver prototxt', default=None, type=str) parser.add_argument('--iters', dest='max_iters', help='number of iterations to train', default=40000, type=int) parser.add_argument('--weights', dest='pretrained_model', help='initialize with pretrained model weights', default=None, type=str) parser.add_argument('--cfg', dest='cfg_file', help='optional config file', default=None, type=str) parser.add_argument('--imdb', dest='imdb_name', help='dataset to train on', default='voc_2007_trainval', type=str) parser.add_argument('--rand', dest='randomize', help='randomize (do not use a fixed seed)', action='store_true') parser.add_argument('--set', dest='set_cfgs', help='set config keys', default=None, nargs=argparse.REMAINDER) if len(sys.argv) == 1: parser.print_help() sys.exit(1) args = parser.parse_args() return args def combined_roidb(imdb_names):
# 融合roidb,roidb来自于数据集(实验可能用到多个),所以需要combine多个数据集的roidb
# def get_roidb(imdb_name): imdb = get_imdb(imdb_name) print 'Loaded dataset `{:s}` for training'.format(imdb.name)
# 设置proposal方法,这里是selective search(config.py) imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD) print 'Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)
# 得到用于训练的roidb,定义在train.py,进行了水平翻转,以及为原始roidb添加了一些说明性的属性 roidb = get_training_roidb(imdb) return roidb roidbs = [get_roidb(s) for s in imdb_names.split('+')] roidb = roidbs[0]
# 这里进行combine roidb if len(roidbs) > 1: for r in roidbs[1:]: roidb.extend(r) imdb = datasets.imdb.imdb(imdb_names) else:
# get_imdb方法定义在dataset/factory.py,通过名字得到imdb imdb = get_imdb(imdb_names) return imdb, roidb if __name__ == '__main__': args = parse_args() print('Called with args:') print(args)
#定义在config.py,见py-faster-rcnn代码阅读2-config.py
if args.cfg_file is not None: cfg_from_file(args.cfg_file) if args.set_cfgs is not None: cfg_from_list(args.set_cfgs) cfg.GPU_ID = args.gpu_id print('Using config:') pprint.pprint(cfg) if not args.randomize: # fix the random seeds (numpy and caffe) for reproducibility np.random.seed(cfg.RNG_SEED) caffe.set_random_seed(cfg.RNG_SEED) # set up caffe caffe.set_mode_gpu() caffe.set_device(args.gpu_id) imdb, roidb = combined_roidb(args.imdb_name) print '{:d} roidb entries'.format(len(roidb)) output_dir = get_output_dir(imdb) print 'Output will be saved to `{:s}`'.format(output_dir) #定义在train.py train_net(args.solver, roidb, output_dir, pretrained_model=args.pretrained_model, max_iters=args.max_iters)
# train.py """Train a Fast R-CNN network.""" import caffe from fast_rcnn.config import cfg import roi_data_layer.roidb as rdl_roidb from utils.timer import Timer import numpy as np import os from caffe.proto import caffe_pb2 import google.protobuf as pb2 class SolverWrapper(object): # 对caffe solver的简单封装,封装允许我们控制snapshot过程,用于去除归一化学习到的bb回归权重 """A simple wrapper around Caffe's solver. This wrapper gives us control over the snapshotting process, which we use to unnormalize the learned bounding-box regression weights. """ def __init__(self, solver_prototxt, roidb, output_dir, pretrained_model=None): """Initialize the SolverWrapper.""" self.output_dir = output_dir if (cfg.TRAIN.HAS_RPN and cfg.TRAIN.BBOX_REG and cfg.TRAIN.BBOX_NORMALIZE_TARGETS): # RPN can only use precomputed normalization because there are no # fixed statistics to compute a priori assert cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED # 求bb回归器的target,返回box的均值和标准差 if cfg.TRAIN.BBOX_REG: print 'Computing bounding-box regression targets...' self.bbox_means, self.bbox_stds = \ rdl_roidb.add_bbox_regression_targets(roidb) print 'done' # solver,copy预训练模型的权重,set roidb self.solver = caffe.SGDSolver(solver_prototxt) if pretrained_model is not None: print ('Loading pretrained model ' 'weights from {:s}').format(pretrained_model) self.solver.net.copy_from(pretrained_model) self.solver_param = caffe_pb2.SolverParameter() with open(solver_prototxt, 'rt') as f: pb2.text_format.Merge(f.read(), self.solver_param) self.solver.net.layers[0].set_roidb(roidb) def snapshot(self): """Take a snapshot of the network after unnormalizing the learned bounding-box regression weights. This enables easy use at test-time. """ # 给snapshot的bb预测参数做去归一化 net = self.solver.net scale_bbox_params = (cfg.TRAIN.BBOX_REG and cfg.TRAIN.BBOX_NORMALIZE_TARGETS and net.params.has_key('bbox_pred')) if scale_bbox_params: # save original values orig_0 = net.params['bbox_pred'][0].data.copy() orig_1 = net.params['bbox_pred'][1].data.copy() # scale and shift with bbox reg unnormalization; then save snapshot # 去除归一化,乘标准差,加上均值 net.params['bbox_pred'][0].data[...] = \ (net.params['bbox_pred'][0].data * self.bbox_stds[:, np.newaxis]) net.params['bbox_pred'][1].data[...] = \ (net.params['bbox_pred'][1].data * self.bbox_stds + self.bbox_means) # 给snapshot命名 infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX if cfg.TRAIN.SNAPSHOT_INFIX != '' else '') filename = (self.solver_param.snapshot_prefix + infix + '_iter_{:d}'.format(self.solver.iter) + '.caffemodel') filename = os.path.join(self.output_dir, filename) # 存snapshot net.save(str(filename)) print 'Wrote snapshot to: {:s}'.format(filename) # 只是存入的snapshot的bb参数做了去归一化用于测试,但是训练部分仍需要保持归一化的状态 if scale_bbox_params: # restore net to original state net.params['bbox_pred'][0].data[...] = orig_0 net.params['bbox_pred'][1].data[...] = orig_1 return filename def train_model(self, max_iters): """Network training loop.""" last_snapshot_iter = -1 timer = Timer() model_paths = [] while self.solver.iter < max_iters: # Make one SGD update timer.tic() self.solver.step(1) timer.toc() if self.solver.iter % (10 * self.solver_param.display) == 0: print 'speed: {:.3f}s / iter'.format(timer.average_time) # 达到预设次数保存snapshot if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0: last_snapshot_iter = self.solver.iter model_paths.append(self.snapshot()) # 整体迭代完成后也要存snapshot if last_snapshot_iter != self.solver.iter: model_paths.append(self.snapshot()) return model_paths def get_training_roidb(imdb): """Returns a roidb (Region of Interest database) for use in training.""" if cfg.TRAIN.USE_FLIPPED: print 'Appending horizontally-flipped training examples...' # 水平翻转,得到更多的roidb imdb.append_flipped_images() print 'done' print 'Preparing training data...' # 为原始数据集的roidb添加一些说明性的属性,max-overlap,max-classes... rdl_roidb.prepare_roidb(imdb) print 'done' return imdb.roidb def filter_roidb(roidb): """Remove roidb entries that have no usable RoIs.""" # 删掉没用的RoIs, 有效的图片必须各有前景和背景ROI def is_valid(entry): # Valid images have: # (1) At least one foreground RoI OR # (2) At least one background RoI overlaps = entry['max_overlaps'] # find boxes with sufficient overlap fg_inds = np.where(overlaps >= cfg.TRAIN.FG_THRESH)[0] # Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI) bg_inds = np.where((overlaps < cfg.TRAIN.BG_THRESH_HI) & (overlaps >= cfg.TRAIN.BG_THRESH_LO))[0] # image is only valid if such boxes exist valid = len(fg_inds) > 0 or len(bg_inds) > 0 return valid num = len(roidb) filtered_roidb = [entry for entry in roidb if is_valid(entry)] num_after = len(filtered_roidb) print 'Filtered {} roidb entries: {} -> {}'.format(num - num_after, num, num_after) return filtered_roidb def train_net(solver_prototxt, roidb, output_dir, pretrained_model=None, max_iters=40000): """Train a Fast R-CNN network.""" # 训练Fast R-CNN roidb = filter_roidb(roidb) sw = SolverWrapper(solver_prototxt, roidb, output_dir, pretrained_model=pretrained_model) print 'Solving...' # 训练... model_paths = sw.train_model(max_iters) print 'done solving' return model_paths