我用的Faster-RCNN是tensorflow版本,fork自githubFaster-RCNN_TF
参考博客http://www.cnblogs.com/CarryPotMan/p/5390336.html
在用自己的数据训练Faster-RCNN,tensorflow版本(一)中我们详细介绍了Faster-rcnn_TF中pascal_voc数据的读写接口,接下来介绍一下,如何编写自己的数据读写接口。
3、编写自己的数据读写接口
我们要用自己的数据进行训练,就得编写自己数据的读写接口,下面参考pascal_voc.py来编写。根据用自己的数据训练Faster-RCNN,tensorflow版本(一)中对pascal_voc.py文件的分析,发现,pascal_voc.py用了非常多的路径拼接,很麻烦,我们不用这么麻烦,简单一点就可以。
3.1、介绍一下我自己的训练数据集格式
我主要是从自然图片中检测出文本,因此我只有background 和text两类物体,我并没有像pascal_voc数据集里面一样每个图像用一个xml来标注,先说一下我的数据格式:
所有需要用到的数据我都放在了目录Data/ID_card/下面。
目录Data/ID_card/下面包含2个文件夹,分别是train,test。
先介绍train,目录Data/ID_card/train/里面包含:
1、所有的训练图片
2、gt_ID_card.txt
3、train.txt
我把train集合中所有图片的gt,集中放在了一个gt_ID_card.txt文件里面,gt_ID_card.txt格式如下:
以第一行为例:
ID_card/back_1.jpg: 是图片的名字;
数字1:代表该张图片上只有一个文本(text);
后面的四个数值:分别是文本框左上角和右下角的坐标。我的图片里面只有一行文本,所以只有一组文本框的坐标。
train.txt文件存放的是所有图片的名字,没有后缀,如下图:
3.2、编写自己的数据读写接口ID_card.py
主要修改的关键函数就是:def _load_annotation(self)——读取图片gt。
编写自己的数据读写接口ID_card.py,内容如下:
#coding:utf-8
# --------------------------------------------------------
#
# Written by lisiqi
# --------------------------------------------------------
import datasets
import os
import datasets.imdb
import xml.dom.minidom as minidom
import numpy as np
import scipy.sparse
import scipy.io as sio
import utils.cython_bbox
import cPickle
import subprocess
class ID_card(datasets.imdb):
def __init__(self, image_set, data_path=None):
datasets.imdb.__init__(self, 'ID_card_' + image_set) #image_set 为train或者val或者trainval或者test。
self._image_set = image_set # image_set以train为例
self._data_path = data_path # 数据所在的路径,根据传进来的参数data_path而定。传进来的参数data_path在我这里就是Data/ID_card/
self._classes = ('__background__','text') #object的类别,只有两类:背景和文本
self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes))) #构成字典{'__background__':'0','text':'1'}
self._image_ext = '.jpg' #图片后缀
self._image_index = self._load_image_set_index() #读取train.txt,获取图片名称(该图片名称没有后缀.jpg)
# Default to roidb handler
self._roidb_handler = self.gt_roidb #获取图片的gt
# PASCAL specific config options
self.config = {
'cleanup' : True,
'use_salt' : True,
'top_k' : 2000}
assert os.path.exists(self._data_path), \ #如果路径Data/ID_card不存在,退出
'Image Path does not exist: {}'.format(self._data_path)
def image_path_at(self, i):#获得_image_index 下标为i的图像的路径
"""
Return the absolute path to image i in the image sequence.
"""
return self.image_path_from_index(self._image_index[i])
def image_path_from_index(self, index):#根据_image_index获取图像路径
"""
Construct an image path from the image's "index" identifier.
"""
image_path = os.path.join(self._data_path, index, self._image_ext)
assert os.path.exists(image_path), \
'Path does not exist: {}'.format(image_path)
return image_path
def _load_image_set_index(self):#已做修改
"""
Load the indexes listed in this dataset's image set file.
得到图片名称的list。这个list里面是集合self._image_set=train中所有图片的名字(注意,图片名字没有后缀.jpg)
"""
image_set_file = os.path.join(self._data_path, self._image_set, self._image_set + '.txt')
#image_set_file是Data/ID_card/train/train.txt
#之所以要读这个train.txt文件,是因为train.txt文件里面写的是集合train中所有图片的名字(没有后缀.jpg)
assert os.path.exists(image_set_file), \
'Path does not exist: {}'.format(image_set_file)
with open(image_set_file) as f: #读取train.txt,获取图片名称(没有后缀.jpg)
image_index = [x.strip() for x in