faster rcnn源码理解

转自:http://blog.csdn.net/u014568921/article/details/53188559

理解faster rcnn的源码有几个关键点

1.算法原理、网络结构、训练过程这是基本

2.要弄懂源码里训练数据数据是怎么组织起来的,imdb,roidb,blob很关键,弄清它们的数据结构以及各个阶段是如何产生的

3.一定的python、numpy基础知识



rpn_train.pt

  1. #stage 1训练RPN时用的网络结构  
  2. name: "ZF"  
  3. layer {  
  4.   name: 'input-data'  
  5.   type: 'Python'  
  6.   top: 'data'  
  7.   top: 'im_info'  
  8.   top: 'gt_boxes'  
  9.   python_param {  
  10.     module: 'roi_data_layer.layer'#对应lib/roi_data_layer/layer.py  
  11. #为训练RPN时为网络输入roi,此时为gt box  
  12.     layer: 'RoIDataLayer'  
  13.     param_str: "'num_classes': 21"  
  14.   }  
  15. }  
  16.   
  17. #前面是ZF网,提取特征用,各个阶段共享  
  18. #========= conv1-conv5 ============  
  19.   
  20. layer {  
  21.     name: "conv1"  
  22.     type: "Convolution"  
  23.     bottom: "data"  
  24.     top: "conv1"  
  25.     param { lr_mult: 1.0 }  
  26.     param { lr_mult: 2.0 }  
  27.     convolution_param {  
  28.         num_output: 96  
  29.         kernel_size: 7  
  30.         pad: 3  
  31.         stride: 2  
  32.     }  
  33. }  
  34. layer {  
  35.     name: "relu1"  
  36.     type: "ReLU"  
  37.     bottom: "conv1"  
  38.     top: "conv1"  
  39. }  
  40. layer {  
  41.     name: "norm1"  
  42.     type: "LRN"  
  43.     bottom: "conv1"  
  44.     top: "norm1"  
  45.     lrn_param {  
  46.         local_size: 3  
  47.         alpha: 0.00005  
  48.         beta: 0.75  
  49.         norm_region: WITHIN_CHANNEL  
  50.     engine: CAFFE  
  51.     }  
  52. }  
  53. layer {  
  54.     name: "pool1"  
  55.     type: "Pooling"  
  56.     bottom: "norm1"  
  57.     top: "pool1"  
  58.     pooling_param {  
  59.         kernel_size: 3  
  60.         stride: 2  
  61.         pad: 1  
  62.         pool: MAX  
  63.     }  
  64. }  
  65. layer {  
  66.     name: "conv2"  
  67.     type: "Convolution"  
  68.     bottom: "pool1"  
  69.     top: "conv2"  
  70.     param { lr_mult: 1.0 }  
  71.     param { lr_mult: 2.0 }  
  72.     convolution_param {  
  73.         num_output: 256  
  74.         kernel_size: 5  
  75.         pad: 2  
  76.         stride: 2  
  77.     }  
  78. }  
  79. layer {  
  80.     name: "relu2"  
  81.     type: "ReLU"  
  82.     bottom: "conv2"  
  83.     top: "conv2"  
  84. }  
  85. layer {  
  86.     name: "norm2"  
  87.     type: "LRN"  
  88.     bottom: "conv2"  
  89.     top: "norm2"  
  90.     lrn_param {  
  91.         local_size: 3  
  92.         alpha: 0.00005  
  93.         beta: 0.75  
  94.         norm_region: WITHIN_CHANNEL  
  95.     engine: CAFFE  
  96.     }  
  97. }  
  98. layer {  
  99.     name: "pool2"  
  100.     type: "Pooling"  
  101.     bottom: "norm2"  
  102.     top: "pool2"  
  103.     pooling_param {  
  104.         kernel_size: 3  
  105.         stride: 2  
  106.         pad: 1  
  107.         pool: MAX  
  108.     }  
  109. }  
  110. layer {  
  111.     name: "conv3"  
  112.     type: "Convolution"  
  113.     bottom: "pool2"  
  114.     top: "conv3"  
  115.     param { lr_mult: 1.0 }  
  116.     param { lr_mult: 2.0 }  
  117.     convolution_param {  
  118.         num_output: 384  
  119.         kernel_size: 3  
  120.         pad: 1  
  121.         stride: 1  
  122.     }  
  123. }  
  124. layer {  
  125.     name: "relu3"  
  126.     type: "ReLU"  
  127.     bottom: "conv3"  
  128.     top: "conv3"  
  129. }  
  130. layer {  
  131.     name: "conv4"  
  132.     type: "Convolution"  
  133.     bottom: "conv3"  
  134.     top: "conv4"  
  135.     param { lr_mult: 1.0 }  
  136.     param { lr_mult: 2.0 }  
  137.     convolution_param {  
  138.         num_output: 384  
  139.         kernel_size: 3  
  140.         pad: 1  
  141.         stride: 1  
  142.     }  
  143. }  
  144. layer {  
  145.     name: "relu4"  
  146.     type: "ReLU"  
  147.     bottom: "conv4"  
  148.     top: "conv4"  
  149. }  
  150. layer {  
  151.     name: "conv5"  
  152.     type: "Convolution"  
  153.     bottom: "conv4"  
  154.     top: "conv5"  
  155.     param { lr_mult: 1.0 }  
  156.     param { lr_mult: 2.0 }  
  157.     convolution_param {  
  158.         num_output: 256  
  159.         kernel_size: 3  
  160.         pad: 1  
  161.         stride: 1  
  162.     }  
  163. }  
  164. layer {  
  165.     name: "relu5"  
  166.     type: "ReLU"  
  167.     bottom: "conv5"  
  168.     top: "conv5"  
  169. }  
  170.   
  171. #========= RPN ============  
  172.   
  173. layer {  
  174.   name: "rpn_conv1"  
  175.   type: "Convolution"  
  176.   bottom: "conv5"  
  177.   top: "rpn_conv1"  
  178.   param { lr_mult: 1.0 }  
  179.   param { lr_mult: 2.0 }  
  180.   convolution_param {  
  181.     num_output: 256  
  182.     kernel_size: 3 pad: 1 stride: 1  
  183.     weight_filler { type: "gaussian" std: 0.01 }  
  184.     bias_filler { type: "constant" value: 0 }  
  185.   }  
  186. }  
  187. layer {  
  188.   name: "rpn_relu1"  
  189.   type: "ReLU"  
  190.   bottom: "rpn_conv1"  
  191.   top: "rpn_conv1"  
  192. }  
  193. layer {  
  194.   name: "rpn_cls_score"  
  195.   type: "Convolution"  
  196.   bottom: "rpn_conv1"  
  197.   top: "rpn_cls_score"  
  198.   param { lr_mult: 1.0 }  
  199.   param { lr_mult: 2.0 }  
  200.   convolution_param {  
  201.     num_output: 18   # 2(bg/fg) * 9(anchors)  
  202.     kernel_size: 1 pad: 0 stride: 1  
  203.     weight_filler { type: "gaussian" std: 0.01 }  
  204.     bias_filler { type: "constant" value: 0 }  
  205.   }  
  206. }  
  207. layer {  
  208.   name: "rpn_bbox_pred"  
  209.   type: "Convolution"  
  210.   bottom: "rpn_conv1"  
  211.   top: "rpn_bbox_pred"  
  212.   param { lr_mult: 1.0 }  
  213.   param { lr_mult: 2.0 }  
  214.   convolution_param {  
  215.     num_output: 36   # 4 * 9(anchors)  
  216.     kernel_size: 1 pad: 0 stride: 1  
  217.     weight_filler { type: "gaussian" std: 0.01 }  
  218.     bias_filler { type: "constant" value: 0 }  
  219.   }  
  220. }  
  221. layer {  
  222.    bottom: "rpn_cls_score"  
  223.    top: "rpn_cls_score_reshape"  
  224.    name: "rpn_cls_score_reshape"  
  225.    type: "Reshape"  
  226.    reshape_param { shape { dim: 0 dim: 2 dim: -1 dim: 0 } }  
  227. }  
  228. layer {  
  229.   name: 'rpn-data'  
  230.   type: 'Python'  
  231.   bottom: 'rpn_cls_score'  
  232.   bottom: 'gt_boxes'  
  233.   bottom: 'im_info'  
  234.   bottom: 'data'  
  235.   top: 'rpn_labels'  
  236.   top: 'rpn_bbox_targets'  
  237.   top: 'rpn_bbox_inside_weights'  
  238.   top: 'rpn_bbox_outside_weights'  
  239.   python_param {  
  240.     module: 'rpn.anchor_target_layer'#对应文件lib/rpn/anchor_target_layer.py  
  241. #用于在原图上产生anchor,结合gt box训练rpn做box cls和box reg  
  242.     layer: 'AnchorTargetLayer'  
  243.     param_str: "'feat_stride': 16"  
  244.   }  
  245. }  
  246. layer {  
  247.   name: "rpn_loss_cls"  
  248.   type: "SoftmaxWithLoss"  
  249.   bottom: "rpn_cls_score_reshape"  
  250.   bottom: "rpn_labels"  
  251.   propagate_down: 1  
  252.   propagate_down: 0  
  253.   top: "rpn_cls_loss"  
  254.   loss_weight: 1  
  255.   loss_param {  
  256.     ignore_label: -1  
  257.     normalize: true  
  258.   }  
  259. }  
  260. layer {  
  261.   name: "rpn_loss_bbox"  
  262.   type: "SmoothL1Loss"  
  263.   bottom: "rpn_bbox_pred"  
  264.   bottom: "rpn_bbox_targets"  
  265.   bottom: "rpn_bbox_inside_weights"  
  266.   bottom: "rpn_bbox_outside_weights"  
  267.   top: "rpn_loss_bbox"  
  268.   loss_weight: 1  
  269.   smooth_l1_loss_param { sigma: 3.0 }  
  270. }  
  271.   
  272. #========= RCNN ============  
  273. # Dummy layers so that initial parameters are saved into the output net  
  274.   
  275. layer {  
  276.   name: "dummy_roi_pool_conv5"  
  277.   type: "DummyData"  
  278.   top: "dummy_roi_pool_conv5"  
  279.   dummy_data_param {  
  280.     shape { dim: 1 dim: 9216 }  
  281.     data_filler { type: "gaussian" std: 0.01 }  
  282.   }  
  283. }  
  284. layer {  
  285.   name: "fc6"  
  286.   type: "InnerProduct"  
  287.   bottom: "dummy_roi_pool_conv5"  
  288.   top: "fc6"  
  289.   param { lr_mult: 0 decay_mult: 0 }  
  290.   param { lr_mult: 0 decay_mult: 0 }  
  291.   inner_product_param {  
  292.     num_output: 4096  
  293.   }  
  294. }  
  295. layer {  
  296.   name: "relu6"  
  297.   type: "ReLU"  
  298.   bottom: "fc6"  
  299.   top: "fc6"  
  300. }  
  301. layer {  
  302.   name: "fc7"  
  303.   type: "InnerProduct"  
  304.   bottom: "fc6"  
  305.   top: "fc7"  
  306.   param { lr_mult: 0 decay_mult: 0 }  
  307.   param { lr_mult: 0 decay_mult: 0 }  
  308.   inner_product_param {  
  309.     num_output: 4096  
  310.   }  
  311. }  
  312. layer {  
  313.   name: "silence_fc7"  
  314.   type: "Silence"  
  315.   bottom: "fc7"  
  316. }  


上面需要注意的是rpn_cls_score层为每个位置的9个anchor做的只是bg/fg的二分类,而不管具体是fg的话属于那一类别,rpn阶段完成这个任务就够了,后面fast rcnn可以对region proposal进行细分和位置精修




roi_data_layer/layer.py


  1. #coding:utf-8  
  2. # --------------------------------------------------------  
  3. # Fast R-CNN  
  4. # Copyright (c) 2015 Microsoft  
  5. # Licensed under The MIT License [see LICENSE for details]  
  6. # Written by Ross Girshick  
  7. # --------------------------------------------------------  
  8.   
  9. """The data layer used during training to train a Fast R-CNN network. 
  10.  
  11. RoIDataLayer implements a Caffe Python layer. 
  12. """  
  13.   
  14. import caffe  
  15. from fast_rcnn.config import cfg  
  16. from roi_data_layer.minibatch import get_minibatch  
  17. import numpy as np  
  18. import yaml  
  19. from multiprocessing import Process, Queue  
  20.   
  21.   
  22. #为网络输入roi  
  23. class RoIDataLayer(caffe.Layer):  
  24.     """Fast R-CNN data layer used for training."""  
  25.   
  26.     def _shuffle_roidb_inds(self):  
  27.         """Randomly permute the training roidb."""  
  28.         if cfg.TRAIN.ASPECT_GROUPING:  
  29.             widths = np.array([r['width'for r in self._roidb])  
  30.             heights = np.array([r['height'for r in self._roidb])  
  31.             horz = (widths >= heights)  
  32.             vert = np.logical_not(horz)  
  33.             horz_inds = np.where(horz)[0]  
  34.             vert_inds = np.where(vert)[0]  
  35.             inds = np.hstack((  
  36.                 np.random.permutation(horz_inds),  
  37.                 np.random.permutation(vert_inds)))  
  38.             inds = np.reshape(inds, (-12))  
  39.             row_perm = np.random.permutation(np.arange(inds.shape[0]))  
  40.             inds = np.reshape(inds[row_perm, :], (-1,))  
  41.             self._perm = inds  
  42.         else:  
  43.             self._perm = np.random.permutation(np.arange(len(self._roidb)))  
  44.         self._cur = 0  
  45. #得到下一个batch训练用的图像的index,默认一次两张图片  
  46.     def _get_next_minibatch_inds(self):  
  47.         """Return the roidb indices for the next minibatch."""  
  48. #如果所有图片都用完了,打乱顺序,roidb由每张图片的rois集合构成  
  49.         if self._cur + cfg.TRAIN.IMS_PER_BATCH >= len(self._roidb):  
  50.             self._shuffle_roidb_inds()  
  51. #从_cur记录的位置开始选择cfg.TRAIN.IMS_PER_BATCH张图片作为训练用  
  52.         db_inds = self._perm[self._cur:self._cur + cfg.TRAIN.IMS_PER_BATCH]  
  53.         self._cur += cfg.TRAIN.IMS_PER_BATCH  
  54.         return db_inds  
  55. #取得训练用的blob  
  56.     def _get_next_minibatch(self):  
  57.         """Return the blobs to be used for the next minibatch. 
  58.  
  59.         If cfg.TRAIN.USE_PREFETCH is True, then blobs will be computed in a 
  60.         separate process and made available through self._blob_queue. 
  61.         """  
  62.         if cfg.TRAIN.USE_PREFETCH:  
  63.             return self._blob_queue.get()  
  64.         else:  
  65.             db_inds = self._get_next_minibatch_inds()  
  66.             minibatch_db = [self._roidb[i] for i in db_inds]  
  67. #函数在lib/roi_data_layer/minibatch.py里实现  
  68.             return get_minibatch(minibatch_db, self._num_classes)  
  69.   
  70.     def set_roidb(self, roidb):  
  71.         """Set the roidb to be used by this layer during training."""  
  72.         self._roidb = roidb  
  73.         self._shuffle_roidb_inds()  
  74.         if cfg.TRAIN.USE_PREFETCH:  
  75.             self._blob_queue = Queue(10)  
  76.             self._prefetch_process = BlobFetcher(self._blob_queue,  
  77.                                                  self._roidb,  
  78.                                                  self._num_classes)  
  79.             self._prefetch_process.start()  
  80.             # Terminate the child process when the parent exists  
  81.             def cleanup():  
  82.                 print 'Terminating BlobFetcher'  
  83.                 self._prefetch_process.terminate()  
  84.                 self._prefetch_process.join()  
  85.             import atexit  
  86.             atexit.register(cleanup)  
  87. #该层初始化时调用  
  88.     def setup(self, bottom, top):  
  89.         """Setup the RoIDataLayer."""  
  90.   
  91.         # parse the layer parameter string, which must be valid YAML  
  92.         layer_params = yaml.load(self.param_str_)  
  93.   
  94.         self._num_classes = layer_params['num_classes']  
  95.   
  96.         self._name_to_top_map = {}  
  97.   
  98.         # data blob: holds a batch of N images, each with 3 channels  
  99.         idx = 0  
  100.         top[idx].reshape(cfg.TRAIN.IMS_PER_BATCH, 3,  
  101.             max(cfg.TRAIN.SCALES), cfg.TRAIN.MAX_SIZE)  
  102.         self._name_to_top_map['data'] = idx  
  103.         idx += 1  
  104. #如果要训练RPN网,roi是gt box  
  105.         if cfg.TRAIN.HAS_RPN:  
  106.             top[idx].reshape(13)  
  107.             self._name_to_top_map['im_info'] = idx  
  108.             idx += 1  
  109.   
  110.             top[idx].reshape(14)  
  111.             self._name_to_top_map['gt_boxes'] = idx  
  112.             idx += 1  
  113. #如果是训练fast rcnn则roi是之前RPN提取的region proposal  
  114.         else# not using RPN  
  115.             # rois blob: holds R regions of interest, each is a 5-tuple  
  116.             # (n, x1, y1, x2, y2) specifying an image batch index n and a  
  117.             # rectangle (x1, y1, x2, y2)  
  118.             top[idx].reshape(15)  
  119.             self._name_to_top_map['rois'] = idx  
  120.             idx += 1  
  121.   
  122.             # labels blob: R categorical labels in [0, ..., K] for K foreground  
  123.             # classes plus background  
  124.             top[idx].reshape(1)  
  125.             self._name_to_top_map['labels'] = idx  
  126.             idx += 1  
  127.   
  128.             if cfg.TRAIN.BBOX_REG:  
  129.                 # bbox_targets blob: R bounding-box regression targets with 4  
  130.                 # targets per class  
  131.                 top[idx].reshape(1self._num_classes * 4)  
  132.                 self._name_to_top_map['bbox_targets'] = idx  
  133.                 idx += 1  
  134.   
  135.                 # bbox_inside_weights blob: At most 4 targets per roi are active;  
  136.                 # thisbinary vector sepcifies the subset of active targets  
  137.                 top[idx].reshape(1self._num_classes * 4)  
  138.                 self._name_to_top_map['bbox_inside_weights'] = idx  
  139.                 idx += 1  
  140.   
  141.                 top[idx].reshape(1self._num_classes * 4)  
  142.                 self._name_to_top_map['bbox_outside_weights'] = idx  
  143.                 idx += 1  
  144.   
  145.         print 'RoiDataLayer: name_to_top:'self._name_to_top_map  
  146.         assert len(top) == len(self._name_to_top_map)  
  147. #作为输入前向计算  
  148.     def forward(self, bottom, top):  
  149.         """Get blobs and copy them into this layer's top blob vector."""  
  150.         blobs = self._get_next_minibatch()  
  151.   
  152.         for blob_name, blob in blobs.iteritems():  
  153.             top_ind = self._name_to_top_map[blob_name]  
  154.             # Reshape net's input blobs  
  155.             top[top_ind].reshape(*(blob.shape))  
  156.             # Copy data into net's input blobs  
  157.             top[top_ind].data[...] = blob.astype(np.float32, copy=False)  
  158. #不用反向传播  
  159.     def backward(self, top, propagate_down, bottom):  
  160.         """This layer does not propagate gradients."""  
  161.         pass  
  162.   
  163.     def reshape(self, bottom, top):  
  164.         """Reshaping happens during the call to forward."""  
  165.         pass  
  166.   
  167. class BlobFetcher(Process):  
  168.     """Experimental class for prefetching blobs in a separate process."""  
  169.     def __init__(self, queue, roidb, num_classes):  
  170.         super(BlobFetcher, self).__init__()  
  171.         self._queue = queue  
  172.         self._roidb = roidb  
  173.         self._num_classes = num_classes  
  174.         self._perm = None  
  175.         self._cur = 0  
  176.         self._shuffle_roidb_inds()  
  177.         # fix the random seed for reproducibility  
  178.         np.random.seed(cfg.RNG_SEED)  
  179.   
  180.     def _shuffle_roidb_inds(self):  
  181.         """Randomly permute the training roidb."""  
  182.         # TODO(rbg): remove duplicated code  
  183.         self._perm = np.random.permutation(np.arange(len(self._roidb)))  
  184.         self._cur = 0  
  185.   
  186.     def _get_next_minibatch_inds(self):  
  187.         """Return the roidb indices for the next minibatch."""  
  188.         # TODO(rbg): remove duplicated code  
  189.         if self._cur + cfg.TRAIN.IMS_PER_BATCH >= len(self._roidb):  
  190.             self._shuffle_roidb_inds()  
  191.   
  192.         db_inds = self._perm[self._cur:self._cur + cfg.TRAIN.IMS_PER_BATCH]  
  193.         self._cur += cfg.TRAIN.IMS_PER_BATCH  
  194.         return db_inds  
  195.   
  196.     def run(self):  
  197.         print 'BlobFetcher started'  
  198.         while True:  
  199.             db_inds = self._get_next_minibatch_inds()  
  200.             minibatch_db = [self._roidb[i] for i in db_inds]  
  201.             blobs = get_minibatch(minibatch_db, self._num_classes)  
  202.             self._queue.put(blobs)  




其中用到了lib/roi_data_layer/minibatch.py里的函数getminibatch


  1. #coding:utf-8  
  2. # --------------------------------------------------------  
  3. # Fast R-CNN  
  4. # Copyright (c) 2015 Microsoft  
  5. # Licensed under The MIT License [see LICENSE for details]  
  6. # Written by Ross Girshick  
  7. # --------------------------------------------------------  
  8.   
  9. """Compute minibatch blobs for training a Fast R-CNN network."""  
  10.   
  11. import numpy as np  
  12. import numpy.random as npr  
  13. import cv2  
  14. from fast_rcnn.config import cfg  
  15. from utils.blob import prep_im_for_blob, im_list_to_blob  
  16.   
  17.   
  18.   
  19. #采样产生训练用的rois的blob,可以直接作为caffe的输入  
  20. def get_minibatch(roidb, num_classes):  
  21.     """Given a roidb, construct a minibatch sampled from it."""  
  22.     num_images = len(roidb)  
  23. #从预设的训练尺度里随机抽样用作此次产生的batch里用的roi的尺度  
  24.     # Sample random scales to use for each image in this batch  
  25.     random_scale_inds = npr.randint(0, high=len(cfg.TRAIN.SCALES),  
  26.                                     size=num_images)  
  27. #BATCH_SIZE为一个minibatch里训练用的roi的数量  
  28.     assert(cfg.TRAIN.BATCH_SIZE % num_images == 0), \  
  29.         'num_images ({}) must divide BATCH_SIZE ({})'. \  
  30.         format(num_images, cfg.TRAIN.BATCH_SIZE)  
  31. #每张图片上应该抽样得到的roi的数量  
  32.     rois_per_image = cfg.TRAIN.BATCH_SIZE / num_images  
  33. #前景roi的数量  
  34.     fg_rois_per_image = np.round(cfg.TRAIN.FG_FRACTION * rois_per_image)  
  35. #产生caffe能用的blob  
  36.     # Get the input image blob, formatted for caffe  
  37. #_get_image_blob的实现在本文件的后面  
  38.     im_blob, im_scales = _get_image_blob(roidb, random_scale_inds)  
  39.   
  40.     blobs = {'data': im_blob}  
  41. #训练RPN时  
  42.     if cfg.TRAIN.HAS_RPN:  
  43.         assert len(im_scales) == 1"Single batch only"  
  44.         assert len(roidb) == 1"Single batch only"  
  45.         # gt boxes: (x1, y1, x2, y2, cls)  
  46. #属于前景的roi的真实类别  
  47.         gt_inds = np.where(roidb[0]['gt_classes'] != 0)[0]  
  48.         gt_boxes = np.empty((len(gt_inds), 5), dtype=np.float32)  
  49. #gt_boxes[i]类似于(x1,y1,x2,y2,cls)  
  50.         gt_boxes[:, 0:4] = roidb[0]['boxes'][gt_inds, :] * im_scales[0]  
  51.         gt_boxes[:, 4] = roidb[0]['gt_classes'][gt_inds]  
  52.         blobs['gt_boxes'] = gt_boxes  
  53.         blobs['im_info'] = np.array(  
  54.             [[im_blob.shape[2], im_blob.shape[3], im_scales[0]]],  
  55.             dtype=np.float32)  
  56. #训练fast rcnn时  
  57.     else# not using RPN  
  58.         # Now, build the region of interest and label blobs  
  59.         rois_blob = np.zeros((05), dtype=np.float32)  
  60.         labels_blob = np.zeros((0), dtype=np.float32)  
  61.         bbox_targets_blob = np.zeros((04 * num_classes), dtype=np.float32)  
  62.         bbox_inside_blob = np.zeros(bbox_targets_blob.shape, dtype=np.float32)  
  63.         # all_overlaps = []  
  64.         for im_i in xrange(num_images):  
  65. #_sample_rois实现在下面,实现从每张图片的rois里采样  
  66.             labels, overlaps, im_rois, bbox_targets, bbox_inside_weights \  
  67.                 = _sample_rois(roidb[im_i], fg_rois_per_image, rois_per_image,  
  68.                                num_classes)  
  69.   
  70.             # Add to RoIs blob  
  71.             rois = _project_im_rois(im_rois, im_scales[im_i])  
  72.             batch_ind = im_i * np.ones((rois.shape[0], 1))  
  73.             rois_blob_this_image = np.hstack((batch_ind, rois))  
  74.             rois_blob = np.vstack((rois_blob, rois_blob_this_image))  
  75.   
  76.             # Add to labels, bbox targets, and bbox loss blobs  
  77.             labels_blob = np.hstack((labels_blob, labels))  
  78.             bbox_targets_blob = np.vstack((bbox_targets_blob, bbox_targets))  
  79.             bbox_inside_blob = np.vstack((bbox_inside_blob, bbox_inside_weights))  
  80.             # all_overlaps = np.hstack((all_overlaps, overlaps))  
  81.   
  82.         # For debug visualizations  
  83.         # _vis_minibatch(im_blob, rois_blob, labels_blob, all_overlaps)  
  84.   
  85.         blobs['rois'] = rois_blob  
  86.         blobs['labels'] = labels_blob  
  87.   
  88.         if cfg.TRAIN.BBOX_REG:  
  89.             blobs['bbox_targets'] = bbox_targets_blob  
  90.             blobs['bbox_inside_weights'] = bbox_inside_blob  
  91.             blobs['bbox_outside_weights'] = \  
  92.                 np.array(bbox_inside_blob > 0).astype(np.float32)  
  93.   
  94.     return blobs  
  95.   
  96. #从一张图片的rois里采样得到roi  
  97. def _sample_rois(roidb, fg_rois_per_image, rois_per_image, num_classes):  
  98.     """Generate a random sample of RoIs comprising foreground and background 
  99.     examples. 
  100.     """  
  101.     # label = class RoI has max overlap with  
  102.     labels = roidb['max_classes']  
  103.     overlaps = roidb['max_overlaps']  
  104.     rois = roidb['boxes']  
  105.   
  106.     # Select foreground RoIs as those with >= FG_THRESH overlap  
  107.     fg_inds = np.where(overlaps >= cfg.TRAIN.FG_THRESH)[0]  
  108.     # Guard against the case when an image has fewer than fg_rois_per_image  
  109.     # foreground RoIs  
  110. #fg_rois_per_this_image取fg_rois_per_this_image和fg_inds.size的较小的一个  
  111.     fg_rois_per_this_image = np.minimum(fg_rois_per_image, fg_inds.size)  
  112.     # Sample foreground regions without replacement  
  113.     if fg_inds.size > 0:  
  114.         fg_inds = npr.choice(  
  115.                 fg_inds, size=fg_rois_per_this_image, replace=False)  
  116.   
  117.     # Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)  
  118.     bg_inds = np.where((overlaps < cfg.TRAIN.BG_THRESH_HI) &  
  119.                        (overlaps >= cfg.TRAIN.BG_THRESH_LO))[0]  
  120.     # Compute number of background RoIs to take from this image (guarding  
  121.     # against there being fewer than desired)  
  122.     bg_rois_per_this_image = rois_per_image - fg_rois_per_this_image  
  123.     bg_rois_per_this_image = np.minimum(bg_rois_per_this_image,  
  124.                                         bg_inds.size)  
  125. #这里如果正负样本数量相差太大会出问题,此时应该做正负样本平衡,这里没有做  
  126.     # Sample foreground regions without replacement  
  127.     if bg_inds.size > 0:  
  128.         bg_inds = npr.choice(  
  129.                 bg_inds, size=bg_rois_per_this_image, replace=False)  
  130.   
  131.     # The indices that we're selecting (both fg and bg)  
  132.     keep_inds = np.append(fg_inds, bg_inds)  
  133.     # Select sampled values from various arrays:  
  134.     labels = labels[keep_inds]  
  135.     # Clamp labels for the background RoIs to 0  
  136. #设定背景roi的label为0  
  137.     labels[fg_rois_per_this_image:] = 0  
  138.     overlaps = overlaps[keep_inds]  
  139.     rois = rois[keep_inds]  
  140.   
  141.     bbox_targets, bbox_inside_weights = _get_bbox_regression_labels(  
  142.             roidb['bbox_targets'][keep_inds, :], num_classes)  
  143.   
  144.     return labels, overlaps, rois, bbox_targets, bbox_inside_weights  
  145.   
  146. def _get_image_blob(roidb, scale_inds):  
  147.     """Builds an input blob from the images in the roidb at the specified 
  148.     scales. 
  149.     """  
  150.     num_images = len(roidb)  
  151.     processed_ims = []  
  152.     im_scales = []  
  153.     for i in xrange(num_images):  
  154. #读取roi所在的图像  
  155.         im = cv2.imread(roidb[i]['image'])  
  156. #判断该roi是否是由水平翻转得到的  
  157.         if roidb[i]['flipped']:  
  158. #实现水平翻转  
  159.             im = im[:, ::-1, :]  
  160. #得到尺度  
  161.         target_size = cfg.TRAIN.SCALES[scale_inds[i]]  
  162.         im, im_scale = prep_im_for_blob(im, cfg.PIXEL_MEANS, target_size,  
  163.                                         cfg.TRAIN.MAX_SIZE)  
  164.         im_scales.append(im_scale)  
  165.         processed_ims.append(im)  
  166. #在lib/util/blob.py里实现  
  167.     # Create a blob to hold the input images  
  168.     blob = im_list_to_blob(processed_ims)  
  169.   
  170.     return blob, im_scales  
  171.   
  172. def _project_im_rois(im_rois, im_scale_factor):  
  173.     """Project image RoIs into the rescaled training image."""  
  174.     rois = im_rois * im_scale_factor  
  175.     return rois  
  176.   
  177. def _get_bbox_regression_labels(bbox_target_data, num_classes):  
  178.     """Bounding-box regression targets are stored in a compact form in the 
  179.     roidb. 
  180.  
  181.     This function expands those targets into the 4-of-4*K representation used 
  182.     by the network (i.e. only one class has non-zero targets). The loss weights 
  183.     are similarly expanded. 
  184.  
  185.     Returns: 
  186.         bbox_target_data (ndarray): N x 4K blob of regression targets 
  187.         bbox_inside_weights (ndarray): N x 4K blob of loss weights 
  188.     """  
  189.     clss = bbox_target_data[:, 0]  
  190.     bbox_targets = np.zeros((clss.size, 4 * num_classes), dtype=np.float32)  
  191.     bbox_inside_weights = np.zeros(bbox_targets.shape, dtype=np.float32)  
  192.     inds = np.where(clss > 0)[0]  
  193.     for ind in inds:  
  194.         cls = clss[ind]  
  195.         start = 4 * cls  
  196.         end = start + 4  
  197.         bbox_targets[ind, start:end] = bbox_target_data[ind, 1:]  
  198.         bbox_inside_weights[ind, start:end] = cfg.TRAIN.BBOX_INSIDE_WEIGHTS  
  199.     return bbox_targets, bbox_inside_weights  
  200.   
  201. def _vis_minibatch(im_blob, rois_blob, labels_blob, overlaps):  
  202.     """Visualize a mini-batch for debugging."""  
  203.     import matplotlib.pyplot as plt  
  204.     for i in xrange(rois_blob.shape[0]):  
  205.         rois = rois_blob[i, :]  
  206.         im_ind = rois[0]  
  207.         roi = rois[1:]  
  208.         im = im_blob[im_ind, :, :, :].transpose((120)).copy()  
  209.         im += cfg.PIXEL_MEANS  
  210.         im = im[:, :, (210)]  
  211.         im = im.astype(np.uint8)  
  212.         cls = labels_blob[i]  
  213.         plt.imshow(im)  
  214.         print 'class: 'cls' overlap: ', overlaps[i]  
  215.         plt.gca().add_patch(  
  216.             plt.Rectangle((roi[0], roi[1]), roi[2] - roi[0],  
  217.                           roi[3] - roi[1], fill=False,  
  218.                           edgecolor='r', linewidth=3)  
  219.             )  
  220.         plt.show()  



lib/utils/bolb.py

  1. # --------------------------------------------------------  
  2. # Fast R-CNN  
  3. # Copyright (c) 2015 Microsoft  
  4. # Licensed under The MIT License [see LICENSE for details]  
  5. # Written by Ross Girshick  
  6. # --------------------------------------------------------  
  7.   
  8. """Blob helper functions."""  
  9.   
  10. import numpy as np  
  11. import cv2  
  12.   
  13. def im_list_to_blob(ims):  
  14.     """Convert a list of images into a network input. 
  15.  
  16.     Assumes images are already prepared (means subtracted, BGR order, ...). 
  17.     """  
  18.     max_shape = np.array([im.shape for im in ims]).max(axis=0)  
  19.     num_images = len(ims)  
  20.     blob = np.zeros((num_images, max_shape[0], max_shape[1], 3),  
  21.                     dtype=np.float32)  
  22.     for i in xrange(num_images):  
  23.         im = ims[i]  
  24.         blob[i, 0:im.shape[0], 0:im.shape[1], :] = im  
  25.     # Move channels (axis 3) to axis 1  
  26.     # Axis order will become: (batch elem, channel, height, width)  
  27.     channel_swap = (0312)  
  28.     blob = blob.transpose(channel_swap)  
  29.     return blob  
  30.   
  31. def prep_im_for_blob(im, pixel_means, target_size, max_size):  
  32.     """Mean subtract and scale an image for use in a blob."""  
  33.     im = im.astype(np.float32, copy=False)  
  34.     im -= pixel_means  
  35.     im_shape = im.shape  
  36.     im_size_min = np.min(im_shape[0:2])  
  37.     im_size_max = np.max(im_shape[0:2])  
  38.     im_scale = float(target_size) / float(im_size_min)  
  39.     # Prevent the biggest axis from being more than MAX_SIZE  
  40.     if np.round(im_scale * im_size_max) > max_size:  
  41.         im_scale = float(max_size) / float(im_size_max)  
  42.     im = cv2.resize(im, NoneNone, fx=im_scale, fy=im_scale,  
  43.                     interpolation=cv2.INTER_LINEAR)  
  44.   
  45.     return im, im_scale  


lib/rpn/anchor_target_layer.py



  1. #coding:utf-8  
  2. # --------------------------------------------------------  
  3. # Faster R-CNN  
  4. # Copyright (c) 2015 Microsoft  
  5. # Licensed under The MIT License [see LICENSE for details]  
  6. # Written by Ross Girshick and Sean Bell  
  7. # --------------------------------------------------------  
  8.   
  9. import os  
  10. import caffe  
  11. import yaml  
  12. from fast_rcnn.config import cfg  
  13. import numpy as np  
  14. import numpy.random as npr  
  15. from generate_anchors import generate_anchors  
  16. from utils.cython_bbox import bbox_overlaps  
  17. from fast_rcnn.bbox_transform import bbox_transform  
  18.   
  19. DEBUG = False  
  20.   
  21. class AnchorTargetLayer(caffe.Layer):  
  22.     """ 
  23.     Assign anchors to ground-truth targets. Produces anchor classification 
  24.     labels and bounding-box regression targets. 
  25.     """  
  26.   
  27.     def setup(self, bottom, top):  
  28.         layer_params = yaml.load(self.param_str_)  
  29. #设定anchor的三个尺度  
  30.         anchor_scales = layer_params.get('scales', (81632))  
  31. #以(8.5,8.5)为中心产生9个基准anchor  
  32.         self._anchors = generate_anchors(scales=np.array(anchor_scales))  
  33.         self._num_anchors = self._anchors.shape[0]  
  34. #其余的anchor以feat_stride为步长上下滑动产生,config.py里feat_stride设为16,为什么是16,  
  35. #因为不管是VGG还是ZF,conv5之后的scale是原图的1/16,这样产生的achor基本均匀分布在整个原图  
  36.         self._feat_stride = layer_params['feat_stride']  
  37.   
  38.         if DEBUG:  
  39.             print 'anchors:'  
  40.             print self._anchors  
  41.             print 'anchor shapes:'  
  42.             print np.hstack((  
  43.                 self._anchors[:, 2::4] - self._anchors[:, 0::4],  
  44.                 self._anchors[:, 3::4] - self._anchors[:, 1::4],  
  45.             ))  
  46.             self._counts = cfg.EPS  
  47.             self._sums = np.zeros((14))  
  48.             self._squared_sums = np.zeros((14))  
  49.             self._fg_sum = 0  
  50.             self._bg_sum = 0  
  51.             self._count = 0  
  52.   
  53.         # allow boxes to sit over the edge by a small amount  
  54.         self._allowed_border = layer_params.get('allowed_border'0)  
  55. #获得featuremap的宽高  
  56.         height, width = bottom[0].data.shape[-2:]  
  57.         if DEBUG:  
  58.             print 'AnchorTargetLayer: height', height, 'width', width  
  59.   
  60.         A = self._num_anchors  
  61.         # labels  
  62.         top[0].reshape(11, A * height, width)  
  63.         # bbox_targets  
  64.         top[1].reshape(1, A * 4, height, width)  
  65.         # bbox_inside_weights  
  66.         top[2].reshape(1, A * 4, height, width)  
  67.         # bbox_outside_weights  
  68.         top[3].reshape(1, A * 4, height, width)  
  69.   
  70.     def forward(self, bottom, top):  
  71.         # Algorithm:  
  72.         #  
  73.         # for each (H, W) location i  
  74.         #   generate 9 anchor boxes centered on cell i  
  75.         #   apply predicted bbox deltas at cell i to each of the 9 anchors  
  76.         # filter out-of-image anchors  
  77.         # measure GT overlap  
  78.   
  79.         assert bottom[0].data.shape[0] == 1, \  
  80.             'Only single item batches are supported'  
  81.   
  82.         # map of shape (..., H, W)  
  83.         height, width = bottom[0].data.shape[-2:]  
  84.         # GT boxes (x1, y1, x2, y2, label)  
  85.         gt_boxes = bottom[1].data  
  86.         # im_info  
  87.         im_info = bottom[2].data[0, :]  
  88.   
  89.         if DEBUG:  
  90.             print ''  
  91.             print 'im_size: ({}, {})'.format(im_info[0], im_info[1])  
  92.             print 'scale: {}'.format(im_info[2])  
  93.             print 'height, width: ({}, {})'.format(height, width)  
  94.             print 'rpn: gt_boxes.shape', gt_boxes.shape  
  95.             print 'rpn: gt_boxes', gt_boxes  
  96.   
  97.         # 1. Generate proposals from bbox deltas and shifted anchors  
  98.         shift_x = np.arange(0, width) * self._feat_stride  
  99.         shift_y = np.arange(0, height) * self._feat_stride  
  100.         shift_x, shift_y = np.meshgrid(shift_x, shift_y)  
  101.         shifts = np.vstack((shift_x.ravel(), shift_y.ravel(),  
  102.                             shift_x.ravel(), shift_y.ravel())).transpose()  
  103.         # add A anchors (1, A, 4) to  
  104.         # cell K shifts (K, 1, 4) to get  
  105.         # shift anchors (K, A, 4)  
  106.         # reshape to (K*A, 4) shifted anchors  
  107.         A = self._num_anchors  
  108.         K = shifts.shape[0]  
  109.         all_anchors = (self._anchors.reshape((1, A, 4)) +  
  110.                        shifts.reshape((1, K, 4)).transpose((102)))  
  111.         all_anchors = all_anchors.reshape((K * A, 4))  
  112.         total_anchors = int(K * A)  
  113.   
  114.         # only keep anchors inside the image  
  115.         inds_inside = np.where(  
  116.             (all_anchors[:, 0] >= -self._allowed_border) &  
  117.             (all_anchors[:, 1] >= -self._allowed_border) &  
  118.             (all_anchors[:, 2] < im_info[1] + self._allowed_border) &  # width  
  119.             (all_anchors[:, 3] < im_info[0] + self._allowed_border)    # height  
  120.         )[0]  
  121.   
  122.         if DEBUG:  
  123.             print 'total_anchors', total_anchors  
  124.             print 'inds_inside', len(inds_inside)  
  125. #裁掉大小超出图片的anchor,inds_inside是在图像内部的anchor的索引数组  
  126.         # keep only inside anchors  
  127.         anchors = all_anchors[inds_inside, :]  
  128.         if DEBUG:  
  129.             print 'anchors.shape', anchors.shape  
  130.   
  131.         # label: 1 is positive, 0 is negative, -1 is dont care  
  132.         labels = np.empty((len(inds_inside), ), dtype=np.float32)  
  133.         labels.fill(-1)  
  134.   
  135.         # overlaps between the anchors and the gt boxes  
  136.         # overlaps (ex, gt)  
  137.         overlaps = bbox_overlaps(  
  138.             np.ascontiguousarray(anchors, dtype=np.float),  
  139.             np.ascontiguousarray(gt_boxes, dtype=np.float))  
  140.         argmax_overlaps = overlaps.argmax(axis=1)  
  141.         max_overlaps = overlaps[np.arange(len(inds_inside)), argmax_overlaps]  
  142.         gt_argmax_overlaps = overlaps.argmax(axis=0)  
  143.         gt_max_overlaps = overlaps[gt_argmax_overlaps,  
  144.                                    np.arange(overlaps.shape[1])]  
  145.         gt_argmax_overlaps = np.where(overlaps == gt_max_overlaps)[0]  
  146.   
  147.         if not cfg.TRAIN.RPN_CLOBBER_POSITIVES:  
  148.             # assign bg labels first so that positive labels can clobber them  
  149.             labels[max_overlaps < cfg.TRAIN.RPN_NEGATIVE_OVERLAP] = 0  
  150.   
  151.         # fg label: for each gt, anchor with highest overlap  
  152.         labels[gt_argmax_overlaps] = 1  
  153.   
  154.         # fg label: above threshold IOU  
  155.         labels[max_overlaps >= cfg.TRAIN.RPN_POSITIVE_OVERLAP] = 1  
  156.   
  157.         if cfg.TRAIN.RPN_CLOBBER_POSITIVES:  
  158.             # assign bg labels last so that negative labels can clobber positives  
  159.             labels[max_overlaps < cfg.TRAIN.RPN_NEGATIVE_OVERLAP] = 0  
  160. #采样正负anchor,如果正负样本数量不均衡,需要保持正负样本的比例基本为1:1,太悬殊  
  161. #会使得算法漏检严重,下面的算法没有实现保持正负样本均衡  
  162.         # subsample positive labels if we have too many  
  163.         num_fg = int(cfg.TRAIN.RPN_FG_FRACTION * cfg.TRAIN.RPN_BATCHSIZE)  
  164.         fg_inds = np.where(labels == 1)[0]  
  165.         if len(fg_inds) > num_fg:  
  166.             disable_inds = npr.choice(  
  167.                 fg_inds, size=(len(fg_inds) - num_fg), replace=False)  
  168.             labels[disable_inds] = -1  
  169.   
  170.         # subsample negative labels if we have too many  
  171.         num_bg = cfg.TRAIN.RPN_BATCHSIZE - np.sum(labels == 1)  
  172.         bg_inds = np.where(labels == 0)[0]  
  173.         if len(bg_inds) > num_bg:  
  174.             disable_inds = npr.choice(  
  175.                 bg_inds, size=(len(bg_inds) - num_bg), replace=False)  
  176.             labels[disable_inds] = -1  
  177.             #print "was %s inds, disabling %s, now %s inds" % (  
  178.                 #len(bg_inds), len(disable_inds), np.sum(labels == 0))  
  179.   
  180.         bbox_targets = np.zeros((len(inds_inside), 4), dtype=np.float32)  
  181.         bbox_targets = _compute_targets(anchors, gt_boxes[argmax_overlaps, :])  
  182.   
  183.         bbox_inside_weights = np.zeros((len(inds_inside), 4), dtype=np.float32)  
  184.         bbox_inside_weights[labels == 1, :] = np.array(cfg.TRAIN.RPN_BBOX_INSIDE_WEIGHTS)  
  185.   
  186.         bbox_outside_weights = np.zeros((len(inds_inside), 4), dtype=np.float32)  
  187.         if cfg.TRAIN.RPN_POSITIVE_WEIGHT < 0:  
  188.             # uniform weighting of examples (given non-uniform sampling)  
  189.             num_examples = np.sum(labels >= 0)  
  190.             positive_weights = np.ones((14)) * 1.0 / num_examples  
  191.             negative_weights = np.ones((14)) * 1.0 / num_examples  
  192.         else:  
  193.             assert ((cfg.TRAIN.RPN_POSITIVE_WEIGHT > 0) &  
  194.                     (cfg.TRAIN.RPN_POSITIVE_WEIGHT < 1))  
  195.             positive_weights = (cfg.TRAIN.RPN_POSITIVE_WEIGHT /  
  196.                                 np.sum(labels == 1))  
  197.             negative_weights = ((1.0 - cfg.TRAIN.RPN_POSITIVE_WEIGHT) /  
  198.                                 np.sum(labels == 0))  
  199.         bbox_outside_weights[labels == 1, :] = positive_weights  
  200.         bbox_outside_weights[labels == 0, :] = negative_weights  
  201.   
  202.         if DEBUG:  
  203.             self._sums += bbox_targets[labels == 1, :].sum(axis=0)  
  204.             self._squared_sums += (bbox_targets[labels == 1, :] ** 2).sum(axis=0)  
  205.             self._counts += np.sum(labels == 1)  
  206.             means = self._sums / self._counts  
  207.             stds = np.sqrt(self._squared_sums / self._counts - means ** 2)  
  208.             print 'means:'  
  209.             print means  
  210.             print 'stdevs:'  
  211.             print stds  
  212.   
  213.         # map up to original set of anchors  
  214.         labels = _unmap(labels, total_anchors, inds_inside, fill=-1)  
  215.         bbox_targets = _unmap(bbox_targets, total_anchors, inds_inside, fill=0)  
  216.         bbox_inside_weights = _unmap(bbox_inside_weights, total_anchors, inds_inside, fill=0)  
  217.         bbox_outside_weights = _unmap(bbox_outside_weights, total_anchors, inds_inside, fill=0)  
  218.   
  219.         if DEBUG:  
  220.             print 'rpn: max max_overlap', np.max(max_overlaps)  
  221.             print 'rpn: num_positive', np.sum(labels == 1)  
  222.             print 'rpn: num_negative', np.sum(labels == 0)  
  223.             self._fg_sum += np.sum(labels == 1)  
  224.             self._bg_sum += np.sum(labels == 0)  
  225.             self._count += 1  
  226.             print 'rpn: num_positive avg'self._fg_sum / self._count  
  227.             print 'rpn: num_negative avg'self._bg_sum / self._count  
  228.   
  229.         # labels  
  230.         labels = labels.reshape((1, height, width, A)).transpose(0312)  
  231.         labels = labels.reshape((11, A * height, width))  
  232.         top[0].reshape(*labels.shape)  
  233.         top[0].data[...] = labels  
  234.   
  235.         # bbox_targets  
  236.         bbox_targets = bbox_targets \  
  237.             .reshape((1, height, width, A * 4)).transpose(0312)  
  238.         top[1].reshape(*bbox_targets.shape)  
  239.         top[1].data[...] = bbox_targets  
  240.   
  241.         # bbox_inside_weights  
  242.         bbox_inside_weights = bbox_inside_weights \  
  243.             .reshape((1, height, width, A * 4)).transpose(0312)  
  244.         assert bbox_inside_weights.shape[2] == height  
  245.         assert bbox_inside_weights.shape[3] == width  
  246.         top[2].reshape(*bbox_inside_weights.shape)  
  247.         top[2].data[...] = bbox_inside_weights  
  248.   
  249.         # bbox_outside_weights  
  250.         bbox_outside_weights = bbox_outside_weights \  
  251.             .reshape((1, height, width, A * 4)).transpose(0312)  
  252.         assert bbox_outside_weights.shape[2] == height  
  253.         assert bbox_outside_weights.shape[3] == width  
  254.         top[3].reshape(*bbox_outside_weights.shape)  
  255.         top[3].data[...] = bbox_outside_weights  
  256.   
  257.     def backward(self, top, propagate_down, bottom):  
  258.         """This layer does not propagate gradients."""  
  259.         pass  
  260.   
  261.     def reshape(self, bottom, top):  
  262.         """Reshaping happens during the call to forward."""  
  263.         pass  
  264.   
  265.   
  266. def _unmap(data, count, inds, fill=0):  
  267.     """ Unmap a subset of item (data) back to the original set of items (of 
  268.     size count) """  
  269.     if len(data.shape) == 1:  
  270.         ret = np.empty((count, ), dtype=np.float32)  
  271.         ret.fill(fill)  
  272.         ret[inds] = data  
  273.     else:  
  274.         ret = np.empty((count, ) + data.shape[1:], dtype=np.float32)  
  275.         ret.fill(fill)  
  276.         ret[inds, :] = data  
  277.     return ret  
  278.   
  279.   
  280. def _compute_targets(ex_rois, gt_rois):  
  281.     """Compute bounding-box regression targets for an image."""  
  282.   
  283.     assert ex_rois.shape[0] == gt_rois.shape[0]  
  284.     assert ex_rois.shape[1] == 4  
  285.     assert gt_rois.shape[1] == 5  
  286.   
  287.     return bbox_transform(ex_rois, gt_rois[:, :4]).astype(np.float32, copy=False)  




用到了lib/rpn/generate_anchors.py里的函数

  1. #coding:utf-8  
  2. # --------------------------------------------------------  
  3. # Faster R-CNN  
  4. # Copyright (c) 2015 Microsoft  
  5. # Licensed under The MIT License [see LICENSE for details]  
  6. # Written by Ross Girshick and Sean Bell  
  7. # --------------------------------------------------------  
  8.   
  9.   
  10. import numpy as np  
  11.   
  12.   
  13. #下面是产生的9个anchor的坐标,每个box为(xmin,ymin,xmax,ymax),每个box的中心都是(8.5,8.5),所以会有负值  
  14. # Verify that we compute the same anchors as Shaoqing's matlab implementation:  
  15. #  
  16. #    >> load output/rpn_cachedir/faster_rcnn_VOC2007_ZF_stage1_rpn/anchors.mat  
  17. #    >> anchors  
  18. #  
  19. #    anchors =  
  20. #  
  21. #       -83   -39   100    56  
  22. #      -175   -87   192   104  
  23. #      -359  -183   376   200  
  24. #       -55   -55    72    72  
  25. #      -119  -119   136   136  
  26. #      -247  -247   264   264  
  27. #       -35   -79    52    96  
  28. #       -79  -167    96   184  
  29. #      -167  -343   184   360  
  30.   
  31. #array([[ -83.,  -39.,  100.,   56.],  
  32. #       [-175.,  -87.,  192.,  104.],  
  33. #       [-359., -183.,  376.,  200.],  
  34. #       [ -55.,  -55.,   72.,   72.],  
  35. #       [-119., -119.,  136.,  136.],  
  36. #       [-247., -247.,  264.,  264.],  
  37. #       [ -35.,  -79.,   52.,   96.],  
  38. #       [ -79., -167.,   96.,  184.],  
  39. #       [-167., -343.,  184.,  360.]])  
  40.   
  41. def generate_anchors(base_size=16, ratios=[0.512],  
  42.                      scales=2**np.arange(36)):  
  43.     """ 
  44.     Generate anchor (reference) windows by enumerating aspect ratios X 
  45.     scales wrt a reference (0, 0, 15, 15) window. 
  46.     """  
  47. #base_anchor的大小为(0,0,15,15),其他anchor在此基础上变换产生  
  48.     base_anchor = np.array([11, base_size, base_size]) - 1  
  49. #产生不同长宽比的anchor,面积一样,中心一样  
  50.     ratio_anchors = _ratio_enum(base_anchor, ratios)  
  51.     anchors = np.vstack([_scale_enum(ratio_anchors[i, :], scales)  
  52.                          for i in xrange(ratio_anchors.shape[0])])  
  53.     return anchors  
  54.   
  55. def _whctrs(anchor):  
  56.     """ 
  57.     Return width, height, x center, and y center for an anchor (window). 
  58.     """  
  59.   
  60.     w = anchor[2] - anchor[0] + 1  
  61.     h = anchor[3] - anchor[1] + 1  
  62.     x_ctr = anchor[0] + 0.5 * (w - 1)  
  63.     y_ctr = anchor[1] + 0.5 * (h - 1)  
  64.     return w, h, x_ctr, y_ctr  
  65.   
  66. def _mkanchors(ws, hs, x_ctr, y_ctr):  
  67.     """ 
  68.     Given a vector of widths (ws) and heights (hs) around a center 
  69.     (x_ctr, y_ctr), output a set of anchors (windows). 
  70.     """  
  71.   
  72.     ws = ws[:, np.newaxis]  
  73.     hs = hs[:, np.newaxis]  
  74.     anchors = np.hstack((x_ctr - 0.5 * (ws - 1),  
  75.                          y_ctr - 0.5 * (hs - 1),  
  76.                          x_ctr + 0.5 * (ws - 1),  
  77.                          y_ctr + 0.5 * (hs - 1)))  
  78.     return anchors  
  79.   
  80. def _ratio_enum(anchor, ratios):  
  81.     """ 
  82.     Enumerate a set of anchors for each aspect ratio wrt an anchor. 
  83.     """  
  84.   
  85.     w, h, x_ctr, y_ctr = _whctrs(anchor)  
  86.     size = w * h  
  87.     size_ratios = size / ratios  
  88.     ws = np.round(np.sqrt(size_ratios))  
  89.     hs = np.round(ws * ratios)  
  90.     anchors = _mkanchors(ws, hs, x_ctr, y_ctr)  
  91.     return anchors  
  92.   
  93. #产生不同面积大小的anchor,长宽比不变,长宽均变为原来的scale倍  
  94. def _scale_enum(anchor, scales):  
  95.     """ 
  96.     Enumerate a set of anchors for each scale wrt an anchor. 
  97.     """  
  98.   
  99.     w, h, x_ctr, y_ctr = _whctrs(anchor)  
  100.     ws = w * scales  
  101.     hs = h * scales  
  102.     anchors = _mkanchors(ws, hs, x_ctr, y_ctr)  
  103.     return anchors  
  104.   
  105. if __name__ == '__main__':  
  106.     import time  
  107.     t = time.time()  
  108.     a = generate_anchors()  
  109.     print time.time() - t  
  110.     print a  
  111.     from IPython import embed; embed()  



rpn_test.pt

  1. #用RPN产生region proposal时的网络结构,这个网络只用前向计算  
  2. name: "ZF"  
  3.   
  4. input: "data"  
  5. input_shape {  
  6.   dim: 1  
  7.   dim: 3  
  8.   dim: 224  
  9.   dim: 224  
  10. }  
  11.   
  12. input: "im_info"  
  13. input_shape {  
  14.   dim: 1  
  15.   dim: 3  
  16. }  
  17. #前面是ZF网,特征提取用,共享  
  18. # ------------------------ layer 1 -----------------------------  
  19. layer {  
  20.     name: "conv1"  
  21.     type: "Convolution"  
  22.     bottom: "data"  
  23.     top: "conv1"  
  24.     convolution_param {  
  25.         num_output: 96  
  26.         kernel_size: 7  
  27.         pad: 3  
  28.         stride: 2  
  29.     }  
  30. }  
  31. layer {  
  32.     name: "relu1"  
  33.     type: "ReLU"  
  34.     bottom: "conv1"  
  35.     top: "conv1"  
  36. }  
  37. layer {  
  38.     name: "norm1"  
  39.     type: "LRN"  
  40.     bottom: "conv1"  
  41.     top: "norm1"  
  42.     lrn_param {  
  43.         local_size: 3  
  44.         alpha: 0.00005  
  45.         beta: 0.75  
  46.         norm_region: WITHIN_CHANNEL  
  47.     engine: CAFFE  
  48.     }  
  49. }  
  50. layer {  
  51.     name: "pool1"  
  52.     type: "Pooling"  
  53.     bottom: "norm1"  
  54.     top: "pool1"  
  55.     pooling_param {  
  56.         kernel_size: 3  
  57.         stride: 2  
  58.         pad: 1  
  59.         pool: MAX  
  60.     }  
  61. }  
  62. layer {  
  63.     name: "conv2"  
  64.     type: "Convolution"  
  65.     bottom: "pool1"  
  66.     top: "conv2"  
  67.     convolution_param {  
  68.         num_output: 256  
  69.         kernel_size: 5  
  70.         pad: 2  
  71.         stride: 2  
  72.     }  
  73. }  
  74. layer {  
  75.     name: "relu2"  
  76.     type: "ReLU"  
  77.     bottom: "conv2"  
  78.     top: "conv2"  
  79. }  
  80.   
  81. layer {  
  82.     name: "norm2"  
  83.     type: "LRN"  
  84.     bottom: "conv2"  
  85.     top: "norm2"  
  86.     lrn_param {  
  87.         local_size: 3  
  88.         alpha: 0.00005  
  89.         beta: 0.75  
  90.         norm_region: WITHIN_CHANNEL  
  91.     engine: CAFFE  
  92.     }  
  93. }  
  94. layer {  
  95.     name: "pool2"  
  96.     type: "Pooling"  
  97.     bottom: "norm2"  
  98.     top: "pool2"  
  99.     pooling_param {  
  100.         kernel_size: 3  
  101.         stride: 2  
  102.         pad: 1  
  103.         pool: MAX  
  104.     }  
  105. }  
  106. layer {  
  107.     name: "conv3"  
  108.     type: "Convolution"  
  109.     bottom: "pool2"  
  110.     top: "conv3"  
  111.     convolution_param {  
  112.         num_output: 384  
  113.         kernel_size: 3  
  114.         pad: 1  
  115.         stride: 1  
  116.     }  
  117. }  
  118. layer {  
  119.     name: "relu3"  
  120.     type: "ReLU"  
  121.     bottom: "conv3"  
  122.     top: "conv3"  
  123. }  
  124. layer {  
  125.     name: "conv4"  
  126.     type: "Convolution"  
  127.     bottom: "conv3"  
  128.     top: "conv4"  
  129.     convolution_param {  
  130.         num_output: 384  
  131.         kernel_size: 3  
  132.         pad: 1  
  133.         stride: 1  
  134.     }  
  135. }  
  136. layer {  
  137.     name: "relu4"  
  138.     type: "ReLU"  
  139.     bottom: "conv4"  
  140.     top: "conv4"  
  141. }  
  142. layer {  
  143.     name: "conv5"  
  144.     type: "Convolution"  
  145.     bottom: "conv4"  
  146.     top: "conv5"  
  147.     convolution_param {  
  148.         num_output: 256#经过最后一层,产生256个特征图  
  149.         kernel_size: 3  
  150.         pad: 1  
  151.         stride: 1  
  152.     }  
  153. }  
  154. layer {  
  155.     name: "relu5"  
  156.     type: "ReLU"  
  157.     bottom: "conv5"  
  158.     top: "conv5"  
  159. }  
  160.   
  161. #-----------------------layer +-------------------------  
  162. #RPN在conv5上滑动窗口,256*3*3*256卷积核,预测每个位置9个anchor是否属于前景,  
  163. #如果属于前景,box的修正位置  
  164. layer {  
  165.   name: "rpn_conv1"  
  166.   type: "Convolution"  
  167.   bottom: "conv5"  
  168.   top: "rpn_conv1"  
  169.   convolution_param {  
  170.     num_output: 256  
  171.     kernel_size: 3 pad: 1 stride: 1  
  172.   }  
  173. }  
  174. layer {  
  175.   name: "rpn_relu1"  
  176.   type: "ReLU"  
  177.   bottom: "rpn_conv1"  
  178.   top: "rpn_conv1"  
  179. }  
  180. layer {  
  181.   name: "rpn_cls_score"  
  182.   type: "Convolution"  
  183.   bottom: "rpn_conv1"  
  184.   top: "rpn_cls_score"  
  185.   convolution_param {  
  186.     num_output: 18   # 2(bg/fg) * 9(anchors)#输出预测每个位置9个anchor,属于bg或fg  
  187.     kernel_size: 1 pad: 0 stride: 1  
  188.   }  
  189. }  
  190. layer {  
  191.   name: "rpn_bbox_pred"  
  192.   type: "Convolution"  
  193.   bottom: "rpn_conv1"  
  194.   top: "rpn_bbox_pred"  
  195.   convolution_param {  
  196.     num_output: 36   # 4 * 9(anchors)#输出预测9个anchor的修正坐标  
  197.     kernel_size: 1 pad: 0 stride: 1  
  198.   }  
  199. }  
  200. layer {  
  201.    bottom: "rpn_cls_score"  
  202.    top: "rpn_cls_score_reshape"  
  203.    name: "rpn_cls_score_reshape"  
  204.    type: "Reshape"  
  205.    reshape_param { shape { dim: 0 dim: 2 dim: -1 dim: 0 } }  
  206. }  
  207.   
  208. #-----------------------output------------------------  
  209. layer {  
  210.   name: "rpn_cls_prob"  
  211.   type: "Softmax"  
  212.   bottom: "rpn_cls_score_reshape"  
  213.   top: "rpn_cls_prob"  
  214. }  
  215. layer {  
  216.   name: 'rpn_cls_prob_reshape'  
  217.   type: 'Reshape'  
  218.   bottom: 'rpn_cls_prob'  
  219.   top: 'rpn_cls_prob_reshape'  
  220.   reshape_param { shape { dim: 0 dim: 18 dim: -1 dim: 0 } }  
  221. }  
  222. layer {  
  223.   name: 'proposal'  
  224.   type: 'Python'  
  225.   bottom: 'rpn_cls_prob_reshape'  
  226.   bottom: 'rpn_bbox_pred'  
  227.   bottom: 'im_info'  
  228.   top: 'rois'  
  229.   top: 'scores'  
  230.   python_param {  
  231.     module: 'rpn.proposal_layer'#对应lib/rpn/proposal_layer.py  
  232.     layer: 'ProposalLayer'  
  233.     param_str: "'feat_stride': 16"  
  234.   }  
  235. }  


lib/rpn/proposal_layer.py

这一层用来由RPN产生region proposal


  1. #coding:utf-8  
  2. # --------------------------------------------------------  
  3. # Faster R-CNN  
  4. # Copyright (c) 2015 Microsoft  
  5. # Licensed under The MIT License [see LICENSE for details]  
  6. # Written by Ross Girshick and Sean Bell  
  7. # --------------------------------------------------------  
  8.   
  9. import caffe  
  10. import numpy as np  
  11. import yaml  
  12. from fast_rcnn.config import cfg  
  13. from generate_anchors import generate_anchors  
  14. from fast_rcnn.bbox_transform import bbox_transform_inv, clip_boxes  
  15. from fast_rcnn.nms_wrapper import nms  
  16.   
  17. DEBUG = False  
  18.   
  19. class ProposalLayer(caffe.Layer):  
  20.     """ 
  21.     Outputs object detection proposals by applying estimated bounding-box 
  22.     transformations to a set of regular boxes (called "anchors"). 
  23.     """  
  24.   
  25.     def setup(self, bottom, top):  
  26.         # parse the layer parameter string, which must be valid YAML  
  27.         layer_params = yaml.load(self.param_str_)  
  28. #16,提取特征后的feature map的大小是原来的1/16  
  29.         self._feat_stride = layer_params['feat_stride']  
  30.         anchor_scales = layer_params.get('scales', (81632))  
  31. #产生anchors  
  32.         self._anchors = generate_anchors(scales=np.array(anchor_scales))  
  33.         self._num_anchors = self._anchors.shape[0]  
  34.   
  35.         if DEBUG:  
  36.             print 'feat_stride: {}'.format(self._feat_stride)  
  37.             print 'anchors:'  
  38.             print self._anchors  
  39.   
  40.         # rois blob: holds R regions of interest, each is a 5-tuple  
  41.         # (n, x1, y1, x2, y2) specifying an image batch index n and a  
  42.         # rectangle (x1, y1, x2, y2)  
  43.         top[0].reshape(15)  
  44.   
  45.         # scores blob: holds scores for R regions of interest  
  46.         if len(top) > 1:  
  47.             top[1].reshape(1111)  
  48. #英文解释得很清楚  
  49.     def forward(self, bottom, top):  
  50.         # Algorithm:  
  51.         #  
  52.         # for each (H, W) location i  
  53.         #1.generate A anchor boxes centered on cell i  
  54.         #2.apply predicted bbox deltas at cell i to each of the A anchors  
  55.         #3.clip predicted boxes to image  
  56.         #4.remove predicted boxes with either height or width < threshold  
  57.         #5.sort all (proposal, score) pairs by score from highest to lowest  
  58.         #6.take top pre_nms_topN proposals before NMS  
  59.         #7.apply NMS with threshold 0.7 to remaining proposals  
  60.         #8.take after_nms_topN proposals after NMS  
  61.         #9.return the top proposals (-> RoIs top, scores top)  
  62.   
  63.         assert bottom[0].data.shape[0] == 1, \  
  64.             'Only single item batches are supported'  
  65.   
  66.         cfg_key = str(self.phase) # either 'TRAIN' or 'TEST'  
  67.         pre_nms_topN  = cfg[cfg_key].RPN_PRE_NMS_TOP_N  
  68.         post_nms_topN = cfg[cfg_key].RPN_POST_NMS_TOP_N  
  69.         nms_thresh    = cfg[cfg_key].RPN_NMS_THRESH  
  70.         min_size      = cfg[cfg_key].RPN_MIN_SIZE  
  71.   
  72.         # the first set of _num_anchors channels are bg probs  
  73.         # the second set are the fg probs, which we want  
  74.         scores = bottom[0].data[:, self._num_anchors:, :, :]  
  75.         bbox_deltas = bottom[1].data  
  76.         im_info = bottom[2].data[0, :]  
  77.   
  78.         if DEBUG:  
  79.             print 'im_size: ({}, {})'.format(im_info[0], im_info[1])  
  80.             print 'scale: {}'.format(im_info[2])  
  81.   
  82.         # 1. Generate proposals from bbox deltas and shifted anchors  
  83.         height, width = scores.shape[-2:]  
  84.   
  85.         if DEBUG:  
  86.             print 'score map size: {}'.format(scores.shape)  
  87.   
  88.         # Enumerate all shifts  
  89.         shift_x = np.arange(0, width) * self._feat_stride  
  90.         shift_y = np.arange(0, height) * self._feat_stride  
  91.         shift_x, shift_y = np.meshgrid(shift_x, shift_y)  
  92.         shifts = np.vstack((shift_x.ravel(), shift_y.ravel(),  
  93.                             shift_x.ravel(), shift_y.ravel())).transpose()  
  94.   
  95.         # Enumerate all shifted anchors:  
  96.         #  
  97.         # add A anchors (1, A, 4) to  
  98.         # cell K shifts (K, 1, 4) to get  
  99.         # shift anchors (K, A, 4)  
  100.         # reshape to (K*A, 4) shifted anchors  
  101.         A = self._num_anchors  
  102.         K = shifts.shape[0]  
  103.         anchors = self._anchors.reshape((1, A, 4)) + \  
  104.                   shifts.reshape((1, K, 4)).transpose((102))  
  105.         anchors = anchors.reshape((K * A, 4))  
  106.   
  107.         # Transpose and reshape predicted bbox transformations to get them  
  108.         # into the same order as the anchors:  
  109.         #  
  110.         # bbox deltas will be (1, 4 * A, H, W) format  
  111.         # transpose to (1, H, W, 4 * A)  
  112.         # reshape to (1 * H * W * A, 4) where rows are ordered by (h, w, a)  
  113.         # in slowest to fastest order  
  114.         bbox_deltas = bbox_deltas.transpose((0231)).reshape((-14))  
  115.   
  116.         # Same story for the scores:  
  117.         #  
  118.         # scores are (1, A, H, W) format  
  119.         # transpose to (1, H, W, A)  
  120.         # reshape to (1 * H * W * A, 1) where rows are ordered by (h, w, a)  
  121.         scores = scores.transpose((0231)).reshape((-11))  
  122.   
  123.         # Convert anchors into proposals via bbox transformations  
  124.         proposals = bbox_transform_inv(anchors, bbox_deltas)  
  125.   
  126.         # 2. clip predicted boxes to image  
  127.         proposals = clip_boxes(proposals, im_info[:2])  
  128. #过滤掉width或height小于RPN_MIN_SIZE的proposal  
  129.         # 3. remove predicted boxes with either height or width < threshold  
  130.         # (NOTE: convert min_size to input image scale stored in im_info[2])  
  131.         keep = _filter_boxes(proposals, min_size * im_info[2])  
  132.         proposals = proposals[keep, :]  
  133.         scores = scores[keep]  
  134.   
  135.         # 4. sort all (proposal, score) pairs by score from highest to lowest  
  136.         # 5. take top pre_nms_topN (e.g. 6000)  
  137.         order = scores.ravel().argsort()[::-1]  
  138.         if pre_nms_topN > 0:  
  139.             order = order[:pre_nms_topN]  
  140.         proposals = proposals[order, :]  
  141.         scores = scores[order]  
  142.   
  143.         # 6. apply nms (e.g. threshold = 0.7)  
  144.         # 7. take after_nms_topN (e.g. 300)  
  145.         # 8. return the top proposals (-> RoIs top)  
  146.         keep = nms(np.hstack((proposals, scores)), nms_thresh)  
  147.         if post_nms_topN > 0:  
  148.             keep = keep[:post_nms_topN]  
  149.         proposals = proposals[keep, :]  
  150.         scores = scores[keep]  
  151.   
  152.         # Output rois blob  
  153.         # Our RPN implementation only supports a single input image, so all  
  154.         # batch inds are 0  
  155.         batch_inds = np.zeros((proposals.shape[0], 1), dtype=np.float32)  
  156.         blob = np.hstack((batch_inds, proposals.astype(np.float32, copy=False)))  
  157.         top[0].reshape(*(blob.shape))  
  158.         top[0].data[...] = blob  
  159.   
  160.         # [Optional] output scores blob  
  161.         if len(top) > 1:  
  162.             top[1].reshape(*(scores.shape))  
  163.             top[1].data[...] = scores  
  164.   
  165.     def backward(self, top, propagate_down, bottom):  
  166.         """This layer does not propagate gradients."""  
  167.         pass  
  168.   
  169.     def reshape(self, bottom, top):  
  170.         """Reshaping happens during the call to forward."""  
  171.         pass  
  172.   
  173. def _filter_boxes(boxes, min_size):  
  174.     """Remove all boxes with any side smaller than min_size."""  
  175.     ws = boxes[:, 2] - boxes[:, 0] + 1  
  176.     hs = boxes[:, 3] - boxes[:, 1] + 1  
  177.     keep = np.where((ws >= min_size) & (hs >= min_size))[0]  
  178.     return keep  





fast_rcnn_train.pt

  1. #stage 1训练fast rcnn网络,输入是rpn提取的roi以及gt box  
  2. name: "ZF"  
  3. layer {  
  4.   name: 'data'  
  5.   type: 'Python'  
  6.   top: 'data'  
  7.   top: 'rois'  
  8.   top: 'labels'  
  9.   top: 'bbox_targets'  
  10.   top: 'bbox_inside_weights'  
  11.   top: 'bbox_outside_weights'  
  12.   python_param {  
  13.     module: 'roi_data_layer.layer'#对应lib/roi_data_layer/layer.py  
  14. #为训练fast rcnn时为网络输入roi,此时为roi是region proposal  
  15.     layer: 'RoIDataLayer'  
  16.     param_str: "'num_classes': 21"  
  17.   }  
  18. }  
  19.   
  20. #ZF网,特征提取用,共享  
  21. #========= conv1-conv5 ============  
  22.   
  23. layer {  
  24.     name: "conv1"  
  25.     type: "Convolution"  
  26.     bottom: "data"  
  27.     top: "conv1"  
  28.     param { lr_mult: 1.0 }  
  29.     param { lr_mult: 2.0 }  
  30.     convolution_param {  
  31.         num_output: 96  
  32.         kernel_size: 7  
  33.         pad: 3  
  34.         stride: 2  
  35.     }  
  36. }  
  37. layer {  
  38.     name: "relu1"  
  39.     type: "ReLU"  
  40.     bottom: "conv1"  
  41.     top: "conv1"  
  42. }  
  43. layer {  
  44.     name: "norm1"  
  45.     type: "LRN"  
  46.     bottom: "conv1"  
  47.     top: "norm1"  
  48.     lrn_param {  
  49.         local_size: 3  
  50.         alpha: 0.00005  
  51.         beta: 0.75  
  52.         norm_region: WITHIN_CHANNEL  
  53.     engine: CAFFE  
  54.     }  
  55. }  
  56. layer {  
  57.     name: "pool1"  
  58.     type: "Pooling"  
  59.     bottom: "norm1"  
  60.     top: "pool1"  
  61.     pooling_param {  
  62.         kernel_size: 3  
  63.         stride: 2  
  64.         pad: 1  
  65.         pool: MAX  
  66.     }  
  67. }  
  68. layer {  
  69.     name: "conv2"  
  70.     type: "Convolution"  
  71.     bottom: "pool1"  
  72.     top: "conv2"  
  73.     param { lr_mult: 1.0 }  
  74.     param { lr_mult: 2.0 }  
  75.     convolution_param {  
  76.         num_output: 256  
  77.         kernel_size: 5  
  78.         pad: 2  
  79.         stride: 2  
  80.     }  
  81. }  
  82. layer {  
  83.     name: "relu2"  
  84.     type: "ReLU"  
  85.     bottom: "conv2"  
  86.     top: "conv2"  
  87. }  
  88. layer {  
  89.     name: "norm2"  
  90.     type: "LRN"  
  91.     bottom: "conv2"  
  92.     top: "norm2"  
  93.     lrn_param {  
  94.         local_size: 3  
  95.         alpha: 0.00005  
  96.         beta: 0.75  
  97.         norm_region: WITHIN_CHANNEL  
  98.     engine: CAFFE  
  99.     }  
  100. }  
  101. layer {  
  102.     name: "pool2"  
  103.     type: "Pooling"  
  104.     bottom: "norm2"  
  105.     top: "pool2"  
  106.     pooling_param {  
  107.         kernel_size: 3  
  108.         stride: 2  
  109.         pad: 1  
  110.         pool: MAX  
  111.     }  
  112. }  
  113. layer {  
  114.     name: "conv3"  
  115.     type: "Convolution"  
  116.     bottom: "pool2"  
  117.     top: "conv3"  
  118.     param { lr_mult: 1.0 }  
  119.     param { lr_mult: 2.0 }  
  120.     convolution_param {  
  121.         num_output: 384  
  122.         kernel_size: 3  
  123.         pad: 1  
  124.         stride: 1  
  125.     }  
  126. }  
  127. layer {  
  128.     name: "relu3"  
  129.     type: "ReLU"  
  130.     bottom: "conv3"  
  131.     top: "conv3"  
  132. }  
  133. layer {  
  134.     name: "conv4"  
  135.     type: "Convolution"  
  136.     bottom: "conv3"  
  137.     top: "conv4"  
  138.     param { lr_mult: 1.0 }  
  139.     param { lr_mult: 2.0 }  
  140.     convolution_param {  
  141.         num_output: 384  
  142.         kernel_size: 3  
  143.         pad: 1  
  144.         stride: 1  
  145.     }  
  146. }  
  147. layer {  
  148.     name: "relu4"  
  149.     type: "ReLU"  
  150.     bottom: "conv4"  
  151.     top: "conv4"  
  152. }  
  153. layer {  
  154.     name: "conv5"  
  155.     type: "Convolution"  
  156.     bottom: "conv4"  
  157.     top: "conv5"  
  158.     param { lr_mult: 1.0 }  
  159.     param { lr_mult: 2.0 }  
  160.     convolution_param {  
  161.         num_output: 256  
  162.         kernel_size: 3  
  163.         pad: 1  
  164.         stride: 1  
  165.     }  
  166. }  
  167. layer {  
  168.     name: "relu5"  
  169.     type: "ReLU"  
  170.     bottom: "conv5"  
  171.     top: "conv5"  
  172. }  
  173.   
  174. #========= RCNN ============  
  175.   
  176. layer {  
  177.   name: "roi_pool_conv5"  
  178.   type: "ROIPooling"#这个层在caffe-fast-rcnn里实现  
  179.   bottom: "conv5"  
  180.   bottom: "rois"  
  181.   top: "roi_pool_conv5"  
  182.   roi_pooling_param {#每个roi做max pooling后的大小为6*6  
  183.     pooled_w: 6  
  184.     pooled_h: 6  
  185.     spatial_scale: 0.0625 # 1/16  
  186.   }  
  187. }  
  188. layer {  
  189.   name: "fc6"  
  190.   type: "InnerProduct"  
  191.   bottom: "roi_pool_conv5"  
  192.   top: "fc6"  
  193.   param { lr_mult: 1.0 }  
  194.   param { lr_mult: 2.0 }  
  195.   inner_product_param {  
  196.     num_output: 4096  
  197.   }  
  198. }  
  199. layer {  
  200.   name: "relu6"  
  201.   type: "ReLU"  
  202.   bottom: "fc6"  
  203.   top: "fc6"  
  204. }  
  205. layer {  
  206.   name: "drop6"  
  207.   type: "Dropout"  
  208.   bottom: "fc6"  
  209.   top: "fc6"  
  210.   dropout_param {  
  211.     dropout_ratio: 0.5  
  212.     scale_train: false  
  213.   }  
  214. }  
  215. layer {  
  216.   name: "fc7"  
  217.   type: "InnerProduct"  
  218.   bottom: "fc6"  
  219.   top: "fc7"  
  220.   param { lr_mult: 1.0 }  
  221.   param { lr_mult: 2.0 }  
  222.   inner_product_param {  
  223.     num_output: 4096  
  224.   }  
  225. }  
  226. layer {  
  227.   name: "relu7"  
  228.   type: "ReLU"  
  229.   bottom: "fc7"  
  230.   top: "fc7"  
  231. }  
  232. layer {  
  233.   name: "drop7"  
  234.   type: "Dropout"  
  235.   bottom: "fc7"  
  236.   top: "fc7"  
  237.   dropout_param {  
  238.     dropout_ratio: 0.5  
  239.     scale_train: false  
  240.   }  
  241. }  
  242. layer {  
  243.   name: "cls_score"  
  244.   type: "InnerProduct"  
  245.   bottom: "fc7"  
  246.   top: "cls_score"  
  247.   param { lr_mult: 1.0 }  
  248.   param { lr_mult: 2.0 }  
  249.   inner_product_param {  
  250.     num_output: 21  
  251.     weight_filler {  
  252.       type: "gaussian"  
  253.       std: 0.01  
  254.     }  
  255.     bias_filler {  
  256.       type: "constant"  
  257.       value: 0  
  258.     }  
  259.   }  
  260. }  
  261. layer {  
  262.   name: "bbox_pred"  
  263.   type: "InnerProduct"  
  264.   bottom: "fc7"  
  265.   top: "bbox_pred"  
  266.   param { lr_mult: 1.0 }  
  267.   param { lr_mult: 2.0 }  
  268.   inner_product_param {  
  269.     num_output: 84  
  270.     weight_filler {  
  271.       type: "gaussian"  
  272.       std: 0.001  
  273.     }  
  274.     bias_filler {  
  275.       type: "constant"  
  276.       value: 0  
  277.     }  
  278.   }  
  279. }  
  280. layer {  
  281.   name: "loss_cls"  
  282.   type: "SoftmaxWithLoss"  
  283.   bottom: "cls_score"  
  284.   bottom: "labels"  
  285.   propagate_down: 1  
  286.   propagate_down: 0  
  287.   top: "cls_loss"  
  288.   loss_weight: 1  
  289.   loss_param {  
  290.     ignore_label: -1  
  291.     normalize: true  
  292.   }  
  293. }  
  294. layer {  
  295.   name: "loss_bbox"  
  296.   type: "SmoothL1Loss"  
  297.   bottom: "bbox_pred"  
  298.   bottom: "bbox_targets"  
  299.   bottom: "bbox_inside_weights"  
  300.   bottom: "bbox_outside_weights"  
  301.   top: "bbox_loss"  
  302.   loss_weight: 1  
  303. }  
  304.   
  305. #========= RPN ============  
  306. # Dummy layers so that initial parameters are saved into the output net  
  307.   
  308. layer {  
  309.   name: "rpn_conv1"  
  310.   type: "Convolution"  
  311.   bottom: "conv5"  
  312.   top: "rpn_conv1"  
  313.   param { lr_mult: 0 decay_mult: 0 }  
  314.   param { lr_mult: 0 decay_mult: 0 }  
  315.   convolution_param {  
  316.     num_output: 256  
  317.     kernel_size: 3 pad: 1 stride: 1  
  318.     weight_filler { type: "gaussian" std: 0.01 }  
  319.     bias_filler { type: "constant" value: 0 }  
  320.   }  
  321. }  
  322. layer {  
  323.   name: "rpn_relu1"  
  324.   type: "ReLU"  
  325.   bottom: "rpn_conv1"  
  326.   top: "rpn_conv1"  
  327. }  
  328. layer {  
  329.   name: "rpn_cls_score"  
  330.   type: "Convolution"  
  331.   bottom: "rpn_conv1"  
  332.   top: "rpn_cls_score"  
  333.   param { lr_mult: 0 decay_mult: 0 }  
  334.   param { lr_mult: 0 decay_mult: 0 }  
  335.   convolution_param {  
  336.     num_output: 18   # 2(bg/fg) * 9(anchors)  
  337.     kernel_size: 1 pad: 0 stride: 1  
  338.     weight_filler { type: "gaussian" std: 0.01 }  
  339.     bias_filler { type: "constant" value: 0 }  
  340.   }  
  341. }  
  342. layer {  
  343.   name: "rpn_bbox_pred"  
  344.   type: "Convolution"  
  345.   bottom: "rpn_conv1"  
  346.   top: "rpn_bbox_pred"  
  347.   param { lr_mult: 0 decay_mult: 0 }  
  348.   param { lr_mult: 0 decay_mult: 0 }  
  349.   convolution_param {  
  350.     num_output: 36   # 4 * 9(anchors)  
  351.     kernel_size: 1 pad: 0 stride: 1  
  352.     weight_filler { type: "gaussian" std: 0.01 }  
  353.     bias_filler { type: "constant" value: 0 }  
  354.   }  
  355. }  
  356. layer {  
  357.   name: "silence_rpn_cls_score"  
  358.   type: "Silence"  
  359.   bottom: "rpn_cls_score"  
  360. }  
  361. layer {  
  362.   name: "silence_rpn_bbox_pred"  
  363.   type: "Silence"  
  364.   bottom: "rpn_bbox_pred"  
  365. }  


其中roi pooling layer在 caffe/src/layers/roi_pooling_layer.cpp里实现



  1. // ------------------------------------------------------------------  
  2. // Fast R-CNN  
  3. // Copyright (c) 2015 Microsoft  
  4. // Licensed under The MIT License [see fast-rcnn/LICENSE for details]  
  5. // Written by Ross Girshick  
  6. // ------------------------------------------------------------------  
  7.   
  8. #include <cfloat>  
  9.   
  10. #include "caffe/fast_rcnn_layers.hpp"  
  11.   
  12. using std::max;  
  13. using std::min;  
  14. using std::floor;  
  15. using std::ceil;  
  16.   
  17. namespace caffe {  
  18.   
  19. template <typename Dtype>  
  20. void ROIPoolingLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  21.       const vector<Blob<Dtype>*>& top) {  
  22.   ROIPoolingParameter roi_pool_param = this->layer_param_.roi_pooling_param();  
  23.   CHECK_GT(roi_pool_param.pooled_h(), 0)  
  24.       << "pooled_h must be > 0";  
  25.   CHECK_GT(roi_pool_param.pooled_w(), 0)  
  26.       << "pooled_w must be > 0";  
  27.   pooled_height_ = roi_pool_param.pooled_h();  
  28.   pooled_width_ = roi_pool_param.pooled_w();  
  29.   spatial_scale_ = roi_pool_param.spatial_scale();  
  30.   LOG(INFO) << "Spatial scale: " << spatial_scale_;  
  31. }  
  32.   
  33. template <typename Dtype>  
  34. void ROIPoolingLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,  
  35.       const vector<Blob<Dtype>*>& top) {  
  36.   channels_ = bottom[0]->channels();  
  37.   height_ = bottom[0]->height();  
  38.   width_ = bottom[0]->width();  
  39.   top[0]->Reshape(bottom[1]->num(), channels_, pooled_height_,  
  40.       pooled_width_);  
  41.   max_idx_.Reshape(bottom[1]->num(), channels_, pooled_height_,  
  42.       pooled_width_);  
  43. }  
  44.   
  45. template <typename Dtype>  
  46. void ROIPoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,  
  47.       const vector<Blob<Dtype>*>& top) {  
  48.   const Dtype* bottom_data = bottom[0]->cpu_data();  
  49.   const Dtype* bottom_rois = bottom[1]->cpu_data();  
  50.   // Number of ROIs  
  51.   int num_rois = bottom[1]->num();  
  52.   int batch_size = bottom[0]->num();  
  53.   int top_count = top[0]->count();  
  54.   Dtype* top_data = top[0]->mutable_cpu_data();  
  55.   caffe_set(top_count, Dtype(-FLT_MAX), top_data);  
  56.   int* argmax_data = max_idx_.mutable_cpu_data();  
  57.   caffe_set(top_count, -1, argmax_data);  
  58.   
  59.   // For each ROI R = [batch_index x1 y1 x2 y2]: max pool over R  
  60.   for (int n = 0; n < num_rois; ++n) {  
  61.     int roi_batch_ind = bottom_rois[0];  
  62.     int roi_start_w = round(bottom_rois[1] * spatial_scale_);  
  63.     int roi_start_h = round(bottom_rois[2] * spatial_scale_);  
  64.     int roi_end_w = round(bottom_rois[3] * spatial_scale_);  
  65.     int roi_end_h = round(bottom_rois[4] * spatial_scale_);  
  66.     CHECK_GE(roi_batch_ind, 0);  
  67.     CHECK_LT(roi_batch_ind, batch_size);  
  68.   
  69.     int roi_height = max(roi_end_h - roi_start_h + 1, 1);  
  70.     int roi_width = max(roi_end_w - roi_start_w + 1, 1);  
  71.     const Dtype bin_size_h = static_cast<Dtype>(roi_height)  
  72.                              / static_cast<Dtype>(pooled_height_);  
  73.     const Dtype bin_size_w = static_cast<Dtype>(roi_width)  
  74.                              / static_cast<Dtype>(pooled_width_);  
  75.   
  76.     const Dtype* batch_data = bottom_data + bottom[0]->offset(roi_batch_ind);  
  77.   
  78.     for (int c = 0; c < channels_; ++c) {  
  79.       for (int ph = 0; ph < pooled_height_; ++ph) {  
  80.         for (int pw = 0; pw < pooled_width_; ++pw) {  
  81.           // Compute pooling region for this output unit:  
  82.           //  start (included) = floor(ph * roi_height / pooled_height_)  
  83.           //  end (excluded) = ceil((ph + 1) * roi_height / pooled_height_)  
  84.           int hstart = static_cast<int>(floor(static_cast<Dtype>(ph)  
  85.                                               * bin_size_h));  
  86.           int wstart = static_cast<int>(floor(static_cast<Dtype>(pw)  
  87.                                               * bin_size_w));  
  88.           int hend = static_cast<int>(ceil(static_cast<Dtype>(ph + 1)  
  89.                                            * bin_size_h));  
  90.           int wend = static_cast<int>(ceil(static_cast<Dtype>(pw + 1)  
  91.                                            * bin_size_w));  
  92.   
  93.           hstart = min(max(hstart + roi_start_h, 0), height_);  
  94.           hend = min(max(hend + roi_start_h, 0), height_);  
  95.           wstart = min(max(wstart + roi_start_w, 0), width_);  
  96.           wend = min(max(wend + roi_start_w, 0), width_);  
  97.   
  98.           bool is_empty = (hend <= hstart) || (wend <= wstart);  
  99.   
  100.           const int pool_index = ph * pooled_width_ + pw;  
  101.           if (is_empty) {  
  102.             top_data[pool_index] = 0;  
  103.             argmax_data[pool_index] = -1;  
  104.           }  
  105.   
  106.           for (int h = hstart; h < hend; ++h) {  
  107.             for (int w = wstart; w < wend; ++w) {  
  108.               const int index = h * width_ + w;  
  109.               if (batch_data[index] > top_data[pool_index]) {  
  110.                 top_data[pool_index] = batch_data[index];  
  111.                 argmax_data[pool_index] = index;  
  112.               }  
  113.             }  
  114.           }  
  115.         }  
  116.       }  
  117.       // Increment all data pointers by one channel  
  118.       batch_data += bottom[0]->offset(0, 1);  
  119.       top_data += top[0]->offset(0, 1);  
  120.       argmax_data += max_idx_.offset(0, 1);  
  121.     }  
  122.     // Increment ROI data pointer  
  123.     bottom_rois += bottom[1]->offset(1);  
  124.   }  
  125. }  
  126.   
  127. template <typename Dtype>  
  128. void ROIPoolingLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,  
  129.       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {  
  130.   NOT_IMPLEMENTED;  
  131. }  
  132.   
  133.   
  134. #ifdef CPU_ONLY  
  135. STUB_GPU(ROIPoolingLayer);  
  136. #endif  
  137.   
  138. INSTANTIATE_CLASS(ROIPoolingLayer);  
  139. REGISTER_LAYER_CLASS(ROIPooling);  
  140.   
  141. }  // namespace caffe  





大致结构看明白了来看具体训练流程


首先看tools/train_faster_rcnn_alt_opt.py


  1. #coding:utf-8  
  2. #!/usr/bin/env python  
  3.   
  4. # --------------------------------------------------------  
  5. # Faster R-CNN  
  6. # Copyright (c) 2015 Microsoft  
  7. # Licensed under The MIT License [see LICENSE for details]  
  8. # Written by Ross Girshick  
  9. # --------------------------------------------------------  
  10.   
  11. """Train a Faster R-CNN network using alternating optimization. 
  12. This tool implements the alternating optimization algorithm described in our 
  13. NIPS 2015 paper ("Faster R-CNN: Towards Real-time Object Detection with Region 
  14. Proposal Networks." Shaoqing Ren, Kaiming He, Ross Girshick, Jian Sun.) 
  15. """  
  16.   
  17. import _init_paths  
  18. from fast_rcnn.train import get_training_roidb, train_net  
  19. from fast_rcnn.config import cfg, cfg_from_file, cfg_from_list, get_output_dir  
  20. from datasets.factory import get_imdb  
  21. from rpn.generate import imdb_proposals  
  22. import argparse  
  23. import pprint  
  24. import numpy as np  
  25. import sys, os  
  26. import multiprocessing as mp  
  27. import cPickle  
  28. import shutil  
  29.   
  30. def parse_args():  
  31.     """ 
  32.     Parse input arguments 
  33.     """  
  34.     parser = argparse.ArgumentParser(description='Train a Faster R-CNN network')  
  35.     #训练时设置使用哪个GPU  
  36.     parser.add_argument('--gpu', dest='gpu_id',  
  37.                         help='GPU device id to use [0]',  
  38.                         default=0, type=int)  
  39.     #设置训练时使用哪种网络模型  
  40.     parser.add_argument('--net_name', dest='net_name',  
  41.                         help='network name (e.g., "ZF")',  
  42.                         default=None, type=str)  
  43.     #指定预训练的模型来初始化网络  
  44.     parser.add_argument('--weights', dest='pretrained_model',  
  45.                         help='initialize with pretrained model weights',  
  46.                         default=None, type=str)  
  47.     #加载配置文件  
  48.     parser.add_argument('--cfg', dest='cfg_file',  
  49.                         help='optional config file',  
  50.                         default=None, type=str)  
  51.     #训练使用的数据集  
  52.     parser.add_argument('--imdb', dest='imdb_name',  
  53.                         help='dataset to train on',  
  54.                         default='voc_2007_trainval', type=str)  
  55.     parser.add_argument('--set', dest='set_cfgs',  
  56.                         help='set config keys', default=None,  
  57.                         nargs=argparse.REMAINDER)  
  58.   
  59.     if len(sys.argv) == 1:  
  60.         parser.print_help()  
  61.         sys.exit(1)  
  62.   
  63.     args = parser.parse_args()  
  64.     return args  
  65.   
  66. def get_roidb(imdb_name, rpn_file=None):  
  67. #得到图像集(image database)的名字,如pascalvoc——2007——trainval  
  68.     imdb = get_imdb(imdb_name)  
  69.     print 'Loaded dataset `{:s}` for training'.format(imdb.name)  
  70. #设置网络得到proposal的方法,有selective search和RPN、gt,selective search已弃用  
  71.     imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)  
  72.     print 'Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)  
  73. #判断之前是否已经有RPN网络提取得到的region proposal文件  
  74.     if rpn_file is not None:  
  75.         imdb.config['rpn_file'] = rpn_file  
  76.     roidb = get_training_roidb(imdb)  
  77.     return roidb, imdb  
  78.   
  79. def get_solvers(net_name):  
  80.     # Faster R-CNN Alternating Optimization  
  81.     n = 'faster_rcnn_alt_opt'  
  82.     # Solver for each training stage  
  83.     solvers = [[net_name, n, 'stage1_rpn_solver60k80k.pt'],  
  84.                [net_name, n, 'stage1_fast_rcnn_solver30k40k.pt'],  
  85.                [net_name, n, 'stage2_rpn_solver60k80k.pt'],  
  86.                [net_name, n, 'stage2_fast_rcnn_solver30k40k.pt']]  
  87.     solvers = [os.path.join(cfg.MODELS_DIR, *s) for s in solvers]  
  88.     # Iterations for each training stage  
  89. #每一轮训练的最大迭代次数,建议测试时都设置为100  
  90.     max_iters = [80000400008000040000]  
  91.     # max_iters = [100, 100, 100, 100]  
  92.     # Test prototxt for the RPN  
  93.     rpn_test_prototxt = os.path.join(  
  94.         cfg.MODELS_DIR, net_name, n, 'rpn_test.pt')  
  95.     return solvers, max_iters, rpn_test_prototxt  
  96.   
  97. # ------------------------------------------------------------------------------  
  98. # Pycaffe doesn't reliably free GPU memory when instantiated nets are discarded  
  99. # (e.g. "del net" in Python code). To work around this issue, each training  
  100. # stage is executed in a separate process using multiprocessing.Process.  
  101. # ------------------------------------------------------------------------------  
  102.   
  103. def _init_caffe(cfg):  
  104.     """Initialize pycaffe in a training process. 
  105.     """  
  106.   
  107.     import caffe  
  108.     # fix the random seeds (numpy and caffe) for reproducibility  
  109.     np.random.seed(cfg.RNG_SEED)  
  110.     caffe.set_random_seed(cfg.RNG_SEED)  
  111.     # set up caffe  
  112.     caffe.set_mode_gpu()  
  113.     caffe.set_device(cfg.GPU_ID)  
  114.   
  115. #训练RPN  
  116. def train_rpn(queue=None, imdb_name=None, init_model=None, solver=None,  
  117.               max_iters=None, cfg=None):  
  118.     """Train a Region Proposal Network in a separate training process. 
  119.     """  
  120.   
  121.     # Not using any proposals, just ground-truth boxes  
  122.     cfg.TRAIN.HAS_RPN = True  
  123.     cfg.TRAIN.BBOX_REG = False  # applies only to Fast R-CNN bbox regression  
  124. #训练RPN时使用ground-truth  
  125.     cfg.TRAIN.PROPOSAL_METHOD = 'gt'  
  126. #每次训练RPN只用一张图片  
  127.     cfg.TRAIN.IMS_PER_BATCH = 1  
  128.     print 'Init model: {}'.format(init_model)  
  129.     print('Using config:')  
  130.     pprint.pprint(cfg)  
  131.   
  132.     import caffe  
  133.     _init_caffe(cfg)  
  134.   
  135.     roidb, imdb = get_roidb(imdb_name)  
  136.     print 'roidb len: {}'.format(len(roidb))  
  137.     output_dir = get_output_dir(imdb)  
  138.     print 'Output will be saved to `{:s}`'.format(output_dir)  
  139. #开始训练RPN网络  
  140.     model_paths = train_net(solver, roidb, output_dir,  
  141.                             pretrained_model=init_model,  
  142.                             max_iters=max_iters)  
  143. #只保留最后得到的网络模型  
  144.     # Cleanup all but the final model  
  145.     for i in model_paths[:-1]:  
  146.         os.remove(i)  
  147.     rpn_model_path = model_paths[-1]  
  148.     # Send final model path through the multiprocessing queue  
  149.     queue.put({'model_path': rpn_model_path})  
  150.   
  151.   
  152. #用训练完的RPN产生region proposal并存到磁盘上  
  153. def rpn_generate(queue=None, imdb_name=None, rpn_model_path=None, cfg=None,  
  154.                  rpn_test_prototxt=None):  
  155.     """Use a trained RPN to generate proposals. 
  156.     """  
  157.   
  158.     cfg.TEST.RPN_PRE_NMS_TOP_N = -1     # no pre NMS filtering  
  159.     cfg.TEST.RPN_POST_NMS_TOP_N = 2000  # limit top boxes after NMS  
  160.     print 'RPN model: {}'.format(rpn_model_path)  
  161.     print('Using config:')  
  162.     pprint.pprint(cfg)  
  163.   
  164.     import caffe  
  165.     _init_caffe(cfg)  
  166.   
  167.     # NOTE: the matlab implementation computes proposals on flipped images, too.  
  168.     # We compute them on the image once and then flip the already computed  
  169.     # proposals. This might cause a minor loss in mAP (less proposal jittering).  
  170.     imdb = get_imdb(imdb_name)  
  171.     print 'Loaded dataset `{:s}` for proposal generation'.format(imdb.name)  
  172.   
  173.     # Load RPN and configure output directory  
  174.     rpn_net = caffe.Net(rpn_test_prototxt, rpn_model_path, caffe.TEST)  
  175.     output_dir = get_output_dir(imdb)  
  176.     print 'Output will be saved to `{:s}`'.format(output_dir)  
  177.     # Generate proposals on the imdb  
  178.     rpn_proposals = imdb_proposals(rpn_net, imdb)  
  179.     # Write proposals to disk and send the proposal file path through the  
  180.     # multiprocessing queue  
  181.     rpn_net_name = os.path.splitext(os.path.basename(rpn_model_path))[0]  
  182.     rpn_proposals_path = os.path.join(  
  183.         output_dir, rpn_net_name + '_proposals.pkl')  
  184.     with open(rpn_proposals_path, 'wb') as f:  
  185.         cPickle.dump(rpn_proposals, f, cPickle.HIGHEST_PROTOCOL)  
  186.     print 'Wrote RPN proposals to {}'.format(rpn_proposals_path)  
  187.     queue.put({'proposal_path': rpn_proposals_path})  
  188.   
  189. #训练fast-rcnn  
  190. def train_fast_rcnn(queue=None, imdb_name=None, init_model=None, solver=None,  
  191.                     max_iters=None, cfg=None, rpn_file=None):  
  192.     """Train a Fast R-CNN using proposals generated by an RPN. 
  193.     """  
  194. #conv5后面现在接的是fast-rcnn  
  195.     cfg.TRAIN.HAS_RPN = False           # not generating prosals on-the-fly  
  196. #roidb由刚刚训练完的RPN产生  
  197.     cfg.TRAIN.PROPOSAL_METHOD = 'rpn'   # use pre-computed RPN proposals instead  
  198. #每次训练fast-rcnn使用两张图片  
  199.     cfg.TRAIN.IMS_PER_BATCH = 2  
  200.     print 'Init model: {}'.format(init_model)  
  201.     print 'RPN proposals: {}'.format(rpn_file)  
  202.     print('Using config:')  
  203.     pprint.pprint(cfg)  
  204.   
  205.     import caffe  
  206.     _init_caffe(cfg)  
  207.   
  208.     roidb, imdb = get_roidb(imdb_name, rpn_file=rpn_file)  
  209.     output_dir = get_output_dir(imdb)  
  210.     print 'Output will be saved to `{:s}`'.format(output_dir)  
  211.     # Train Fast R-CNN  
  212.     model_paths = train_net(solver, roidb, output_dir,  
  213.                             pretrained_model=init_model,  
  214.                             max_iters=max_iters)  
  215.     # Cleanup all but the final model  
  216.     for i in model_paths[:-1]:  
  217.         os.remove(i)  
  218.     fast_rcnn_model_path = model_paths[-1]  
  219.     # Send Fast R-CNN model path over the multiprocessing queue  
  220.     queue.put({'model_path': fast_rcnn_model_path})  
  221.   
  222. if __name__ == '__main__':  
  223.     args = parse_args()  
  224.   
  225.     print('Called with args:')  
  226.     print(args)  
  227.   
  228.     if args.cfg_file is not None:  
  229.         cfg_from_file(args.cfg_file)  
  230.     if args.set_cfgs is not None:  
  231.         cfg_from_list(args.set_cfgs)  
  232.     cfg.GPU_ID = args.gpu_id  
  233.   
  234.     # --------------------------------------------------------------------------  
  235.     # Pycaffe doesn't reliably free GPU memory when instantiated nets are  
  236.     # discarded (e.g. "del net" in Python code). To work around this issue, each  
  237.     # training stage is executed in a separate process using  
  238.     # multiprocessing.Process.  
  239.     # --------------------------------------------------------------------------  
  240.   
  241.     # queue for communicated results between processes  
  242.     mp_queue = mp.Queue()  
  243.     # solves, iters, etc. for each training stage  
  244.     solvers, max_iters, rpn_test_prototxt = get_solvers(args.net_name)  
  245.   
  246.     print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  247.     print 'Stage 1 RPN, init from ImageNet model'  
  248.     print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  249.   
  250.     cfg.TRAIN.SNAPSHOT_INFIX = 'stage1'  
  251.     mp_kwargs = dict(  
  252.             queue=mp_queue,  
  253.             imdb_name=args.imdb_name,  
  254.             init_model=args.pretrained_model,  
  255.             solver=solvers[0],  
  256.             max_iters=max_iters[0],  
  257.             cfg=cfg)  
  258.     p = mp.Process(target=train_rpn, kwargs=mp_kwargs)  
  259.     p.start()  
  260.     rpn_stage1_out = mp_queue.get()  
  261.     p.join()  
  262.   
  263.     print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  264.     print 'Stage 1 RPN, generate proposals'  
  265.     print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  266.   
  267.     mp_kwargs = dict(  
  268.             queue=mp_queue,  
  269.             imdb_name=args.imdb_name,  
  270.             rpn_model_path=str(rpn_stage1_out['model_path']),  
  271.             cfg=cfg,  
  272.             rpn_test_prototxt=rpn_test_prototxt)  
  273.     p = mp.Process(target=rpn_generate, kwargs=mp_kwargs)  
  274.     p.start()  
  275.     rpn_stage1_out['proposal_path'] = mp_queue.get()['proposal_path']  
  276.     p.join()  
  277.   
  278.     print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  279.     print 'Stage 1 Fast R-CNN using RPN proposals, init from ImageNet model'  
  280.     print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  281.   
  282.     cfg.TRAIN.SNAPSHOT_INFIX = 'stage1'  
  283.     mp_kwargs = dict(  
  284.             queue=mp_queue,  
  285.             imdb_name=args.imdb_name,  
  286.             init_model=args.pretrained_model,  
  287.             solver=solvers[1],  
  288.             max_iters=max_iters[1],  
  289.             cfg=cfg,  
  290.             rpn_file=rpn_stage1_out['proposal_path'])  
  291.     p = mp.Process(target=train_fast_rcnn, kwargs=mp_kwargs)  
  292.     p.start()  
  293.     fast_rcnn_stage1_out = mp_queue.get()  
  294.     p.join()  
  295.   
  296.     print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  297.     print 'Stage 2 RPN, init from stage 1 Fast R-CNN model'  
  298.     print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  299.   
  300.     cfg.TRAIN.SNAPSHOT_INFIX = 'stage2'  
  301.     mp_kwargs = dict(  
  302.             queue=mp_queue,  
  303.             imdb_name=args.imdb_name,  
  304.             init_model=str(fast_rcnn_stage1_out['model_path']),  
  305.             solver=solvers[2],  
  306.             max_iters=max_iters[2],  
  307.             cfg=cfg)  
  308.     p = mp.Process(target=train_rpn, kwargs=mp_kwargs)  
  309.     p.start()  
  310.     rpn_stage2_out = mp_queue.get()  
  311.     p.join()  
  312.   
  313.     print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  314.     print 'Stage 2 RPN, generate proposals'  
  315.     print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  316.   
  317.     mp_kwargs = dict(  
  318.             queue=mp_queue,  
  319.             imdb_name=args.imdb_name,  
  320.             rpn_model_path=str(rpn_stage2_out['model_path']),  
  321.             cfg=cfg,  
  322.             rpn_test_prototxt=rpn_test_prototxt)  
  323.     p = mp.Process(target=rpn_generate, kwargs=mp_kwargs)  
  324.     p.start()  
  325.     rpn_stage2_out['proposal_path'] = mp_queue.get()['proposal_path']  
  326.     p.join()  
  327.   
  328.     print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  329.     print 'Stage 2 Fast R-CNN, init from stage 2 RPN R-CNN model'  
  330.     print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  331.   
  332.     cfg.TRAIN.SNAPSHOT_INFIX = 'stage2'  
  333.     mp_kwargs = dict(  
  334.             queue=mp_queue,  
  335.             imdb_name=args.imdb_name,  
  336.             init_model=str(rpn_stage2_out['model_path']),  
  337.             solver=solvers[3],  
  338.             max_iters=max_iters[3],  
  339.             cfg=cfg,  
  340.             rpn_file=rpn_stage2_out['proposal_path'])  
  341.     p = mp.Process(target=train_fast_rcnn, kwargs=mp_kwargs)  
  342.     p.start()  
  343.     fast_rcnn_stage2_out = mp_queue.get()  
  344.     p.join()  
  345.   
  346.     # Create final model (just a copy of the last stage)  
  347.     final_path = os.path.join(  
  348.             os.path.dirname(fast_rcnn_stage2_out['model_path']),  
  349.             args.net_name + '_faster_rcnn_final.caffemodel')  
  350.     print 'cp {} -> {}'.format(  
  351.             fast_rcnn_stage2_out['model_path'], final_path)  
  352.     shutil.copy(fast_rcnn_stage2_out['model_path'], final_path)  
  353.     print 'Final model: {}'.format(final_path)  


lib/rpn/generate.py利用rpn网络前向计算产生proposal

  1. #coding:utf-8  
  2. # --------------------------------------------------------  
  3. # Faster R-CNN  
  4. # Copyright (c) 2015 Microsoft  
  5. # Licensed under The MIT License [see LICENSE for details]  
  6. # Written by Ross Girshick  
  7. # --------------------------------------------------------  
  8.   
  9. from fast_rcnn.config import cfg  
  10. from utils.blob import im_list_to_blob  
  11. from utils.timer import Timer  
  12. import numpy as np  
  13. import cv2  
  14.   
  15. def _vis_proposals(im, dets, thresh=0.5):  
  16.     """Draw detected bounding boxes."""  
  17.     inds = np.where(dets[:, -1] >= thresh)[0]  
  18.     if len(inds) == 0:  
  19.         return  
  20.   
  21.     class_name = 'obj'  
  22.     im = im[:, :, (210)]  
  23.     fig, ax = plt.subplots(figsize=(1212))  
  24.     ax.imshow(im, aspect='equal')  
  25.     for i in inds:  
  26.         bbox = dets[i, :4]  
  27.         score = dets[i, -1]  
  28.   
  29.         ax.add_patch(  
  30.             plt.Rectangle((bbox[0], bbox[1]),  
  31.                           bbox[2] - bbox[0],  
  32.                           bbox[3] - bbox[1], fill=False,  
  33.                           edgecolor='red', linewidth=3.5)  
  34.             )  
  35.         ax.text(bbox[0], bbox[1] - 2,  
  36.                 '{:s} {:.3f}'.format(class_name, score),  
  37.                 bbox=dict(facecolor='blue', alpha=0.5),  
  38.                 fontsize=14, color='white')  
  39.   
  40.     ax.set_title(('{} detections with '  
  41.                   'p({} | box) >= {:.1f}').format(class_name, class_name,  
  42.                                                   thresh),  
  43.                   fontsize=14)  
  44.     plt.axis('off')  
  45.     plt.tight_layout()  
  46.     plt.draw()  
  47.   
  48. def _get_image_blob(im):  
  49.     """Converts an image into a network input. 
  50.  
  51.     Arguments: 
  52.         im (ndarray): a color image in BGR order 
  53.  
  54.     Returns: 
  55.         blob (ndarray): a data blob holding an image pyramid 
  56.         im_scale_factors (list): list of image scales (relative to im) used 
  57.             in the image pyramid 
  58.     """  
  59.     im_orig = im.astype(np.float32, copy=True)  
  60.     im_orig -= cfg.PIXEL_MEANS  
  61.   
  62.     im_shape = im_orig.shape  
  63.     im_size_min = np.min(im_shape[0:2])  
  64.     im_size_max = np.max(im_shape[0:2])  
  65.   
  66.     processed_ims = []  
  67.   
  68.     assert len(cfg.TEST.SCALES) == 1  
  69.     target_size = cfg.TEST.SCALES[0]  
  70.   
  71.     im_scale = float(target_size) / float(im_size_min)  
  72.     # Prevent the biggest axis from being more than MAX_SIZE  
  73.     if np.round(im_scale * im_size_max) > cfg.TEST.MAX_SIZE:  
  74.         im_scale = float(cfg.TEST.MAX_SIZE) / float(im_size_max)  
  75.     im = cv2.resize(im_orig, NoneNone, fx=im_scale, fy=im_scale,  
  76.                     interpolation=cv2.INTER_LINEAR)  
  77.     im_info = np.hstack((im.shape[:2], im_scale))[np.newaxis, :]  
  78.     processed_ims.append(im)  
  79.   
  80.     # Create a blob to hold the input images  
  81.     blob = im_list_to_blob(processed_ims)  
  82.   
  83.     return blob, im_info  
  84. #在一张图片上RPN前向计算产生region proposal  
  85. def im_proposals(net, im):  
  86.     """Generate RPN proposals on a single image."""  
  87.     blobs = {}  
  88.     blobs['data'], blobs['im_info'] = _get_image_blob(im)  
  89.     net.blobs['data'].reshape(*(blobs['data'].shape))  
  90.     net.blobs['im_info'].reshape(*(blobs['im_info'].shape))  
  91.     blobs_out = net.forward(  
  92.             data=blobs['data'].astype(np.float32, copy=False),  
  93.             im_info=blobs['im_info'].astype(np.float32, copy=False))  
  94.   
  95.     scale = blobs['im_info'][02]  
  96.     #boxes是列表,是所有roi box的坐标  
  97.     boxes = blobs_out['rois'][:, 1:].copy() / scale  
  98.     scores = blobs_out['scores'].copy()  
  99.     return boxes, scores  
  100.   
  101. #对imdb中所有的图像计算Region Proposal  
  102. def imdb_proposals(net, imdb):  
  103.     """Generate RPN proposals on all images in an imdb."""  
  104.   
  105.     _t = Timer()  
  106.     imdb_boxes = [[] for _ in xrange(imdb.num_images)]  
  107.     for i in xrange(imdb.num_images):  
  108.         im = cv2.imread(imdb.image_path_at(i))  
  109.         _t.tic()  
  110.         imdb_boxes[i], scores = im_proposals(net, im)  
  111.         _t.toc()  
  112.         print 'im_proposals: {:d}/{:d} {:.3f}s' \  
  113.               .format(i + 1, imdb.num_images, _t.average_time)  
  114.         if 0:  
  115.             dets = np.hstack((imdb_boxes[i], scores))  
  116.             # from IPython import embed; embed()  
  117.             _vis_proposals(im, dets[:3, :], thresh=0.9)  
  118.             plt.show()  
  119.   
  120.     return imdb_boxes  



lib/fast_rcnn/train.py


  1. #coding:utf-8  
  2. # --------------------------------------------------------  
  3. # Fast R-CNN  
  4. # Copyright (c) 2015 Microsoft  
  5. # Licensed under The MIT License [see LICENSE for details]  
  6. # Written by Ross Girshick  
  7. # --------------------------------------------------------  
  8.   
  9. """Train a Fast R-CNN network."""  
  10.   
  11. import caffe  
  12. from fast_rcnn.config import cfg  
  13. import roi_data_layer.roidb as rdl_roidb  
  14. from utils.timer import Timer  
  15. import numpy as np  
  16. import os  
  17.   
  18. from caffe.proto import caffe_pb2  
  19. import google.protobuf as pb2  
  20.   
  21. class SolverWrapper(object):  
  22.     """A simple wrapper around Caffe's solver. 
  23.     This wrapper gives us control over he snapshotting process, which we 
  24.     use to unnormalize the learned bounding-box regression weights. 
  25.     """  
  26.   
  27.     def __init__(self, solver_prototxt, roidb, output_dir,  
  28.                  pretrained_model=None):  
  29.         """Initialize the SolverWrapper."""  
  30.         self.output_dir = output_dir  
  31.   
  32.         if (cfg.TRAIN.HAS_RPN and cfg.TRAIN.BBOX_REG and  
  33.             cfg.TRAIN.BBOX_NORMALIZE_TARGETS):  
  34.             # RPN can only use precomputed normalization because there are no  
  35.             # fixed statistics to compute a priori  
  36.             assert cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED  
  37.   
  38.         if cfg.TRAIN.BBOX_REG:  
  39.             print 'Computing bounding-box regression targets...'  
  40.             #bbox_stds是什么  
  41.             self.bbox_means, self.bbox_stds = \  
  42.                     rdl_roidb.add_bbox_regression_targets(roidb)  
  43.             print 'done'  
  44.   
  45.         self.solver = caffe.SGDSolver(solver_prototxt)  
  46.         #加载在ImageNet上训练得到的预训练模型  
  47.         if pretrained_model is not None:  
  48.             print ('Loading pretrained model '  
  49.                    'weights from {:s}').format(pretrained_model)  
  50.             self.solver.net.copy_from(pretrained_model)  
  51.         #解析得到训练时的参数,学习率等  
  52.         self.solver_param = caffe_pb2.SolverParameter()  
  53.         with open(solver_prototxt, 'rt') as f:  
  54.             pb2.text_format.Merge(f.read(), self.solver_param)  
  55.         #设置输入  
  56.         self.solver.net.layers[0].set_roidb(roidb)  
  57. #迭代达到10000次、20000次。。。时存结果  
  58.     def snapshot(self):  
  59.         """Take a snapshot of the network after unnormalizing the learned 
  60.         bounding-box regression weights. This enables easy use at test-time. 
  61.         """  
  62.         net = self.solver.net  
  63.   
  64.         scale_bbox_params = (cfg.TRAIN.BBOX_REG and  
  65.                              cfg.TRAIN.BBOX_NORMALIZE_TARGETS and  
  66.                              net.params.has_key('bbox_pred'))  
  67.   
  68.         if scale_bbox_params:  
  69.             # save original values  
  70.             orig_0 = net.params['bbox_pred'][0].data.copy()  
  71.             orig_1 = net.params['bbox_pred'][1].data.copy()  
  72.   
  73.             # scale and shift with bbox reg unnormalization; then save snapshot  
  74.             net.params['bbox_pred'][0].data[...] = \  
  75.                     (net.params['bbox_pred'][0].data *  
  76.                      self.bbox_stds[:, np.newaxis])  
  77.             net.params['bbox_pred'][1].data[...] = \  
  78.                     (net.params['bbox_pred'][1].data *  
  79.                      self.bbox_stds + self.bbox_means)  
  80.   
  81.         infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX  
  82.                  if cfg.TRAIN.SNAPSHOT_INFIX != '' else '')  
  83.         filename = (self.solver_param.snapshot_prefix + infix +  
  84.                     '_iter_{:d}'.format(self.solver.iter) + '.caffemodel')  
  85.         filename = os.path.join(self.output_dir, filename)  
  86.   
  87.         net.save(str(filename))  
  88.         print 'Wrote snapshot to: {:s}'.format(filename)  
  89.   
  90.         if scale_bbox_params:  
  91.             # restore net to original state  
  92.             net.params['bbox_pred'][0].data[...] = orig_0  
  93.             net.params['bbox_pred'][1].data[...] = orig_1  
  94.         return filename  
  95. #迭代一次  
  96.     def train_model(self, max_iters):  
  97.         """Network training loop."""  
  98.         last_snapshot_iter = -1  
  99.         timer = Timer()  
  100.         model_paths = []  
  101.         while self.solver.iter < max_iters:  
  102.             # Make one SGD update  
  103.             timer.tic()  
  104.             self.solver.step(1)  
  105.             timer.toc()  
  106.             if self.solver.iter % (10 * self.solver_param.display) == 0:  
  107.                 print 'speed: {:.3f}s / iter'.format(timer.average_time)  
  108.   
  109.             if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:  
  110.                 last_snapshot_iter = self.solver.iter  
  111.                 model_paths.append(self.snapshot())  
  112.   
  113.         if last_snapshot_iter != self.solver.iter:  
  114.             model_paths.append(self.snapshot())  
  115.         return model_paths  
  116.   
  117. def get_training_roidb(imdb):  
  118.     """Returns a roidb (Region of Interest database) for use in training."""  
  119. #如果设置使用水平翻转的图像  
  120.     if cfg.TRAIN.USE_FLIPPED:  
  121.         print 'Appending horizontally-flipped training examples...'  
  122. #把原来image database里所有的图像水平翻转加入到imdb里  
  123.         imdb.append_flipped_images()  
  124.         print 'done'  
  125.   
  126.     print 'Preparing training data...'  
  127.     rdl_roidb.prepare_roidb(imdb)  
  128.     print 'done'  
  129.   
  130.     return imdb.roidb  
  131.   
  132. #过滤产生符合条件的roidb  
  133. def filter_roidb(roidb):  
  134.     """Remove roidb entries that have no usable RoIs."""  
  135.   
  136.     def is_valid(entry):  
  137. #满足roidb中至少有一个前景或背景的roidb才有效  
  138.         # Valid images have:  
  139.         #   (1) At least one foreground RoI OR  
  140.         #   (2) At least one background RoI  
  141.         overlaps = entry['max_overlaps']  
  142.         # find boxes with sufficient overlap  
  143.         fg_inds = np.where(overlaps >= cfg.TRAIN.FG_THRESH)[0]  
  144.         # Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)  
  145.         bg_inds = np.where((overlaps < cfg.TRAIN.BG_THRESH_HI) &  
  146.                            (overlaps >= cfg.TRAIN.BG_THRESH_LO))[0]  
  147.         # image is only valid if such boxes exist  
  148.         valid = len(fg_inds) > 0 or len(bg_inds) > 0  
  149.         return valid  
  150.   
  151.     num = len(roidb)  
  152.     filtered_roidb = [entry for entry in roidb if is_valid(entry)]  
  153.     num_after = len(filtered_roidb)  
  154.     print 'Filtered {} roidb entries: {} -> {}'.format(num - num_after,  
  155.                                                        num, num_after)  
  156.     return filtered_roidb  
  157.   
  158. def train_net(solver_prototxt, roidb, output_dir,  
  159.               pretrained_model=None, max_iters=40000):  
  160.     """Train a Fast R-CNN network."""  
  161.   
  162.     roidb = filter_roidb(roidb)  
  163.     sw = SolverWrapper(solver_prototxt, roidb, output_dir,  
  164.                        pretrained_model=pretrained_model)  
  165.   
  166.     print 'Solving...'  
  167.     model_paths = sw.train_model(max_iters)  
  168.     print 'done solving'  
  169.     return model_paths  



lib/roi_data_layer/roidb.py

roidb是一个重要的数据结构,roidb是一个列表,里面的每个元素是字典,对应一张图片的所有roi信息
{'image':imageindex,'width':w,'height':h,'gt_overlaps':二维array,每张图片上所有roi与各个类别的gt的overlap,'max_classes':max_cls,每个roi属于那一类别的大,'max_overlaps':每个roi与gt最大重叠率的大小}

  1. #coding:utf-8  
  2. # --------------------------------------------------------  
  3. # Fast R-CNN  
  4. # Copyright (c) 2015 Microsoft  
  5. # Licensed under The MIT License [see LICENSE for details]  
  6. # Written by Ross Girshick  
  7. # --------------------------------------------------------  
  8.   
  9. """Transform a roidb into a trainable roidb by adding a bunch of metadata."""  
  10.   
  11. import numpy as np  
  12. from fast_rcnn.config import cfg  
  13. from fast_rcnn.bbox_transform import bbox_transform  
  14. from utils.cython_bbox import bbox_overlaps  
  15. import PIL  
  16.   
  17.   
  18. #准备roidb  
  19. def prepare_roidb(imdb):  
  20.     """Enrich the imdb's roidb by adding some derived quantities that 
  21.     are useful for training. This function precomputes the maximum 
  22.     overlap, taken over ground-truth boxes, between each ROI and 
  23.     each ground-truth box. The class with maximum overlap is also 
  24.     recorded. 
  25.     """  
  26. #得到每幅图像的宽和高  
  27.     sizes = [PIL.Image.open(imdb.image_path_at(i)).size  
  28.              for i in xrange(imdb.num_images)]  
  29.     roidb = imdb.roidb  
  30. #roidb是一个列表,里面的每个元素是一个字典,对应一张图片的所有roi信息  
  31.     for i in xrange(len(imdb.image_index)):  
  32.         #字典{'image':imageindex,'width':w,'height':h,'gt_overlaps':二维array,每张图片上所有roi与各个类别的gt的overlap  
  33.         #'max_classes':max_cls,每个roi属于那一类别的最大  
  34.         roidb[i]['image'] = imdb.image_path_at(i)  
  35.         roidb[i]['width'] = sizes[i][0]  
  36.         roidb[i]['height'] = sizes[i][1]  
  37.         # need gt_overlaps as a dense array for argmax  
  38.         gt_overlaps = roidb[i]['gt_overlaps'].toarray()  
  39.         # max overlap with gt over classes (columns)  
  40. #传递进来的roidb尚未经过下面的取最大值的操作  
  41. #下面得到每个roi与ground-truth的bbox的最大IoU值  
  42.         max_overlaps = gt_overlaps.max(axis=1)  
  43.         # gt class that had the max overlap  
  44. #与哪个类别有最大IoU  
  45.         max_classes = gt_overlaps.argmax(axis=1)  
  46.         roidb[i]['max_classes'] = max_classes  
  47.         roidb[i]['max_overlaps'] = max_overlaps  
  48.         # sanity checks  
  49.         # max overlap of 0 => class should be zero (background)  
  50. #确保max overlap=0的box都属于背景  
  51.         zero_inds = np.where(max_overlaps == 0)[0]  
  52.         assert all(max_classes[zero_inds] == 0)  
  53.         # max overlap > 0 => class should not be zero (must be a fg class)  
  54.         nonzero_inds = np.where(max_overlaps > 0)[0]  
  55.         assert all(max_classes[nonzero_inds] != 0)  
  56.   
  57. def add_bbox_regression_targets(roidb):  
  58.     """Add information needed to train bounding-box regressors."""  
  59.     assert len(roidb) > 0  
  60.     assert 'max_classes' in roidb[0], 'Did you call prepare_roidb first?'  
  61.   
  62.     num_images = len(roidb)  
  63.     # Infer number of classes from the number of columns in gt_overlaps  
  64.     num_classes = roidb[0]['gt_overlaps'].shape[1]  
  65.     for im_i in xrange(num_images):  
  66.         rois = roidb[im_i]['boxes']  
  67.         max_overlaps = roidb[im_i]['max_overlaps']  
  68.         max_classes = roidb[im_i]['max_classes']  
  69.         roidb[im_i]['bbox_targets'] = \  
  70.                 _compute_targets(rois, max_overlaps, max_classes)  
  71.   
  72.     if cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED:  
  73.         # Use fixed / precomputed "means" and "stds" instead of empirical values  
  74.         means = np.tile(  
  75.                 np.array(cfg.TRAIN.BBOX_NORMALIZE_MEANS), (num_classes, 1))  
  76.         stds = np.tile(  
  77.                 np.array(cfg.TRAIN.BBOX_NORMALIZE_STDS), (num_classes, 1))  
  78.     else:  
  79.         # Compute values needed for means and stds  
  80.         # var(x) = E(x^2) - E(x)^2  
  81.         class_counts = np.zeros((num_classes, 1)) + cfg.EPS  
  82.         sums = np.zeros((num_classes, 4))  
  83.         squared_sums = np.zeros((num_classes, 4))  
  84.         for im_i in xrange(num_images):  
  85.             targets = roidb[im_i]['bbox_targets']  
  86.             for cls in xrange(1, num_classes):  
  87.                 cls_inds = np.where(targets[:, 0] == cls)[0]  
  88.                 if cls_inds.size > 0:  
  89.                     class_counts[cls] += cls_inds.size  
  90.                     sums[cls, :] += targets[cls_inds, 1:].sum(axis=0)  
  91.                     squared_sums[cls, :] += \  
  92.                             (targets[cls_inds, 1:] ** 2).sum(axis=0)  
  93.   
  94.         means = sums / class_counts  
  95.         stds = np.sqrt(squared_sums / class_counts - means ** 2)  
  96.   
  97.     print 'bbox target means:'  
  98.     print means  
  99.     print means[1:, :].mean(axis=0# ignore bg class  
  100.     print 'bbox target stdevs:'  
  101.     print stds  
  102.     print stds[1:, :].mean(axis=0# ignore bg class  
  103.   
  104.     # Normalize targets  
  105.     if cfg.TRAIN.BBOX_NORMALIZE_TARGETS:  
  106.         print "Normalizing targets"  
  107.         for im_i in xrange(num_images):  
  108.             targets = roidb[im_i]['bbox_targets']  
  109.             for cls in xrange(1, num_classes):  
  110.                 cls_inds = np.where(targets[:, 0] == cls)[0]  
  111.                 roidb[im_i]['bbox_targets'][cls_inds, 1:] -= means[cls, :]  
  112.                 roidb[im_i]['bbox_targets'][cls_inds, 1:] /= stds[cls, :]  
  113.     else:  
  114.         print "NOT normalizing targets"  
  115.   
  116.     # These values will be needed for making predictions  
  117.     # (the predicts will need to be unnormalized and uncentered)  
  118.     return means.ravel(), stds.ravel()  
  119.   
  120. #计算bbox回归时用到的回归目标值,包括region proposal相对grouynd-truth的bbox的  
  121. #坐标偏移量和长宽比例,这四个目标值都经过了归一化处理  
  122. def _compute_targets(rois, overlaps, labels):  
  123.     """Compute bounding-box regression targets for an image."""  
  124.     # Indices of ground-truth ROIs  
  125.     gt_inds = np.where(overlaps == 1)[0]  
  126.     if len(gt_inds) == 0:  
  127.         # Bail if the image has no ground-truth ROIs  
  128. #如果roidb全部是背景,返回0矩阵  
  129.         return np.zeros((rois.shape[0], 5), dtype=np.float32)  
  130.     # Indices of examples for which we try to make predictions  
  131.     ex_inds = np.where(overlaps >= cfg.TRAIN.BBOX_THRESH)[0]  
  132.   
  133.     # Get IoU overlap between each ex ROI and gt ROI  
  134.     ex_gt_overlaps = bbox_overlaps(  
  135.         np.ascontiguousarray(rois[ex_inds, :], dtype=np.float),  
  136.         np.ascontiguousarray(rois[gt_inds, :], dtype=np.float))  
  137.   
  138.     # Find which gt ROI each ex ROI has max overlap with:  
  139.     # this will be the ex ROI's gt target  
  140.     gt_assignment = ex_gt_overlaps.argmax(axis=1)  
  141.     gt_rois = rois[gt_inds[gt_assignment], :]  
  142.     ex_rois = rois[ex_inds, :]  
  143.   
  144.     targets = np.zeros((rois.shape[0], 5), dtype=np.float32)  
  145. #矩阵第一列是类别  
  146.     targets[ex_inds, 0] = labels[ex_inds]  
  147. #后面四列是边框回归目标值  
  148.     targets[ex_inds, 1:] = bbox_transform(ex_rois, gt_rois)  
  149.     return targets  



lib/datasets/imdb.py


  1. #coding:utf-8  
  2. # --------------------------------------------------------  
  3. # Fast R-CNN  
  4. # Copyright (c) 2015 Microsoft  
  5. # Licensed under The MIT License [see LICENSE for details]  
  6. # Written by Ross Girshick  
  7. # --------------------------------------------------------  
  8.   
  9. import os  
  10. import os.path as osp  
  11. import PIL  
  12. from utils.cython_bbox import bbox_overlaps  
  13. import numpy as np  
  14. import scipy.sparse  
  15. from fast_rcnn.config import cfg  
  16.   
  17. class imdb(object):  
  18.     """Image database."""  
  19.   
  20.     def __init__(self, name):  
  21. #imdb的一些属性     
  22.         self._name = name  
  23.         self._num_classes = 0  
  24.         self._classes = []  
  25.         self._image_index = []  
  26.         self._obj_proposer = 'selective_search'  
  27.         self._roidb = None  
  28.         self._roidb_handler = self.default_roidb  
  29.         # Use this dict for storing dataset specific config options  
  30.         self.config = {}  
  31.  
  32.     @property  
  33.     def name(self):  
  34.         return self._name  
  35.  
  36.     @property  
  37.     def num_classes(self):  
  38.         return len(self._classes)  
  39.  
  40.     @property  
  41.     def classes(self):  
  42.         return self._classes  
  43.  
  44.     @property  
  45.     def image_index(self):  
  46.         return self._image_index  
  47.  
  48.     @property  
  49.     def roidb_handler(self):  
  50.         return self._roidb_handler  
  51.  
  52.     @roidb_handler.setter  
  53.     def roidb_handler(self, val):  
  54.         self._roidb_handler = val  
  55.   
  56.     def set_proposal_method(self, method):  
  57.         method = eval('self.' + method + '_roidb')  
  58.         self.roidb_handler = method  
  59.  
  60.     @property  
  61.     def roidb(self):  
  62.         # A roidb is a list of dictionaries, each with the following keys:  
  63.         #   boxes  
  64.         #   gt_overlaps  
  65.         #   gt_classes  
  66.         #   flipped  
  67.         if self._roidb is not None:  
  68.             return self._roidb  
  69.         self._roidb = self.roidb_handler()  
  70.         return self._roidb  
  71.  
  72.     @property  
  73.     def cache_path(self):  
  74.         cache_path = osp.abspath(osp.join(cfg.DATA_DIR, 'cache'))  
  75.         if not os.path.exists(cache_path):  
  76.             os.makedirs(cache_path)  
  77.         return cache_path  
  78.  
  79.     @property  
  80.     def num_images(self):  
  81.       return len(self.image_index)  
  82.   
  83.     def image_path_at(self, i):  
  84.         raise NotImplementedError  
  85.   
  86.     def default_roidb(self):  
  87.         raise NotImplementedError  
  88.   
  89.     def evaluate_detections(self, all_boxes, output_dir=None):  
  90.         """ 
  91.         all_boxes is a list of length number-of-classes. 
  92.         Each list element is a list of length number-of-images. 
  93.         Each of those list elements is either an empty list [] 
  94.         or a numpy array of detection. 
  95.  
  96.         all_boxes[class][image] = [] or np.array of shape #dets x 5 
  97.         """  
  98.         raise NotImplementedError  
  99.   
  100.     def _get_widths(self):  
  101.       return [PIL.Image.open(self.image_path_at(i)).size[0]  
  102.               for i in xrange(self.num_images)]  
  103. #对所有原始图像进行水平翻转  
  104.     def append_flipped_images(self):  
  105.         num_images = self.num_images  
  106. #得到所有图像的宽度存入list  
  107.         widths = self._get_widths()  
  108.         for i in xrange(num_images):  
  109. #复制每张图中所有的box坐标,这个boxes是一个列表,类似[(x1min,y1min,x1max,y1max),]  
  110.             boxes = self.roidb[i]['boxes'].copy()  
  111.             oldx1 = boxes[:, 0].copy()  
  112.             oldx2 = boxes[:, 2].copy()  
  113. #水平翻转只用对横坐标做变换,容易得到x'=width-x  
  114.             boxes[:, 0] = widths[i] - oldx2 - 1  
  115.             boxes[:, 2] = widths[i] - oldx1 - 1  
  116.             assert (boxes[:, 2] >= boxes[:, 0]).all()  
  117. #entry是一个字典,存了boxes坐标,和ground-truth的重叠率,类别,是否由水平翻转得到,  
  118. #重叠率和类别不会变,直接复制i  
  119.             entry = {'boxes' : boxes,  
  120.                      'gt_overlaps' : self.roidb[i]['gt_overlaps'],  
  121.                      'gt_classes' : self.roidb[i]['gt_classes'],  
  122.                      'flipped' : True}  
  123. #把水平翻转得到的数据加入roidb  
  124.             self.roidb.append(entry)  
  125. #数量变为原来的2倍  
  126.         self._image_index = self._image_index * 2  
  127.   
  128.     def evaluate_recall(self, candidate_boxes=None, thresholds=None,  
  129.                         area='all', limit=None):  
  130.         """Evaluate detection proposal recall metrics. 
  131.  
  132.         Returns: 
  133.             results: dictionary of results with keys 
  134.                 'ar': average recall 
  135.                 'recalls': vector recalls at each IoU overlap threshold 
  136.                 'thresholds': vector of IoU overlap thresholds 
  137.                 'gt_overlaps': vector of all ground-truth overlaps 
  138.         """  
  139.         # Record max overlap value for each gt box  
  140.         # Return vector of overlap values  
  141.         areas = { 'all'0'small'1'medium'2'large'3,  
  142.                   '96-128'4'128-256'5'256-512'6'512-inf'7}  
  143.         area_ranges = [ [0**21e5**2],    # all  
  144.                         [0**232**2],     # small  
  145.                         [32**296**2],    # medium  
  146.                         [96**21e5**2],   # large  
  147.                         [96**2128**2],   # 96-128  
  148.                         [128**2256**2],  # 128-256  
  149.                         [256**2512**2],  # 256-512  
  150.                         [512**21e5**2],  # 512-inf  
  151.                       ]  
  152.         assert areas.has_key(area), 'unknown area range: {}'.format(area)  
  153.         area_range = area_ranges[areas[area]]  
  154.         gt_overlaps = np.zeros(0)  
  155.         num_pos = 0  
  156.         for i in xrange(self.num_images):  
  157.             # Checking for max_overlaps == 1 avoids including crowd annotations  
  158.             # (...pretty hacking :/)  
  159.             max_gt_overlaps = self.roidb[i]['gt_overlaps'].toarray().max(axis=1)  
  160.             gt_inds = np.where((self.roidb[i]['gt_classes'] > 0) &  
  161.                                (max_gt_overlaps == 1))[0]  
  162.             gt_boxes = self.roidb[i]['boxes'][gt_inds, :]  
  163.             gt_areas = self.roidb[i]['seg_areas'][gt_inds]  
  164.             valid_gt_inds = np.where((gt_areas >= area_range[0]) &  
  165.                                      (gt_areas <= area_range[1]))[0]  
  166.             gt_boxes = gt_boxes[valid_gt_inds, :]  
  167.             num_pos += len(valid_gt_inds)  
  168.   
  169.             if candidate_boxes is None:  
  170.                 # If candidate_boxes is not supplied, the default is to use the  
  171.                 # non-ground-truth boxes from this roidb  
  172.                 non_gt_inds = np.where(self.roidb[i]['gt_classes'] == 0)[0]  
  173.                 boxes = self.roidb[i]['boxes'][non_gt_inds, :]  
  174.             else:  
  175.                 boxes = candidate_boxes[i]  
  176.             if boxes.shape[0] == 0:  
  177.                 continue  
  178.             if limit is not None and boxes.shape[0] > limit:  
  179.                 boxes = boxes[:limit, :]  
  180.   
  181.             overlaps = bbox_overlaps(boxes.astype(np.float),  
  182.                                      gt_boxes.astype(np.float))  
  183.   
  184.             _gt_overlaps = np.zeros((gt_boxes.shape[0]))  
  185.             for j in xrange(gt_boxes.shape[0]):  
  186.                 # find which proposal box maximally covers each gt box  
  187.                 argmax_overlaps = overlaps.argmax(axis=0)  
  188.                 # and get the iou amount of coverage for each gt box  
  189.                 max_overlaps = overlaps.max(axis=0)  
  190.                 # find which gt box is 'best' covered (i.e. 'best' = most iou)  
  191.                 gt_ind = max_overlaps.argmax()  
  192.                 gt_ovr = max_overlaps.max()  
  193.                 assert(gt_ovr >= 0)  
  194.                 # find the proposal box that covers the best covered gt box  
  195.                 box_ind = argmax_overlaps[gt_ind]  
  196.                 # record the iou coverage of this gt box  
  197.                 _gt_overlaps[j] = overlaps[box_ind, gt_ind]  
  198.                 assert(_gt_overlaps[j] == gt_ovr)  
  199.                 # mark the proposal box and the gt box as used  
  200.                 overlaps[box_ind, :] = -1  
  201.                 overlaps[:, gt_ind] = -1  
  202.             # append recorded iou coverage level  
  203.             gt_overlaps = np.hstack((gt_overlaps, _gt_overlaps))  
  204.   
  205.         gt_overlaps = np.sort(gt_overlaps)  
  206.         if thresholds is None:  
  207.             step = 0.05  
  208.             thresholds = np.arange(0.50.95 + 1e-5, step)  
  209.         recalls = np.zeros_like(thresholds)  
  210.         # compute recall for each iou threshold  
  211.         for i, t in enumerate(thresholds):  
  212.             recalls[i] = (gt_overlaps >= t).sum() / float(num_pos)  
  213.         # ar = 2 * np.trapz(recalls, thresholds)  
  214.         ar = recalls.mean()  
  215.         return {'ar': ar, 'recalls': recalls, 'thresholds': thresholds,  
  216.                 'gt_overlaps': gt_overlaps}  
  217.   
  218.     def create_roidb_from_box_list(self, box_list, gt_roidb):  
  219. #断言box_list的数目和图像数目一样,这里box_list[i]里存的是相应第i张图像里所有的bbox的坐标  
  220.         assert len(box_list) == self.num_images, \  
  221.                 'Number of boxes must match number of ground-truth images'  
  222. #重要,roidb是一个列表,列表中的每一个元素是字典,存储了每张图象的boxes,gt_classes,'gt_overlaps',是否flipped信息  
  223.         roidb = []  
  224.         for i in xrange(self.num_images):  
  225.             boxes = box_list[i]  
  226.             num_boxes = boxes.shape[0]  
  227. #计算每个box和每一类目标的重叠率  
  228.             overlaps = np.zeros((num_boxes, self.num_classes), dtype=np.float32)  
  229.   
  230.             if gt_roidb is not None and gt_roidb[i]['boxes'].size > 0:  
  231. #取得ground-truth里bbox的坐标  
  232.                 gt_boxes = gt_roidb[i]['boxes']  
  233. #取得每个bbox对应的类别  
  234.                 gt_classes = gt_roidb[i]['gt_classes']  
  235. #计算roidb的bbox与ground-truth的bbox的重叠率  
  236.                 gt_overlaps = bbox_overlaps(boxes.astype(np.float),  
  237.                                             gt_boxes.astype(np.float))  
  238. #与那一类的重叠率最大  
  239.                 argmaxes = gt_overlaps.argmax(axis=1)  
  240.                 maxes = gt_overlaps.max(axis=1)  
  241.                 I = np.where(maxes > 0)[0]  
  242.                 overlaps[I, gt_classes[argmaxes[I]]] = maxes[I]  
  243.   
  244.             overlaps = scipy.sparse.csr_matrix(overlaps)  
  245.             roidb.append({  
  246.                 'boxes' : boxes,  
  247.                 'gt_classes' : np.zeros((num_boxes,), dtype=np.int32),  
  248.                 'gt_overlaps' : overlaps,  
  249.                 'flipped' : False,  
  250.                 'seg_areas' : np.zeros((num_boxes,), dtype=np.float32),  
  251.             })  
  252.         return roidb  
  253.  
  254.     @staticmethod  
  255.     def merge_roidbs(a, b):  
  256.         assert len(a) == len(b)  
  257.         for i in xrange(len(a)):  
  258.             a[i]['boxes'] = np.vstack((a[i]['boxes'], b[i]['boxes']))  
  259.             a[i]['gt_classes'] = np.hstack((a[i]['gt_classes'],  
  260.                                             b[i]['gt_classes']))  
  261.             a[i]['gt_overlaps'] = scipy.sparse.vstack([a[i]['gt_overlaps'],  
  262.                                                        b[i]['gt_overlaps']])  
  263.             a[i]['seg_areas'] = np.hstack((a[i]['seg_areas'],  
  264.                                            b[i]['seg_areas']))  
  265.         return a  
  266.   
  267.     def competition_mode(self, on):  
  268.         """Turn competition mode on or off."""  
  269.         pass  


lib/datasets/pascal_voc.py


  1. #CODING:UTF-8  
  2. # --------------------------------------------------------  
  3. # Fast R-CNN  
  4. # Copyright (c) 2015 Microsoft  
  5. # Licensed under The MIT License [see LICENSE for details]  
  6. # Written by Ross Girshick  
  7. # --------------------------------------------------------  
  8.   
  9. import os  
  10. from datasets.imdb import imdb  
  11. import datasets.ds_utils as ds_utils  
  12. import xml.etree.ElementTree as ET  
  13. import numpy as np  
  14. import scipy.sparse  
  15. import scipy.io as sio  
  16. import utils.cython_bbox  
  17. import cPickle  
  18. import subprocess  
  19. import uuid  
  20. from voc_eval import voc_eval  
  21. from fast_rcnn.config import cfg  
  22.   
  23. class pascal_voc(imdb):  
  24.     def __init__(self, image_set, year, devkit_path=None):  
  25.         imdb.__init__(self'voc_' + year + '_' + image_set)  
  26.         self._year = year  
  27.         self._image_set = image_set  
  28.         self._devkit_path = self._get_default_path() if devkit_path is None \  
  29.                             else devkit_path  
  30.         self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)  
  31.         self._classes = ('__background__'# always index 0  
  32.                          'aeroplane''bicycle''bird''boat',  
  33.                          'bottle''bus''car''cat''chair',  
  34.                          'cow''diningtable''dog''horse',  
  35.                          'motorbike''person''pottedplant',  
  36.                          'sheep''sofa''train''tvmonitor')  
  37.         self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))  
  38.         self._image_ext = '.jpg'  
  39.         self._image_index = self._load_image_set_index()  
  40.         # Default to roidb handler  
  41.         self._roidb_handler = self.selective_search_roidb  
  42.         self._salt = str(uuid.uuid4())  
  43.         self._comp_id = 'comp4'  
  44.   
  45.         # PASCAL specific config options  
  46.         self.config = {'cleanup'     : True,  
  47.                        'use_salt'    : True,  
  48.                        'use_diff'    : False,  
  49.                        'matlab_eval' : False,  
  50.                        'rpn_file'    : None,  
  51.                        'min_size'    : 2}  
  52.   
  53.         assert os.path.exists(self._devkit_path), \  
  54.                 'VOCdevkit path does not exist: {}'.format(self._devkit_path)  
  55.         assert os.path.exists(self._data_path), \  
  56.                 'Path does not exist: {}'.format(self._data_path)  
  57.   
  58.     def image_path_at(self, i):  
  59.         """ 
  60.         Return the absolute path to image i in the image sequence. 
  61.         """  
  62.         return self.image_path_from_index(self._image_index[i])  
  63.   
  64.     def image_path_from_index(self, index):  
  65.         """ 
  66.         Construct an image path from the image's "index" identifier. 
  67.         """  
  68.         image_path = os.path.join(self._data_path, 'JPEGImages',  
  69.                                   index + self._image_ext)  
  70.         assert os.path.exists(image_path), \  
  71.                 'Path does not exist: {}'.format(image_path)  
  72.         return image_path  
  73.   
  74.     def _load_image_set_index(self):  
  75.         """ 
  76.         Load the indexes listed in this dataset's image set file. 
  77.         """  
  78.         # Example path to image set file:  
  79.         # self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt  
  80.         image_set_file = os.path.join(self._data_path, 'ImageSets''Main',  
  81.                                       self._image_set + '.txt')  
  82.         assert os.path.exists(image_set_file), \  
  83.                 'Path does not exist: {}'.format(image_set_file)  
  84.         with open(image_set_file) as f:  
  85.             image_index = [x.strip() for x in f.readlines()]  
  86.         return image_index  
  87.   
  88.     def _get_default_path(self):  
  89.         """ 
  90.         Return the default path where PASCAL VOC is expected to be installed. 
  91.         """  
  92.         return os.path.join(cfg.DATA_DIR, 'VOCdevkit' + self._year)  
  93.   
  94.     def gt_roidb(self):  
  95.         """ 
  96.         Return the database of ground-truth regions of interest. 
  97.  
  98.         This function loads/saves from/to a cache file to speed up future calls. 
  99.         """  
  100.         cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')  
  101.         if os.path.exists(cache_file):  
  102.             with open(cache_file, 'rb') as fid:  
  103.                 roidb = cPickle.load(fid)  
  104.             print '{} gt roidb loaded from {}'.format(self.name, cache_file)  
  105.             return roidb  
  106.   
  107.         gt_roidb = [self._load_pascal_annotation(index)  
  108.                     for index in self.image_index]  
  109.         with open(cache_file, 'wb') as fid:  
  110.             cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)  
  111.         print 'wrote gt roidb to {}'.format(cache_file)  
  112.   
  113.         return gt_roidb  
  114.   
  115.     def selective_search_roidb(self):  
  116.         """ 
  117.         Return the database of selective search regions of interest. 
  118.         Ground-truth ROIs are also included. 
  119.  
  120.         This function loads/saves from/to a cache file to speed up future calls. 
  121.         """  
  122.         cache_file = os.path.join(self.cache_path,  
  123.                                   self.name + '_selective_search_roidb.pkl')  
  124.   
  125.         if os.path.exists(cache_file):  
  126.             with open(cache_file, 'rb') as fid:  
  127.                 roidb = cPickle.load(fid)  
  128.             print '{} ss roidb loaded from {}'.format(self.name, cache_file)  
  129.             return roidb  
  130.   
  131.         if int(self._year) == 2007 or self._image_set != 'test':  
  132.             gt_roidb = self.gt_roidb()  
  133.             ss_roidb = self._load_selective_search_roidb(gt_roidb)  
  134.             roidb = imdb.merge_roidbs(gt_roidb, ss_roidb)  
  135.         else:  
  136.             roidb = self._load_selective_search_roidb(None)  
  137.         with open(cache_file, 'wb') as fid:  
  138.             cPickle.dump(roidb, fid, cPickle.HIGHEST_PROTOCOL)  
  139.         print 'wrote ss roidb to {}'.format(cache_file)  
  140.   
  141.         return roidb  
  142.   
  143.     def rpn_roidb(self):  
  144.         if int(self._year) == 2007 or self._image_set != 'test':  
  145.             gt_roidb = self.gt_roidb()  
  146.             rpn_roidb = self._load_rpn_roidb(gt_roidb)  
  147.             roidb = imdb.merge_roidbs(gt_roidb, rpn_roidb)  
  148.         else:  
  149.             roidb = self._load_rpn_roidb(None)  
  150.   
  151.         return roidb  
  152.   
  153.     def _load_rpn_roidb(self, gt_roidb):  
  154.         filename = self.config['rpn_file']  
  155.         print 'loading {}'.format(filename)  
  156.         assert os.path.exists(filename), \  
  157.                'rpn data not found at: {}'.format(filename)  
  158. #得到rpn产生的box  
  159.         with open(filename, 'rb') as f:  
  160.             box_list = cPickle.load(f)  
  161. #由box_list产生roidb  
  162.         return self.create_roidb_from_box_list(box_list, gt_roidb)  
  163.   
  164.     def _load_selective_search_roidb(self, gt_roidb):  
  165.         filename = os.path.abspath(os.path.join(cfg.DATA_DIR,  
  166.                                                 'selective_search_data',  
  167.                                                 self.name + '.mat'))  
  168.         assert os.path.exists(filename), \  
  169.                'Selective search data not found at: {}'.format(filename)  
  170.         raw_data = sio.loadmat(filename)['boxes'].ravel()  
  171.   
  172.         box_list = []  
  173.         for i in xrange(raw_data.shape[0]):  
  174.             boxes = raw_data[i][:, (1032)] - 1  
  175.             keep = ds_utils.unique_boxes(boxes)  
  176.             boxes = boxes[keep, :]  
  177.             keep = ds_utils.filter_small_boxes(boxes, self.config['min_size'])  
  178.             boxes = boxes[keep, :]  
  179.             box_list.append(boxes)  
  180.   
  181.         return self.create_roidb_from_box_list(box_list, gt_roidb)  
  182.   
  183.     def _load_pascal_annotation(self, index):  
  184.         """ 
  185.         Load image and bounding boxes info from XML file in the PASCAL VOC 
  186.         format. 
  187.         """  
  188. #xml文件名  
  189.         filename = os.path.join(self._data_path, 'Annotations', index + '.xml')  
  190. #解析xml  
  191.         tree = ET.parse(filename)  
  192. #找到所有object属性项  
  193.         objs = tree.findall('object')  
  194.         if not self.config['use_diff']:  
  195.             # Exclude the samples labeled as difficult  
  196.             non_diff_objs = [  
  197.                 obj for obj in objs if int(obj.find('difficult').text) == 0]  
  198.             # if len(non_diff_objs) != len(objs):  
  199.             #     print 'Removed {} difficult objects'.format(  
  200.             #         len(objs) - len(non_diff_objs))  
  201.             objs = non_diff_objs  
  202.         num_objs = len(objs)  
  203. #boxes存储这张图片上所有bbox的坐标  
  204.         boxes = np.zeros((num_objs, 4), dtype=np.uint16)  
  205. #gt_classes存储每个bbox的类别  
  206.         gt_classes = np.zeros((num_objs), dtype=np.int32)  
  207.         overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)  
  208.         # "Seg" area for pascal is just the box area  
  209.         seg_areas = np.zeros((num_objs), dtype=np.float32)  
  210.   
  211.         # Load object bounding boxes into a data frame.  
  212.         for ix, obj in enumerate(objs):  
  213.             bbox = obj.find('bndbox')  
  214.             # Make pixel indexes 0-based  
  215.             x1 = float(bbox.find('xmin').text) - 1  
  216.             y1 = float(bbox.find('ymin').text) - 1  
  217.             x2 = float(bbox.find('xmax').text) - 1  
  218.             y2 = float(bbox.find('ymax').text) - 1  
  219. #从类别名字映射到ID  
  220.             cls = self._class_to_ind[obj.find('name').text.lower().strip()]  
  221.             boxes[ix, :] = [x1, y1, x2, y2]  
  222.             gt_classes[ix] = cls  
  223. #因为是groud-truth,所以重叠率设置为1  
  224.             overlaps[ix, cls] = 1.0  
  225.             seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)  
  226.   
  227.         overlaps = scipy.sparse.csr_matrix(overlaps)  
  228. #返回一个字典  
  229.         return {'boxes' : boxes,  
  230.                 'gt_classes': gt_classes,  
  231.                 'gt_overlaps' : overlaps,  
  232.                 'flipped' : False,  
  233.                 'seg_areas' : seg_areas}  
  234.   
  235.     def _get_comp_id(self):  
  236.         comp_id = (self._comp_id + '_' + self._salt if self.config['use_salt']  
  237.             else self._comp_id)  
  238.         return comp_id  
  239.   
  240.     def _get_voc_results_file_template(self):  
  241.         # VOCdevkit/results/VOC2007/Main/<comp_id>_det_test_aeroplane.txt  
  242.         filename = self._get_comp_id() + '_det_' + self._image_set + '_{:s}.txt'  
  243.         path = os.path.join(  
  244.             self._devkit_path,  
  245.             'results',  
  246.             'VOC' + self._year,  
  247.             'Main',  
  248.             filename)  
  249.         return path  
  250.   
  251.     def _write_voc_results_file(self, all_boxes):  
  252.         for cls_ind, cls in enumerate(self.classes):  
  253.             if cls == '__background__':  
  254.                 continue  
  255.             print 'Writing {} VOC results file'.format(cls)  
  256.             filename = self._get_voc_results_file_template().format(cls)  
  257.             with open(filename, 'wt') as f:  
  258.                 for im_ind, index in enumerate(self.image_index):  
  259.                     dets = all_boxes[cls_ind][im_ind]  
  260.                     if dets == []:  
  261.                         continue  
  262.                     # the VOCdevkit expects 1-based indices  
  263.                     for k in xrange(dets.shape[0]):  
  264.                         f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.  
  265.                                 format(index, dets[k, -1],  
  266.                                        dets[k, 0] + 1, dets[k, 1] + 1,  
  267.                                        dets[k, 2] + 1, dets[k, 3] + 1))  
  268.   
  269.     def _do_python_eval(self, output_dir = 'output'):  
  270.         annopath = os.path.join(  
  271.             self._devkit_path,  
  272.             'VOC' + self._year,  
  273.             'Annotations',  
  274.             '{:s}.xml')  
  275.         imagesetfile = os.path.join(  
  276.             self._devkit_path,  
  277.             'VOC' + self._year,  
  278.             'ImageSets',  
  279.             'Main',  
  280.             self._image_set + '.txt')  
  281.         cachedir = os.path.join(self._devkit_path, 'annotations_cache')  
  282.         aps = []  
  283.         # The PASCAL VOC metric changed in 2010  
  284.         use_07_metric = True if int(self._year) < 2010 else False  
  285.         print 'VOC07 metric? ' + ('Yes' if use_07_metric else 'No')  
  286.         if not os.path.isdir(output_dir):  
  287.             os.mkdir(output_dir)  
  288.         for i, cls in enumerate(self._classes):  
  289.             if cls == '__background__':  
  290.                 continue  
  291.             filename = self._get_voc_results_file_template().format(cls)  
  292.             rec, prec, ap = voc_eval(  
  293.                 filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5,  
  294.                 use_07_metric=use_07_metric)  
  295.             aps += [ap]  
  296.             print('AP for {} = {:.4f}'.format(cls, ap))  
  297.             with open(os.path.join(output_dir, cls + '_pr.pkl'), 'w') as f:  
  298.                 cPickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)  
  299.         print('Mean AP = {:.4f}'.format(np.mean(aps)))  
  300.         print('~~~~~~~~')  
  301.         print('Results:')  
  302.         for ap in aps:  
  303.             print('{:.3f}'.format(ap))  
  304.         print('{:.3f}'.format(np.mean(aps)))  
  305.         print('~~~~~~~~')  
  306.         print('')  
  307.         print('--------------------------------------------------------------')  
  308.         print('Results computed with the **unofficial** Python eval code.')  
  309.         print('Results should be very close to the official MATLAB eval code.')  
  310.         print('Recompute with `./tools/reval.py --matlab ...` for your paper.')  
  311.         print('-- Thanks, The Management')  
  312.         print('--------------------------------------------------------------')  
  313.   
  314.     def _do_matlab_eval(self, output_dir='output'):  
  315.         print '-----------------------------------------------------'  
  316.         print 'Computing results with the official MATLAB eval code.'  
  317.         print '-----------------------------------------------------'  
  318.         path = os.path.join(cfg.ROOT_DIR, 'lib''datasets',  
  319.                             'VOCdevkit-matlab-wrapper')  
  320.         cmd = 'cd {} && '.format(path)  
  321.         cmd += '{:s} -nodisplay -nodesktop '.format(cfg.MATLAB)  
  322.         cmd += '-r "dbstop if error; '  
  323.         cmd += 'voc_eval(\'{:s}\',\'{:s}\',\'{:s}\',\'{:s}\'); quit;"' \  
  324.                .format(self._devkit_path, self._get_comp_id(),  
  325.                        self._image_set, output_dir)  
  326.         print('Running:\n{}'.format(cmd))  
  327.         status = subprocess.call(cmd, shell=True)  
  328.   
  329.     def evaluate_detections(self, all_boxes, output_dir):  
  330.         self._write_voc_results_file(all_boxes)  
  331.         self._do_python_eval(output_dir)  
  332.         if self.config['matlab_eval']:  
  333.             self._do_matlab_eval(output_dir)  
  334.         if self.config['cleanup']:  
  335.             for cls in self._classes:  
  336.                 if cls == '__background__':  
  337.                     continue  
  338.                 filename = self._get_voc_results_file_template().format(cls)  
  339.                 os.remove(filename)  
  340.   
  341.     def competition_mode(self, on):  
  342.         if on:  
  343.             self.config['use_salt'] = False  
  344.             self.config['cleanup'] = False  
  345.         else:  
  346.             self.config['use_salt'] = True  
  347.             self.config['cleanup'] = True  
  348.   
  349. if __name__ == '__main__':  
  350.     from datasets.pascal_voc import pascal_voc  
  351.     d = pascal_voc('trainval''2007')  
  352.     res = d.roidb  
  353.     from IPython import embed; embed()  



http://blog.csdn.net/iamzhangzhuping/article/category/6230157

http://blog.csdn.net/u010668907/article/category/6237110

Faster RCNN代码理解(Python)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值