tf-faster-rcnn代码理解之trianval_net.py

原始工程代码是通过tf-faster-rcnn\experiments\scripts目录下的train_faster_rcnn.sh调用tf-faster-rcnn\tools\trainval_net.py进行模型训练。为了方便使用pycharm对整个训练工程进行调试,故修改trianval_net.py使之不需要shell脚本引导,可以直接运行。修改之后的代码如下:

# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Zheqi He, Xinlei Chen, based on code from Ross Girshick
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import _init_paths
from model.train_val import get_training_roidb, train_net
from model.config import cfg, cfg_from_file, cfg_from_list, get_output_dir, get_output_tb_dir
from datasets.factory import get_imdb
import datasets.imdb
import argparse
import pprint
import numpy as np
import sys

import tensorflow as tf
from nets.vgg16 import vgg16
from nets.resnet_v1 import resnetv1

class args:
  """
  Parse input arguments
  """
  cfg_file = '/home/whao/tf-faster-rcnn/experiments/cfgs/vgg16.yml'
  weight = '/home/whao/tf-faster-rcnn/data/imagenet_weights/vgg16.ckpt'
  imdb_name = 'voc_2007_trainval'
  imdbval_name = 'voc_2007_test'
  max_iters = 100000
  tag = None
  net = 'vgg16'
  set_cfgs = ['ANCHOR_SCALES', '[8,16,32]', 'ANCHOR_RATIOS', '[0.5,1,2]', 'TRAIN.STEPSIZE', '50000']

def combined_roidb(imdb_names):
  """
  Combine multiple roidbs
  """
  def get_roidb(imdb_name):
    imdb = get_imdb(imdb_name)
    print('Loaded dataset `{:s}` for training'.format(imdb.name))
    imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
    print('Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD))
    roidb = get_training_roidb(imdb)
    return roidb

  roidbs = [get_roidb(s) for s in imdb_names.split('+')]
  roidb = roidbs[0]
  if len(roidbs) > 1:

    for r in roidbs[1:]:
      roidb.extend(r)
    tmp = get_imdb(imdb_names.split('+')[1])
    imdb = datasets.imdb.imdb(imdb_names, tmp.classes)
  else:
    imdb = get_imdb(imdb_names)
  return imdb, roidb

if __name__ == '__main__':
#  args = parse_args()

  print('Called with args:')

  if args.cfg_file is not None:
    cfg_from_file(args.cfg_file)
  if args.set_cfgs is not None:
    cfg_from_list(args.set_cfgs)

  print('Using config:')
  pprint.pprint(cfg)

  np.random.seed(cfg.RNG_SEED)

  # train set
  imdb, roidb = combined_roidb(args.imdb_name)
  print('{:d} roidb entries'.format(len(roidb)))

  # output directory where the models are saved
  output_dir = get_output_dir(imdb, args.tag)
  print('Output will be saved to `{:s}`'.format(output_dir))

  # tensorboard directory where the summaries are saved during training
  tb_dir = get_output_tb_dir(imdb, args.tag)
  print('TensorFlow summaries will be saved to `{:s}`'.format(tb_dir))

  # also add the validation set, but with no flipping images
  orgflip = cfg.TRAIN.USE_FLIPPED
  cfg.TRAIN.USE_FLIPPED = False
  _, valroidb = combined_roidb(args.imdbval_name)
  print('{:d} validation roidb entries'.format(len(valroidb)))
  cfg.TRAIN.USE_FLIPPED = orgflip

  # load network
  if args.net == 'vgg16':
    net = vgg16(batch_size=cfg.TRAIN.IMS_PER_BATCH)
  else:
    raise NotImplementedError
    
  train_net(net, imdb, roidb, valroidb, output_dir, tb_dir,
            pretrained_model=args.weight,
            max_iters=args.max_iters)
以上代码中,定义了args类代替shell传参。首先需要把训练集按照pascal voc的格式处理好,包括文件名与标签个事和Main中的txt的文件。

代码的执行流程是先读取cfg_file所指定的yml文件来配置部分超参量。执行函数为cfg_from_file(args.cfg_file),它把yml中的超参数合并到config.py中定义的__C对象中,它是类EasyDict的对象。

然后,通过cfg_from_list(args.set_cfgs)配置__C对象中的变量。

接下来,开始处理训练集,通过combined_roidb(args.imdb_name)收集训练集,它通过调用lib/datasets/factory.py中的get_imdb()获得数据集,获得类pascal_voc的对象imdb,再设置区域推荐的方式,默认为gt,通过lib/model/train_val.py中的函数get_training_roidb()获得roidb,即每张图片中的区域推荐样本,其为实际为imdb中的一个变量。打印出区域推荐样本的数量

接下来设置训练好的模型和tensorboard文件的存储路径,再获取验证集的数据,前面的训练的数据是经过数据增强的,每张图片都经过旋转,验证集不进行数据增强。

接下来,配置vgg16网络的batch数量,默认是设置为1。

最后调用train_val.py中的train_net()函数开启训练。

未完待续。

评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值