1.首先搭建Fast rcnn的环境(可以去网上搜)
2.首先可以看到fast rcnn的工程目录下有个Lib目录
这里下面存在4个目录分别是:
datasets
fast_rcnn
roi_data_layer
utils
在这里修改读写数据的接口主要是datasets目录下,fast_rcnn下面主要存放的是python的训练和测试脚本,以及训练的配置文件,roi_data_layer下面存放的主要是一些ROI处理操作,utils下面存放的是一些通用操作比如非极大值nms,以及计算bounding box的重叠率等常用功能
使用fast rcnn训练网络主要是构建自己的IMDB子类,修改后的datasets目录下的pascal_voc.py,,factory.py, imdb.py, __init__.py如下:
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
import datasets
import datasets.pascal_voc
import os
import datasets.imdb
import xml.dom.minidom as minidom
import numpy as np
import scipy.sparse
import scipy.io as sio
import utils.cython_bbox
import cPickle
import subprocess
class pascal_voc(datasets.imdb):
def __init__(self, image_set,devkit_path=None):
datasets.imdb.__init__(self, image_set)
self._image_set = image_set
self._devkit_path = devkit_path
self._data_path = os.path.join(self._devkit_path)
self._classes = ('__background__', # always index 0
'car', 'bike')
self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))
self._image_ext = '.jpg'
self._image_index = self._load_image_set_index()
# Default to roidb handler
self._roidb_handler = self.selective_search_roidb
# PASCAL specific config options
self.config = {'cleanup' : True,
'use_salt' : True,
'top_k' : 2000}
assert os.path.exists(self._devkit_path), \
'VOCdevkit path does not exist: {}'.format(self._devkit_path)
assert os.path.exists(self._data_path), \
'Path does not exist: {}'.format(self._data_path)
def image_path_at(self, i):
"""
Return the absolute path to image i in the image sequence.
"""
return self.image_path_from_index(self._image_index[i])
def image_path_from_index(self, index):
"""
Construct an image path from the image's "index" identifier.
"""
image_path = os.path.join(self._data_path,index)
assert os.path.exists(image_path), \
'Path does not exist: {}'.format(image_path)
return image_path
def _load_image_set_index(self):
"""
Load the indexes listed in this dataset's image set file.
"""
# Example path to image set file:
# self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt
image_set_file = os.path.join(self._data_path,self._image_set + '.txt')
assert os.path.exists(image_set_file), \
'Path does not exist: {}'.format(image_set_file)
with open(image_set_file) as f:
image_index = [x.strip() for x in f.readlines()]
return image_index
def gt_roidb(self):
"""
Return the database of ground-truth regions of interest.
This function loads/saves from/to a cache file to speed up future calls.
"""
cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
if os.path.exists(cache_file):
with open(cache_file, 'rb') as fid:
roidb = cPickle.load(fid)
print '{} gt roidb loaded from {}'.format(self.name, cache_file)
return roidb
gt_roidb = self._load_pascal_annotation()
with open(cache_file, 'wb') as fid:
cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)
print 'wrote gt roidb to {}'.format(cache_file)
return gt_roidb
def selective_search_roidb(self):
"""
Return the database of selective search regions of interest.
Ground-truth ROIs are also included.
This function loads/saves from/to a cache file to speed up future calls.
"""
cache_file = os.path.join(self.cache_path,
self.name + '_selective_search_roidb.pkl')
if os.path.exists(cache_file):
with open(cache_file, 'rb') as fid:
roidb = cPickle.load(fid)
print '{} ss roidb loaded from {}'.format(self.name, cache_file)
return roidb
if self._image_set != 'test':
gt_roidb = self.gt_roidb()
ss_roidb = self._load_selective_search_roidb(gt_roidb)
roidb = datasets.imdb.merge_roidbs(gt_roidb, ss_roidb)
else:
roidb = self._load_selective_search_roidb(None)
with open(cache_file, 'wb') as fid:
cPickle.dump(roidb, fid, cPickle.HIGHEST_PROTOCOL)
print 'wrote ss roidb to {}'.format(cache_file)
return roidb
def _load_selective_search_roidb(self, gt_roidb):
filename = os.path.abspath(os.path.join(self.cache_path, '..',
'train' + '.mat'))
assert os.path.exists(filename), \
'Selective search data not found at: {}'.format(filename)
raw_data = sio.loadmat(filename)['boxes'].ravel()
box_list = []
for i in xrange(raw_data.shape[0]):
box_list.append(raw_data[i][:, (1, 0, 3, 2)] - 1)
return self.create_roidb_from_box_list(box_list, gt_roidb)
def selective_search_IJCV_roidb(self):
"""
Return the database of selective search regions of interest.
Ground-truth ROIs are also included.
This function loads/saves from/to a cache file to speed up future calls.
"""
cache_file = os.path.join(self.cache_path,
'{:s}_selective_search_IJCV_top_{:d}_roidb.pkl'.
format(self.name, self.config['top_k']))
if os.path.exists(cache_file):
with open(cache_file, 'rb') as fid:
roidb = cPickle.load(fid)
print '{} ss roidb loaded from {}'.format(self.name, cache_file)
return roidb
gt_roidb = self.gt_roidb()
ss_roidb = self._load_selective_search_IJCV_roidb(gt_roidb)
roidb = datasets.imdb.merge_roidbs(gt_roidb, ss_roidb)
with open(cache_file, 'wb') as fid:
cPickle.dump(roidb, fid, cPickle.HIGHEST_PROTOCOL)
print 'wrote ss roidb to {}'.format(cache_file)
return roidb
def _load_selective_search_IJCV_roidb(self, gt_roidb):
IJCV_path = os.path.abspath(os.path.join(self.cache_path, '..',
'selective_search_IJCV_data',
'voc_' + self._year))
assert os.path.exists(IJCV_path), \
'Selective search IJCV data not found at: {}'.format(IJCV_path)
top_k = self.config['top_k']
box_list = []
for i in xrange(self.num_images):
filename = os.path.join(IJCV_path, self.image_index[i] + '.mat')
raw_data = sio.loadmat(filename)
box_list.append((raw_data['boxes'][:top_k, :]-1).astype(np.uint16))
return self.create_roidb_from_box_list(box_list, gt_roidb)
def _load_pascal_annotation(self):
"""
Load image and bounding boxes info from XML file in the PASCAL VOC
format.
"""
#filename = os.path.join(self._data_path, 'Annotations', index + '.xml')
# print 'Loading: {}'.format(filename)
gt_roidb = []
annotationfile = os.path.join(self._data_path, 'annotations.txt')
f = open(annotationfile)
split_line = f.readline().strip().split()
num = 1
while(split_line):
num_objs = num
boxes = np.zeros((num_objs, 4), dtype=np.uint16)
gt_classes = np.zeros((num_objs), dtype=np.int32)
overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
for i in range(num_objs):
x1 = float( split_line[2 + i * 4])
y1 = float (split_line[3 + i * 4])
x2 = float (split_line[4 + i * 4])
y2 = float (split_line[5 + i * 4])
cls = self._class_to_ind[str(split_line[1])]
boxes[i,:] = [x1, y1, x2, y2]
gt_classes[i] = cls
overlaps[i,cls] = 1.0
overlaps = scipy.sparse.csr_matrix(overlaps)
gt_roidb.append({'boxes' : boxes, 'gt_classes': gt_classes, 'gt_overlaps' : overlaps, 'flipped' : False})
split_line = f.readline().strip().split()
f.close()
return gt_roidb
def _write_voc_results_file(self, all_boxes):
use_salt = self.config['use_salt']
comp_id = 'comp4'
if use_salt:
comp_id += '-{}'.format(os.getpid())
# VOCdevkit/results/VOC2007/Main/comp4-44503_det_test_aeroplane.txt
path = os.path.join(self._devkit_path, 'results', 'VOC' + self._year,
'Main', comp_id + '_')
for cls_ind, cls in enumerate(self.classes):
if cls == '__background__':
continue
print 'Writing {} VOC results file'.format(cls)
filename = path + 'det_' + self._image_set + '_' + cls + '.txt'
with open(filename, 'wt') as f:
for im_ind, index in enumerate(self.image_index):
dets = all_boxes[cls_ind][im_ind]
if dets == []:
continue
# the VOCdevkit expects 1-based indices
for k in xrange(dets.shape[0]):
f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.
format(index, dets[k, -1],
dets[k, 0] + 1, dets[k, 1] + 1,
dets[k, 2] + 1, dets[k, 3] + 1))
return comp_id
def _do_matlab_eval(self, comp_id, output_dir='output'):
rm_results = self.config['cleanup']
path = os.path.join(os.path.dirname(__file__),
'VOCdevkit-matlab-wrapper')
cmd = 'cd {} && '.format(path)
cmd += '{:s} -nodisplay -nodesktop '.format(datasets.MATLAB)
cmd += '-r "dbstop if error; '
cmd += 'voc_eval(\'{:s}\',\'{:s}\',\'{:s}\',\'{:s}\',{:d}); quit;"' \
.format(self._devkit_path, comp_id,
self._image_set, output_dir, int(rm_results))
print('Running:\n{}'.format(cmd))
status = subprocess.call(cmd, shell=True)
def evaluate_detections(self, all_boxes, output_dir):
comp_id = self._write_voc_results_file(all_boxes)
self._do_matlab_eval(comp_id, output_dir)
def competition_mode(self, on):
if on:
self.config['use_salt'] = False
self.config['cleanup'] = False
else:
self.config['use_salt'] = True
self.config['cleanup'] = True
if __name__ == '__main__':
d = datasets.pascal_voc('trainval', '/home/k40/fast-rcnn/BDCI/second/')
res = d.roidb
from IPython import embed; embed()
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
"""Factory method for easily getting imdbs by name."""
__sets = {}
import datasets.pascal_voc
import numpy as np
imageset = 'pascal_voc'
devkit = '/home/k40/fast-rcnn/BDCI/second/trainval'
'''
def _selective_search_IJCV_top_k(split, top_k):
"""Return an imdb that uses the top k proposals from the selective search
IJCV code.
"""
imdb = datasets.pascal_voc(split)
imdb.roidb_handler = imdb.selective_search_IJCV_roidb
imdb.config['top_k'] = top_k
return imdb
'''
# Set up voc_<year>_<split> using selective search "fast" mode
'''
for year in ['2007', '2012']:
for split in ['train', 'val', 'trainval', 'test']:
name = 'voc_{}_{}'.format(year, split)
__sets[name] = (lambda split=split, year=year:
datasets.pascal_voc(split, year))
'''
# Set up voc_<year>_<split>_top_<k> using selective search "quality" mode
# but only returning the first k boxes
'''
for top_k in np.arange(1000, 11000, 1000):
for split in ['train', 'val', 'trainval', 'test']:
name = 'voc_{}_{}_top_{:d}'.format(split, top_k)
__sets[name] = (lambda split=split,top_k=top_k:
_selective_search_IJCV_top_k(split,top_k))
'''
def get_imdb(name):
"""Get an imdb (image database) by name."""
__sets['pascal_voc'] = (lambda imageset = imageset, devkit = devkit: datasets.pascal_voc(imageset,devkit))
if not __sets.has_key(name):
raise KeyError('Unknown dataset: {}'.format(name))
return __sets[name]()
def list_imdbs():
"""List all registered imdbs."""
return __sets.keys()
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
import os
import os.path as osp
import PIL
from utils.cython_bbox import bbox_overlaps
import numpy as np
import scipy.sparse
import datasets
class imdb(object):
"""Image database."""
def __init__(self, name):
self._name = name
self._num_classes = 0
self._classes = []
self._image_index = []
self._obj_proposer = 'selective_search'
self._roidb = None
self._roidb_handler = self.default_roidb
# Use this dict for storing dataset specific config options
self.config = {}
@property
def name(self):
return self._name
@property
def num_classes(self):
return len(self._classes)
@property
def classes(self):
return self._classes
@property
def image_index(self):
return self._image_index
@property
def roidb_handler(self):
return self._roidb_handler
@roidb_handler.setter
def roidb_handler(self, val):
self._roidb_handler = val
@property
def roidb(self):
# A roidb is a list of dictionaries, each with the following keys:
# boxes
# gt_overlaps
# gt_classes
# flipped
if self._roidb is not None:
return self._roidb
self._roidb = self.roidb_handler()
return self._roidb
@property
def cache_path(self):
cache_path = osp.abspath(osp.join(datasets.ROOT_DIR, 'data', 'cache'))
if not os.path.exists(cache_path):
os.makedirs(cache_path)
return cache_path
@property
def num_images(self):
return len(self.image_index)
def image_path_at(self, i):
raise NotImplementedError
def default_roidb(self):
raise NotImplementedError
def evaluate_detections(self, all_boxes, output_dir=None):
"""
all_boxes is a list of length number-of-classes.
Each list element is a list of length number-of-images.
Each of those list elements is either an empty list []
or a numpy array of detection.
all_boxes[class][image] = [] or np.array of shape #dets x 5
"""
raise NotImplementedError
def append_flipped_images(self):
num_images = self.num_images
widths = [PIL.Image.open(self.image_path_at(i)).size[0]
for i in xrange(num_images)]
for i in xrange(num_images):
boxes = self.roidb[i]['boxes'].copy()
#print boxes
#print len(boxes)
oldx1 = boxes[:, 0].copy()
oldx2 = boxes[:, 2].copy()
boxes[:, 0] = widths[i] - oldx2 - 1
boxes[:, 2] = widths[i] - oldx1 - 1
#print i
#print boxes[:, 0]
#print boxes[:, 2]
assert (boxes[:, 2] >= boxes[:, 0]).all()
entry = {'boxes' : boxes,
'gt_overlaps' : self.roidb[i]['gt_overlaps'],
'gt_classes' : self.roidb[i]['gt_classes'],
'flipped' : True}
self.roidb.append(entry)
self._image_index = self._image_index * 2
def evaluate_recall(self, candidate_boxes, ar_thresh=0.5):
# Record max overlap value for each gt box
# Return vector of overlap values
gt_overlaps = np.zeros(0)
for i in xrange(self.num_images):
gt_inds = np.where(self.roidb[i]['gt_classes'] > 0)[0]
gt_boxes = self.roidb[i]['boxes'][gt_inds, :]
boxes = candidate_boxes[i]
if boxes.shape[0] == 0:
continue
overlaps = bbox_overlaps(boxes.astype(np.float),
gt_boxes.astype(np.float))
# gt_overlaps = np.hstack((gt_overlaps, overlaps.max(axis=0)))
_gt_overlaps = np.zeros((gt_boxes.shape[0]))
for j in xrange(gt_boxes.shape[0]):
argmax_overlaps = overlaps.argmax(axis=0)
max_overlaps = overlaps.max(axis=0)
gt_ind = max_overlaps.argmax()
gt_ovr = max_overlaps.max()
assert(gt_ovr >= 0)
box_ind = argmax_overlaps[gt_ind]
_gt_overlaps[j] = overlaps[box_ind, gt_ind]
assert(_gt_overlaps[j] == gt_ovr)
overlaps[box_ind, :] = -1
overlaps[:, gt_ind] = -1
gt_overlaps = np.hstack((gt_overlaps, _gt_overlaps))
num_pos = gt_overlaps.size
gt_overlaps = np.sort(gt_overlaps)
step = 0.001
thresholds = np.minimum(np.arange(0.5, 1.0 + step, step), 1.0)
recalls = np.zeros_like(thresholds)
for i, t in enumerate(thresholds):
recalls[i] = (gt_overlaps >= t).sum() / float(num_pos)
ar = 2 * np.trapz(recalls, thresholds)
return ar, gt_overlaps, recalls, thresholds
def create_roidb_from_box_list(self, box_list, gt_roidb):
assert len(box_list) == self.num_images, \
'Number of boxes must match number of ground-truth images'
roidb = []
for i in xrange(self.num_images):
boxes = box_list[i]
num_boxes = boxes.shape[0]
overlaps = np.zeros((num_boxes, self.num_classes), dtype=np.float32)
if gt_roidb is not None:
gt_boxes = gt_roidb[i]['boxes']
gt_classes = gt_roidb[i]['gt_classes']
gt_overlaps = bbox_overlaps(boxes.astype(np.float),
gt_boxes.astype(np.float))
argmaxes = gt_overlaps.argmax(axis=1)
maxes = gt_overlaps.max(axis=1)
I = np.where(maxes > 0)[0]
overlaps[I, gt_classes[argmaxes[I]]] = maxes[I]
overlaps = scipy.sparse.csr_matrix(overlaps)
roidb.append({'boxes' : boxes,
'gt_classes' : np.zeros((num_boxes,),
dtype=np.int32),
'gt_overlaps' : overlaps,
'flipped' : False})
return roidb
@staticmethod
def merge_roidbs(a, b):
assert len(a) == len(b)
for i in xrange(len(a)):
a[i]['boxes'] = np.vstack((a[i]['boxes'], b[i]['boxes']))
a[i]['gt_classes'] = np.hstack((a[i]['gt_classes'],
b[i]['gt_classes']))
a[i]['gt_overlaps'] = scipy.sparse.vstack([a[i]['gt_overlaps'],
b[i]['gt_overlaps']])
return a
def competition_mode(self, on):
"""Turn competition mode on or off."""
pass
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
from .imdb import imdb
from .pascal_voc import pascal_voc
from . import factory
import os.path as osp
ROOT_DIR = osp.join(osp.dirname(__file__), '..', '..')
# We assume your matlab binary is in your path and called `matlab'.
# If either is not true, just add it to your path and alias it as matlab, or
# you could change this file.
MATLAB = 'matlab'
# http://stackoverflow.com/questions/377017/test-if-executable-exists-in-python
def _which(program):
import os
def is_exe(fpath):
return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
fpath, fname = os.path.split(program)
if fpath:
if is_exe(program):
return program
else:
for path in os.environ["PATH"].split(os.pathsep):
path = path.strip('"')
exe_file = os.path.join(path, program)
if is_exe(exe_file):
return exe_file
return None
if _which(MATLAB) is None:
msg = ("MATLAB command '{}' not found. "
"Please add '{}' to your PATH.").format(MATLAB, MATLAB)
raise EnvironmentError(msg)
然后运行tools/train.py,生成数据。
注意在data/cache下的pkl文件我需要特别说明一下,如果你再次训练的时候修改了数据库,比如添加或者删除了一些样本,但是你的数据库名字函数原来那个,比如我这里训练的数据库叫KakouTrain,必须要在data/cache/目录下把数据库的缓存文件.pkl给删除掉,否则其不会重新读取相应的数据库,而是直接从之前读入然后缓存的pkl文件中读取进来,这样修改的数据库并没有进入网络,而是加载了老版本的数据。