Faster R-CNN代码学习(一)——datasets模块
源代码:https://github.com/smallcorgi/Faster-RCNN_TF
datasets模块在lib文件夹下,负责的是对数据集的操作,包含数据集对象的创建、载入过程,这一部分在训练自己的数据集时需要重点修改。
datasets模块主要包含3个py文件,分别为所有数据集类的父类imdb.py;根据数据集特有创建的以VOC为例,pascal_voc.py;用于迅速创建数据集对象的factory.py。
下面依次进行介绍。
imdb.py
- imdb为所有数据集的父类,因此包含了所有数据集共有的属性。
class imdb(object, name):
def __init__(self, name)
self._name = name
self._classes = []
self._num_classes = []
self._image_index = []
self._roidb = None
self._roidb_handler = self.default_roidb # 是一个指针,指向不同的roi生成函数
self.config = {
}
对于每一个数据集,其共有的属性都包含数据集名称name、数据集里有的类classes、数据集的图片样本image_index、数据集中的roi集合以及相关的设置config。
- 由于这些是私有属性,那么需要通过
装饰器property
将其取出,因此下面代码的主要内容为get这些属性。
@property
def name(self):
return self._name
@property
def classes(self):
return self._classes
@property
def num_classes(self):
return len(self._classes)
@property
def image_index(self):
return self._image_index
@property
def num_images(self):
return len(self.image_index)
@property
def roidb_handler(self):
return self._roidb_handler
@roidb_handler.setter
def roidb_handler(self, val):
self._roidb_handler = val
@property
def roidb(self):
# 如果已经有了,那么直接返回,没有就通过指针指向的函数生成
if self._roidb is not None:
return self._roidb
self._roidb = self.roidb_handler()
return self._roidb
# cache_path用来生成roidb缓存文件的文件夹,用来存储数据集的roi
@property
def cache_path(self):
cache_path = osp.abspath(osp.join(cfg.DATA_DIR, 'cache'))
if not osp.exists(cache_path):
os.makedirs(cache_path)
return cache_path
- 部分方法需要依靠具体的数据集及相应路径来制定,因此仅声明接口:
def default_roidb(self):
raise NotImplementedError
def image_path_at(self):
raise NotImplementedError
- 数据集的共有方法:数据翻转扩增、recall指标评估、通过提供的Box_list创建roidb
# 在数据扩增前需要获取每张图片的width,这里引入了python通用的图片处理扩展包PIL
def _get_width(self):
return [PIL.Image.open(self.image_path_at(i)).size[0]
for i in range(self.num_images)]
# 这里对所有的图片进行数据扩增,这部分roidb的属性仅改变了相应x坐标及flipped
def append_flipped_images(self):
num_images = self.num_images
widths = self._get_width()
for i in range(num_images):
boxes = self.roidb[i].boxes.copy()
oldx1 = boxes[:, 0].copy()
oldx2 = boxes[:, 2].copy()
boxes[:, 0] = widths[i] - oldx2 - 1
boxes[:, 2] = widths[i] - oldx1 - 1
assert (boxes[:, 2] >= boxes[:, 0]).all()
entry = {
'boxes': boxes,
'gt_overlaps': self.roidb[i]['gt_overlaps'],
'gt_classes': self.roidb[i]['gt_classes'],
'flipped': True}
self.roidb.append(entry)
self._image_index = self._image_index * 2
recall指标评估是根据候选框来确定候选框的recall值
def evaluate_racall(self, candidate_boxes=None, thresholds=None,
area='