faster-rcnn 之训练数据是如何准备的:imdb和roidb的产生

http://blog.csdn.net/zouyu1746430162/article/details/53911555


关于imdb和roidb的生成都是在函数train_rpn的中,所以我们从这个函数开始,逐个跟进看如何执行得到我们需要的imdb和roidb:


[python]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. def train_rpn(queue=None, imdb_name=None, init_model=None, solver=None,  
  2.               max_iters=None, cfg=None):  
  3.     """Train a Region Proposal Network in a separate training process. 
  4.     """  
  5.   
  6.     # Not using any proposals, just ground-truth boxes  
  7.     cfg.TRAIN.HAS_RPN = True  
  8.     cfg.TRAIN.BBOX_REG = False  # applies only to Fast R-CNN bbox regression  
  9.     cfg.TRAIN.PROPOSAL_METHOD = 'gt'  
  10.     cfg.TRAIN.IMS_PER_BATCH = 1  
  11.     print 'Init model: {}'.format(init_model)  
  12.     print('Using config:')  
  13.     pprint.pprint(cfg)  
  14.   
  15.     import caffe  
  16.     _init_caffe(cfg)  
  17.   
  18.     roidb, imdb = get_roidb(imdb_name) # 调用函数,返回训练数据  
  19.     print 'roidb len: {}'.format(len(roidb))  
  20.     output_dir = get_output_dir(imdb)  
  21.     print 'Output will be saved to `{:s}`'.format(output_dir)  
  22.   
  23.     model_paths = train_net(solver, roidb, output_dir, #传入数据roidb,供训练  
  24.                             pretrained_model=init_model,  
  25.                             max_iters=max_iters)  
  26.     # Cleanup all but the final model  
  27.     for i in model_paths[:-1]:  
  28.         os.remove(i)  
  29.     rpn_model_path = model_paths[-1]  
  30.     # Send final model path through the multiprocessing queue  
  31.     queue.put({'model_path': rpn_model_path})  

所以,进入get_roidb函数:

[python]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. def get_roidb(imdb_name, rpn_file=None):  
  2.     imdb = get_imdb(imdb_name) # 调用该函数,返回imdb  
  3.     print 'Loaded dataset `{:s}` for training'.format(imdb.name)  
  4.     imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)  
  5.     print 'Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)  
  6.     if rpn_file is not None:  
  7.         imdb.config['rpn_file'] = rpn_file  
  8.     roidb = get_training_roidb(imdb) #利用imdb,产生roi_db  
  9.     return roidb, imdb  

所以我们要先看imdb是如何产生的,然后看如何借助imdb产生roidb

[python]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. def get_imdb(name):  
  2.     """Get an imdb (image database) by name."""  
  3.     if not __sets.has_key(name):  
  4.         raise KeyError('Unknown dataset: {}'.format(name))  
  5.     return __sets[name]()  
从上面可见,get_imdb这个函数的实现原理:_sets是一个字典,字典的key是数据集的名称,字典的value是一个lambda表达式(即一个函数指针),
[python]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. __sets[name]()  
这句话实际上是调用函数,返回数据集imdb,下面看这个函数:
[python]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. for year in ['2007''2012']:  
  2.     for split in ['train''val''trainval''test']:  
  3.         name = 'voc_{}_{}'.format(year, split)  
  4.         __sets[name] = (lambda split=split, year=year: pascal_voc(split, year))  
所以可以看到,执行的实际上是pascal_voc函数,参数是split 和 year(ps:在train_vpn函数中,name是voc_2007_trainval,所以这里对应的split和year分别是trainval和2007);
很明显,pascal_voc是一个类,这是调用了该类的构造函数,返回的也是该类的一个实例,所以这下我们清楚了imdb实际上就是pascal_voc的一个实例;

那么我们来看这个类的构造函数是如何的,以及输入的图片数据在里面是如何组织的:

该类的构造函数如下:基本上就是设置了imdb的一些属性,比如图片的路径,图片名称的索引,并没有把真实的图片数据放进来

[python]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. class pascal_voc(imdb):  
  2.     def __init__(self, image_set, year, devkit_path=None):  
  3.         imdb.__init__(self'voc_' + year + '_' + image_set)  
  4.         self._year = year # 设置年,2007  
  5.         self._image_set = image_set # trainval  
  6.         self._devkit_path = self._get_default_path() if devkit_path is None \  
  7.                             else devkit_path # 数据集的路径'/home/sloan/py-faster-rcnn-master/data/VOCdevkit2007'  
  8.         self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year) # '/home/sloan/py-faster-rcnn-master/data/VOCdevkit2007/VOC2007'  
  9.         self._classes = ('__background__'# always index 0  
  10.                          'aeroplane''bicycle''bird''boat',  
  11.                          'bottle''bus''car''cat''chair',  
  12.                          'cow''diningtable''dog''horse',  
  13.                          'motorbike''person''pottedplant',  
  14.                          'sheep''sofa''train''tvmonitor'# 21个类别  
  15.         self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes))) #给每个类别赋予一个对应的整数  
  16.         self._image_ext = '.jpg' # 图片的扩展名  
  17.         self._image_index = self._load_image_set_index() # 把所有图片的名称加载,放在list中,便于索引读取图片  
  18.         # Default to roidb handler  
  19.         self._roidb_handler = self.selective_search_roidb  
  20.         self._salt = str(uuid.uuid4())  
  21.         self._comp_id = 'comp4'  
  22.   
  23.         # PASCAL specific config options  
  24.         self.config = {'cleanup'     : True,  
  25.                        'use_salt'    : True,  
  26.                        'use_diff'    : False,  
  27.                        'matlab_eval' : False,  
  28.                        'rpn_file'    : None,  
  29.                        'min_size'    : 2}  
  30.         # 这两句就是检查前面的路径是否存在合法了,否则后面无法运行  
  31.         assert os.path.exists(self._devkit_path), \  
  32.                 'VOCdevkit path does not exist: {}'.format(self._devkit_path)  
  33.         assert os.path.exists(self._data_path), \  
  34.                 'Path does not exist: {}'.format(self._data_path)  

那么有了imdb之后,roidb又有什么不同呢?为什么实际输入train_rpn的数据是roidb呢?

前面我们已经得到了imdb,但是imdb的成员roidb还是空白,啥都没有,那么roidb是如何生成的,其中又包含了哪些信息呢?

[python]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)  
上面调用的函数,为imdb添加了roidb的数据,我们看看如何添加的,见下面这个函数:

[python]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. def set_proposal_method(self, method):  
  2.     method = eval('self.' + method + '_roidb')  
  3.     self.roidb_handler = method  
这里method传入的是一个str:gt,所以method=eval('self.gt_roidb')

那么关键就是eval函数做了什么操作???分析这个函数分析roidb中每个元素的具体hany
有了roidb后,后面的get_training_roidb(imdb)完成什么功能:将roidb中的元素由5011个,通过水平对称变成10022个;将index这个list的元素相应的也翻一番;

我们看看这个函数:

[python]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. <span style="font-family: Arial, Helvetica, sans-serif; background-color: rgb(255, 255, 255);">函数如下:这个函数首先对imdb中涉及到的图像做了一个水平镜像,使得trainval中的5011张图片,变成了10022张图片;然后调用函数prepare_roidb函数准备数据(ps:我觉得作者这些函数的层层嵌套,又没做多大事情,实在是让结构不那么美观)</span>  

[python]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. def get_training_roidb(imdb):  
  2.     """Returns a roidb (Region of Interest database) for use in training."""  
  3.     if cfg.TRAIN.USE_FLIPPED:  
  4.         print 'Appending horizontally-flipped training examples...'  
  5.         imdb.append_flipped_images()  
  6.         print 'done'  
  7.   
  8.     print 'Preparing training data...'  
  9.     rdl_roidb.prepare_roidb(imdb)  
  10.     print 'done'  
首先我们看看append_flipped_images函数:可以发现,roidb是imdb的一个成员变量,roidb是一个list(每个元素对应一张图片),list中的元素是一个字典,字典中存放了5个key,分别是boxes信息,每个box的class信息,是否是flipped的标志位,重叠信息gt_overlaps,以及seg_areas;分析该函数可知,将box的值按照水平对称,原先roidb中只有5011个元素,经过水平对称后通过append增加到5011*2=10022个;

[python]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. def append_flipped_images(self):  
  2.     num_images = self.num_images  
  3.     widths = self._get_widths()  
  4.     for i in xrange(num_images):  
  5.         boxes = self.roidb[i]['boxes'].copy()  
  6.         oldx1 = boxes[:, 0].copy()  
  7.         oldx2 = boxes[:, 2].copy()  
  8.         boxes[:, 0] = widths[i] - oldx2 - 1  
  9.         boxes[:, 2] = widths[i] - oldx1 - 1 # 新框的xmin和xmax都要更新  
  10.         assert (boxes[:, 2] >= boxes[:, 0]).all()  
  11.         entry = {'boxes' : boxes,  
  12.                  'gt_overlaps' : self.roidb[i]['gt_overlaps'],  
  13.                  'gt_classes' : self.roidb[i]['gt_classes'],  
  14.                  'flipped' : True}  
  15.         self.roidb.append(entry) # 把这个新的框添加到roidb中  
  16.     self._image_index = self._image_index * 2 #将索引的list 复制拼接  
然后就是prepare_roidb函数:

[python]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. def prepare_roidb(imdb):  
  2.     """Enrich the imdb's roidb by adding some derived quantities that 
  3.     are useful for training. This function precomputes the maximum 
  4.     overlap, taken over ground-truth boxes, between each ROI and 
  5.     each ground-truth box. The class with maximum overlap is also 
  6.     recorded. 
  7.     """  
  8.     sizes = [PIL.Image.open(imdb.image_path_at(i)).size  
  9.              for i in xrange(imdb.num_images)]  
  10.     roidb = imdb.roidb  
  11.     for i in xrange(len(imdb.image_index)):  
  12.         roidb[i]['image'] = imdb.image_path_at(i)  
  13.         roidb[i]['width'] = sizes[i][0]  
  14.         roidb[i]['height'] = sizes[i][1]  
  15.         # need gt_overlaps as a dense array for argmax  
  16.         gt_overlaps = roidb[i]['gt_overlaps'].toarray()  
  17.         # max overlap with gt over classes (columns)  
  18.         max_overlaps = gt_overlaps.max(axis=1)  
  19.         # gt class that had the max overlap  
  20.         max_classes = gt_overlaps.argmax(axis=1)  
  21.         roidb[i]['max_classes'] = max_classes  
  22.         roidb[i]['max_overlaps'] = max_overlaps  
  23.         # sanity checks  
  24.         # max overlap of 0 => class should be zero (background)  
  25.         zero_inds = np.where(max_overlaps == 0)[0]  
  26.         assert all(max_classes[zero_inds] == 0)  
  27.         # max overlap > 0 => class should not be zero (must be a fg class)  
  28.         nonzero_inds = np.where(max_overlaps > 0)[0]  
  29.         assert all(max_classes[nonzero_inds] != 0)  
============================================================================================================================

写到这里,我就想吐槽了,以为数据准备好了么,no,上面只是准备好了roidb的相关信息而已;

我表示这个作者搞的太麻烦了,结构不够扁平化,简单的事情用多个函数绕来绕去,受不了了;

真正的数据处理操作是在

class RoIDataLayer(caffe.Layer): 类的

    def forward(self, bottom, top):函数中开始的,这个类在faster-rcnn-root/lib/roi_data_layer/layer.py文件中

blobs = self._get_next_minibatch()这句话产生了我们需要的数据blobs;这个函数又调用了minibatch.py文件中的def get_minibatch(roidb, num_classes):函数;

然后又调用了def _get_image_blob(roidb, scale_inds):函数;在这个函数中,我们终于发现了cv2.imread函数,也就是最终的读取图片到内存的地方:

[python]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. def _get_image_blob(roidb, scale_inds):  
  2.     """Builds an input blob from the images in the roidb at the specified 
  3.     scales. 
  4.     """  
  5.     num_images = len(roidb)  
  6.     processed_ims = []  
  7.     im_scales = []  
  8.     for i in xrange(num_images):  
  9.         im = cv2.imread(roidb[i]['image']) #终于在这里读取图片了  
  10.         if roidb[i]['flipped']:  
  11.             im = im[:, ::-1, :]  
  12.         target_size = cfg.TRAIN.SCALES[scale_inds[i]]  
  13.         im, im_scale = prep_im_for_blob(im, cfg.PIXEL_MEANS, target_size,  
  14.                                         cfg.TRAIN.MAX_SIZE)  
  15.         im_scales.append(im_scale)  
  16.         processed_ims.append(im)  
  17.   
  18.     # Create a blob to hold the input images  
  19.     blob = im_list_to_blob(processed_ims)  
  20.   
  21.     return blob, im_scales  

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值