FasterRCNN专题:源代码分析1-数据集的加载。

在上一篇博文中,我们对Windows下基于tensorflow的FasterRCNN开发环境搭建进行了介绍,下面从本文开始,我们将对源代码进行详细的介绍,首先将介绍数据处理部分的代码。那么我们自然而然地以train.py函数为入口,可以看到主函数只有两行执行代码:

if __name__ == '__main__':
    train = Train()
    train.train()

创建Train类的实例,之后调用train()函数进行训练。首先来看Train类的初始化函数:

    def __init__(self):
        # Create network
        if cfg.FLAGS.network == 'vgg16':
            self.net = vgg16(batch_size=cfg.FLAGS.ims_per_batch)
        else:
            raise NotImplementedError
        self.imdb, self.roidb = combined_roidb("voc_2007_trainval")
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.output_dir = cfg.get_output_dir(self.imdb, 'default')

考虑代码cfg.FLAGS.network == 'vgg16',cfg为\lib\config\config.py的实例,config.py主要使用了tensorflow的tf.app.flags模块,用于为Tensorflow程序实现命令行标志,例如操作:

tf.app.flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.'),

这里第一个参数是参数名称,第二个是参数默认值,第三个是参数描述。在config.py文件中,可以看到FLAGS下定义了很多变量,可以找到cfg.FLAGS.network的定义为:

tf.app.flags.DEFINE_string('network', "vgg16", "The network to be used as backbone")

可以看到这里定义network的string变量初始化为'vgg16',即默认的网络结构,因此如果需要使用其他网络结构我们可以在config.py中修改这句代码。随后执行self.net = vgg16(batch_size=cfg.FLAGS.ims_per_batch),这里创建了vgg16类的一个实例net,vgg16网络结构的定义在\lib\nets\vgg16.py,这里不再详述。

考虑代码self.imdb, self.roidb = combined_roidb("voc_2007_trainval"),调用combined_roidb()函数:

def combined_roidb(imdb_names):
    """
    Combine multiple roidbs
    """
    def get_roidb(imdb_name):
        imdb = get_imdb(imdb_name)
        print('Loaded dataset `{:s}` for training'.format(imdb.name))
        imdb.set_proposal_method("gt")
        print('Set proposal method: {:s}'.format("gt"))
        roidb = get_training_roidb(imdb)
        return roidb
    roidbs = [get_roidb(s) for s in imdb_names.split('+')]
    roidb = roidbs[0]
    if len(roidbs) > 1:
        for r in roidbs[1:]:
            roidb.extend(r)
        tmp = get_imdb(imdb_names.split('+')[1])
        imdb = imdb2(imdb_names, tmp.classes)
    else:
        imdb = get_imdb(imdb_names)
    return imdb, roidb

显然voc_2007_trainval是数据集的名称,首先执行代码roidbs = [get_roidb(s) for s in imdb_names.split('+')],实际就是调用了get_roidb("voc_2007_trainval"),进入函数get_roidb,首先执行imdb = get_imdb("voc_2007_trainval"),get_imdb函数的定义在lib\datasets\factory.py中:

for year in ['2007', '2012']:
  for split in ['train', 'val', 'trainval', 'test']:
    name = 'voc_{}_{}'.format(year, split)
    __sets[name] = (lambda split=split, year=year: pascal_voc(split, year))
# Set up coco_2014_<split>
for year in ['2014']:
  for split in ['train', 'val', 'minival', 'valminusminival', 'trainval']:
    name = 'coco_{}_{}'.format(year, split)
    __sets[name] = (lambda split=split, year=year: coco(split, year))
# Set up coco_2015_<split>
for year in ['2015']:
  for split in ['test', 'test-dev']:
    name = 'coco_{}_{}'.format(year, split)
    __sets[name] = (lambda split=split, year=year: coco(split, year))

def get_imdb(name):
  """Get an imdb (image database) by name."""
  if name not in __sets:
    raise KeyError('Unknown dataset: {}'.format(name))
  return __sets[name]()

可以看到首先根据年代(year){2007,2012,2015} 和训练测试类型(split){'train', 'val', 'trainval', 'test'},不同的(split,year)通过lambda匿名调用lib\datasets\coco.py中coco类或者lib\datasets\pascal_voc.py的构造函数,生成coco类或者pascal类实例,最后__sets中的元素为根据不同名字({}_{year}_{split})生成的不同参数的coco类或者pascal_voc构造函数调用,get_imdb函数会返回对应的coco或者pascal_voc实例,对于get_roidb("voc_2007_trainval")可以得到实际返回的是pascal_voc类的构造函数pascal_voc('trainval','2007')初始化的实例,即imdb为pascal_voc类的一个实例。回到combined_roidb函数,接下来执行代码:

print('Loaded dataset `{:s}` for training'.format(imdb.name))

现在我们可以通过一些变量名知道,imdb即图像数据集,roidb即感兴趣区域的数据集,factory即为工厂,它的函数get_imdb根据名字返回图像数据集对应的类的对象。这里imdb_name为“voc_2007_trainval”,我们可以知道get_imdb返回的是pascal_voc (‘trainval’, ‘2007’),而进入pascal_voc文件(lib\datasets\pascal_voc.py):

class pascal_voc(imdb):
    def __init__(self, image_set, year, devkit_path=None):
        imdb.__init__(self, 'voc_' + year + '_' + image_set)
        self._year = year
        self._image_set = image_set
        self._devkit_path = self._get_default_path() if devkit_path is None \
            else devkit_path
        self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)
        self._classes = ('__background__',  # always index 0
                         'aeroplane', 'bicycle', 'bird', 'boat',
                         'bottle', 'bus', 'car', 'cat', 'chair',
                         'cow', 'diningtable', 'dog', 'horse',
                         'motorbike', 'person', 'pottedplant',
                         'sheep', 'sofa', 'train', 'tvmonitor')
        self._class_to_ind = dict(list(zip(self.classes, list(range(self.num_classes)))))
        self._image_ext = '.jpg'
        self._image_index = self._load_image_set_index()
        # Default to roidb handler
        print('Step1:')
        self._roidb_handler = self.gt_roidb
        print('Step2:')
        self._salt = str(uuid.uuid4())
        self._comp_id = 'comp4'

        # PASCAL specific config options
        self.config = {'cleanup': True,
                       'use_salt': True,
                       'use_diff': False,
                       'matlab_eval': False,
                       'rpn_file': None}

        assert os.path.exists(self._devkit_path), \
            'VOCdevkit path does not exist: {}'.format(self._devkit_path)
        assert os.path.exists(self._data_path), \
            'Path does not exist: {}'.format(self._data_path)

我们可以看到pascal_voc是一个继承于imdb类的子类(class pascal_voc(imdb):)在__init__函数中,我们可以看到该类记录了文件路径、目标类别还有图像扩展名等与数据集相关的内容,而父类imdb中则是名字等一般性的内容。由于pascal_voc继承自imdb类,这里的imdb.name其实调用的是基类imdb的成员函数name(),根据之前输入的参数容易得到其值为voc_2007_trainval,故打印输出为:

Loaded dataset `voc_2007_trainval` for training

回到train类的init函数,imdb = get_imdb("voc_2007_trainval")实际执行的是imdb = pascal_voc('trainval','2007'),返回的是pascal_voc类的一个实例。

接下来执行代码:

  1. imdb.set_proposal_method("gt")

imdb.set_proposal_method同样是基类imdb的成员函数,代码如下:

 def set_proposal_method(self, method):
        method = eval('self.' + method + '_roidb')
        self.roidb_handler = method

其中,eval() 函数用来执行一个字符串表达式,并返回表达式的值。故其实执行的是 self.roidb_handler = self.gt_roidb,gt_roidb的定义在继承类pascal_voc中。接着执行代码:

roidb = get_training_roidb(imdb)

函数get_training_roidb定义如下:

def get_training_roidb(imdb):
    """Returns a roidb (Region of Interest database) for use in training."""
    if True:
        print('Appending horizontally-flipped training examples...')
        imdb.append_flipped_images()
        print('done')

    print('Preparing training data...')
    rdl_roidb.prepare_roidb(imdb)
    print('done')

    return imdb.roidb

这里首先执行 imdb.append_flipped_images()函数,其定义在基类imdb中:

   def append_flipped_images(self):
        num_images = self.num_images
        widths = self._get_widths()
        for i in range(num_images):
            boxes = self.roidb[i]['boxes'].copy()
            oldx1 = boxes[:, 0].copy()
            oldx2 = boxes[:, 2].copy()
            boxes[:, 0] = widths[i] - oldx2 - 1
            boxes[:, 2] = widths[i] - oldx1 - 1
            assert (boxes[:, 2] >= boxes[:, 0]).all()
            entry = {'boxes': boxes,
                     'gt_overlaps': self.roidb[i]['gt_overlaps'],
                     'gt_classes': self.roidb[i]['gt_classes'],
                     'flipped': True}
            self.roidb.append(entry)
        self._image_index = self._image_index * 2

首先看代码:

num_images = self.num_images,这里调用imdb类的成员函数num_images   

def num_images(self):
        return len(self.image_index)

因此,其实际是image_index的长度,而image_index的初始化在类pascal_voc的__init__函数中初始化:

 self._image_index = self._load_image_set_index()

进一步查看函数_load_image_set_index():

   def _load_image_set_index(self):
        """
        Load the indexes listed in this dataset's image set file.
        """
        # Example path to image set file:
        # self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt
        image_set_file = os.path.join(self._data_path, 'ImageSets', 'Main',
                                      self._image_set + '.txt')
        assert os.path.exists(image_set_file), \
            'Path does not exist: {}'.format(image_set_file)
        with open(image_set_file) as f:
            image_index = [x.strip() for x in f.readlines()]
        return image_index

由前面构造函数pascal_voc可知,这里image_set_file=’XXX/data/VOCdevkit2007/VOC2007/ImageSets/Main/trainval.txt‘,XXX/data/VOCdevkit2007/VOC2007/ImageSets/Main/有如下几个文件:

内容类似为:

000005
000027
000028
000033
000042
000045
000048
000058

即JPEGImages文件夹下图片的名字(无后缀),test.txt是测试集,train.txt是训练集,val.txt是验证集,trainval.txt是训练和验证集.VOC2007中,trainval大概是整个数据集的50%,test也大概是整个数据集的50%;train大概是trainval的50%,val大概是trainval的50%。因此image_index存储了训练和验证集所有图像的名字,对应pascal的成员变量_image_index,其在构造函数中完成初始化,存储训练和验证集trainval.txt中所有图像的索引,那么num_images为训练验证集图像的数量。 回到append_flipped_images函数,第二行执行:

widths = self._get_widths()

为imdb的成员函数:

    def _get_widths(self):
        return [PIL.Image.open(self.image_path_at(i)).size[0]
                for i in range(self.num_images)]

image_path_at为imdb类中的接口,其实际实现在类pascal_voc中,

    def image_path_at(self, i):
        """
        Return the absolute path to image i in the image sequence.
        """
        return self.image_path_from_index(self._image_index[i])

其中调用函数image_path_from_index:

    def image_path_from_index(self, index):
        """
        Construct an image path from the image's "index" identifier.
        """
        image_path = os.path.join(self._data_path, 'JPEGImages',
                                  index + self._image_ext)
        assert os.path.exists(image_path), \
            'Path does not exist: {}'.format(image_path)
        return image_path

显然,image_path_at的作用是返回所有带完整路径的验证图像文件名。PIL.Image.open返回的是PIL.JpegImagePlugin.JpegImageFile类,size[0]对应图像的宽度,故widths存储每个检验集图像的宽度。继续回到append_flipped_images函数,接下来在num_images个循环中依次执行代码:

boxes = self.roidb[i]['boxes'].copy()

这里首先调用imdb的成员函数roidb:

    def roidb(self):
        # A roidb is a list of dictionaries, each with the following keys:
        #   boxes
        #   gt_overlaps
        #   gt_classes
        #   flipped
        if self._roidb is not None:
            return self._roidb
        self._roidb = self.roidb_handler()
        return self._roidb

可以看到第一次调用时会调用roidb_handler函数即pascal_voc的成员函数gt_roidb:

 def gt_roidb(self):
        """
        Return the database of ground-truth regions of interest.
        This function loads/saves from/to a cache file to speed up future calls.
        """
        print('here')
        cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
        if os.path.exists(cache_file):
            with open(cache_file, 'rb') as fid:
                try:
                    roidb = pickle.load(fid)
                except:
                    roidb = pickle.load(fid, encoding='bytes')
            print('{} gt roidb loaded from {}'.format(self.name, cache_file))
            return roidb
        gt_roidb = [self._load_pascal_annotation(index)
                    for index in self.image_index]
        with open(cache_file, 'wb') as fid:
            pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)
        print('wrote gt roidb to {}'.format(cache_file))
        return gt_roidb

cache_file 为“Faster-RCNN-TensorFlow-Python3-master\data\cache\voc_2007_trainval_gt_roidb.pkl”,此文件作用是一个缓存文件,用来加速特征调用,第一次调用时并不存在,因此第一次调用时首先执行

gt_roidb = [self._load_pascal_annotation(index)
                    for index in self.image_index]

_load_pascal_annotation函数的定义如下:

    def _load_pascal_annotation(self, index):
        """
        Load image and bounding boxes info from XML file in the PASCAL VOC
        format.
        """
        filename = os.path.join(self._data_path, 'Annotations', index + '.xml')
        tree = ET.parse(filename)
        objs = tree.findall('object')
        if not self.config['use_diff']:
            # Exclude the samples labeled as difficult
            non_diff_objs = [
                obj for obj in objs if int(obj.find('difficult').text) == 0]
            # if len(non_diff_objs) != len(objs):
            #     print 'Removed {} difficult objects'.format(
            #         len(objs) - len(non_diff_objs))
            objs = non_diff_objs
        num_objs = len(objs)
        boxes = np.zeros((num_objs, 4), dtype=np.uint16)
        gt_classes = np.zeros((num_objs), dtype=np.int32)
        overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
        # "Seg" area for pascal is just the box area
        seg_areas = np.zeros((num_objs), dtype=np.float32)
        # Load object bounding boxes into a data frame.
        for ix, obj in enumerate(objs):
            bbox = obj.find('bndbox')
            # Make pixel indexes 0-based
            x1 = float(bbox.find('xmin').text) - 1
            y1 = float(bbox.find('ymin').text) - 1
            x2 = float(bbox.find('xmax').text) - 1
            y2 = float(bbox.find('ymax').text) - 1
            cls = self._class_to_ind[obj.find('name').text.lower().strip()]
            boxes[ix, :] = [x1, y1, x2, y2]
            gt_classes[ix] = cls
            overlaps[ix, cls] = 1.0
            seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)
        overlaps = scipy.sparse.csr_matrix(overlaps)
        return {'boxes': boxes,
                'gt_classes': gt_classes,
                'gt_overlaps': overlaps,
                'flipped': False,
                'seg_areas': seg_areas}

_load_pascal_annotation函数读取每个image_index对应图像的信息,返回一个字典,

字典的内容包括:(num_objs行4列的矩阵,num_objs是图片中的物体数量,4列是每个物体左上角和右下角的坐标),'gt_classes'(包含num_objs个元素的一维矩阵,矩阵中每个元素的取值表示:物体所对应的类别号),'gt_overlaps'(num_objs行21列的矩阵,这个矩阵表示的是:每个物体的box和gt_box的IOU,由于这里本来就只有gt_box,所以IOU为1,因此这个矩阵的每一行中,只有这个物体所对应的类别号所在的列元素为1,其余都为0),'flipped'(有没有翻转,默认没有,所以取值为:False),'seg_areas'(包含num_objs个元素的一维矩阵,矩阵中每个元素的取值表示:这个物体的gt_box的面积)。将每个图像对应的字典信息存入voc_2007_trainval_gt_roidb.pkl文件中,下次直接调用就可以了。gt_roidb函数最终返回所有图像的字典信息,且下次调用不需要再进行读取操作,直接返回图像的字典信息。

    回到append_flipped_images函数,boxes = self.roidb[i]['boxes'].copy()表示第i幅图的目标物体的范围坐标((x1,y1),(x2,y2)),每个图像的信息都存储在List roidb中,即imdb的成员roidb中每个元素对应一个图像的字典(roidb格式:<type 'list'>: [{'boxes': array([[0, 0, 0, 0]], dtype=uint16), 'gt_classes': array([15], dtype=int32), 'gt_overlaps': <1x21 sparse matrix of type '<type 'numpy.float32'>':

  boxes: num_objs,four rows.the proposal.left-up,right-down

  gt_overlaps: len(box)*类别数(即,每个box对应的类别。初始化时,从xml读出来的类别对应类别值是1.0,被压缩保存)

  gt_classes: 每个box的类别索引

  flipped: true,代表图片被水平反转,改变了boxes里第一、三列的值(所有原图都这样的操作,imdb.image_index*2)(cfg.TRAIN.USE_FLIPPED会导致此操作的发生,见train.py 116行)

  seg_areas: box的面积。

也就是说在append_flipped_images函数中,循环for i in range(num_images):在第一次循环时,即执行boxes = self.roidb[0]['boxes'].copy(),imdb的成员roidb存储了所有原图像集的信息字典。随后的代码:

boxes = self.roidb[i]['boxes'].copy()
oldx1 = boxes[:, 0].copy()
oldx2 = boxes[:, 2].copy()
boxes[:, 0] = widths[i] - oldx2 - 1
boxes[:, 2] = widths[i] - oldx1 - 1
assert (boxes[:, 2] >= boxes[:, 0]).all()
entry = {'boxes': boxes,
'gt_overlaps': self.roidb[i]['gt_overlaps'],
'gt_classes': self.roidb[i]['gt_classes'],
'flipped': True}
self.roidb.append(entry)

对每一个图像进行了反转,作为一个新的图像,同样存入信息字典到roidb中。因此,图像集规模变为原来2倍,故最后执行:

self._image_index = self._image_index * 2

接下来执行代码 rdl_roidb.prepare_roidb(imdb),rdl_roidb为roidb(\lid\datasets\roidb.py)的实例,其作用是通过加入一串元数据(metadata)将感兴趣区域数据集(roidb)转换为可训练roidb,代码如下:

def prepare_roidb(imdb):
  """Enrich the imdb's roidb by adding some derived quantities that
  are useful for training. This function precomputes the maximum
  overlap, taken over ground-truth boxes, between each ROI and
  each ground-truth box. The class with maximum overlap is also
  recorded.
  """
  roidb = imdb.roidb
  if not (imdb.name.startswith('coco')):
    sizes = [PIL.Image.open(imdb.image_path_at(i)).size
         for i in range(imdb.num_images)]
  for i in range(len(imdb.image_index)):
    roidb[i]['image'] = imdb.image_path_at(i)
    if not (imdb.name.startswith('coco')):
      roidb[i]['width'] = sizes[i][0]
      roidb[i]['height'] = sizes[i][1]
    # need gt_overlaps as a dense array for argmax
    gt_overlaps = roidb[i]['gt_overlaps'].toarray()
    # max overlap with gt over classes (columns)
    max_overlaps = gt_overlaps.max(axis=1)
    # gt class that had the max overlap
    max_classes = gt_overlaps.argmax(axis=1)
    roidb[i]['max_classes'] = max_classes
    roidb[i]['max_overlaps'] = max_overlaps
    # sanity checks
    # max overlap of 0 => class should be zero (background)
    zero_inds = np.where(max_overlaps == 0)[0]
    assert all(max_classes[zero_inds] == 0)
    # max overlap > 0 => class should not be zero (must be a fg class)
    nonzero_inds = np.where(max_overlaps > 0)[0]
    assert all(max_classes[nonzero_inds] != 0)

前面已经看到了roidb为字典结构其存储了图像集相关信息,可以看到该函数进一步对roidb的图像数据集信息做了处理。

最后回到combined_roidb函数,分别返回imdb(pascal_voc('trainval','2007'))和roidb。对于代码循环部分:

    if len(roidbs) > 1:
        for r in roidbs[1:]:
            roidb.extend(r)
        tmp = get_imdb(imdb_names.split('+')[1])
        imdb = imdb2(imdb_names, tmp.classes)
    else:
        imdb = get_imdb(imdb_names)

表示多个数据集的话,用"+"隔开,依次按上述顺序处理,并依次存入imdb和roidb。至此训练验证集数据处理完成,整个过程是首先根据指定数据集,定义了其对应的数据集处理类,随后数据集处理类读取数据集相关信息,并作了一定的准备工作。

接下来执行:

self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)

RoIDataLayer类的功能比较好理解,构造函数

    def __init__(self, roidb, num_classes, random=False):
        """Set the roidb to be used by this layer during training."""
        self._roidb = roidb
        self._num_classes = num_classes
        # Also set a random flag
        self._random = random
        self._shuffle_roidb_inds()

即将图像字典信息列表roidb赋值给其成员_roidb,分类个数赋值给其成员_num_classes,最后调用_shuffle_roidb_inds随机打乱图像集的顺序,打乱后的顺序存储在成员变量_perm中,函数:

    def _get_next_minibatch_inds(self):
        """Return the roidb indices for the next minibatch."""

        if self._cur + cfg.FLAGS.ims_per_batch >= len(self._roidb):
            self._shuffle_roidb_inds()

        db_inds = self._perm[self._cur:self._cur + cfg.FLAGS.ims_per_batch]
        self._cur += cfg.FLAGS.ims_per_batch

        return db_inds

随后,_get_next_minibatch_inds负责取出下一个训练需要的batch数据对应的roidb索引。

  • 7
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值