Faster RCNN复现
Faster RCNN源码解读1-整体流程和各个子流程梳理
Faster RCNN源码解读2-_anchor_component()为图像建立anchors(核心和关键1)
Faster RCNN源码解读3.1-_region_proposal() 筛选anchors-_proposal_layer()(核心和关键2)
Faster RCNN源码解读3.2-_region_proposal()筛选anchors-_anchor_target_layer()(核心和关键2)
Faster RCNN源码解读3.3-_region_proposal() 筛选anchors-_proposal_target_layer()(核心和关键2)
Faster RCNN源码解读4-其他收尾工作:ROI_pooling、分类、回归等
Faster RCNN源码解读5-损失函数
python3.7
tensorflow1.14
windows cpu
调试过程,没有记录,总的原则就是兵来将挡,水来土掩。遇到问题,带着问题去搜索就行了,总有人和你遇到相同的问题并解决了,不踩坑永远没法长记性,哈哈。
demo.py
#!/usr/bin/env python
# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Xinlei Chen, based on code from Ross Girshick
# --------------------------------------------------------
"""
Demo script showing detections in sample images.
See README.md for installation instructions before running.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import _init_paths
from model.config import cfg
from model.test import im_detect
from model.nms_wrapper import nms
from utils.timer import Timer
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os, cv2
import argparse
from nets.vgg16 import vgg16
from nets.resnet_v1 import resnetv1
CLASSES = ('__background__',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')
NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',),'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}
DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}
def vis_detections(im, class_name, dets, thresh=0.5):
"""Draw detected bounding boxes."""
inds = np.where(dets[:, -1] >= thresh)[0]
if len(inds) == 0:
return
im = im[:, :, (2, 1, 0)]
fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(im, aspect='equal')
for i in inds:
bbox = dets[i, :4]
score = dets[i, -1]
ax.add_patch(
plt.Rectangle((bbox[0], bbox[1]),
bbox[2] - bbox[0],
bbox[3] - bbox[1], fill=False,
edgecolor='red', linewidth=3.5)
)
ax.text(bbox[0], bbox[1] - 2,
'{:s} {:.3f}'.format(class_name, score),
bbox=dict(facecolor='blue', alpha=0.5),
fontsize=14, color='white')
ax.set_title(('{} detections with '
'p({} | box) >= {:.1f}').format(class_name, class_name,
thresh),
fontsize=14)
plt.axis('off')
plt.tight_layout()
plt.draw()
def demo(sess, net, image_name):
"""Detect object classes in an image using pre-computed object proposals.
使用预先计算的候选框检测图片中的物体
"""
# Load the demo image 读图
im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
im = cv2.imread(im_file)
# Detect all object classes and regress object bounds 检测所有对象类并回归对象边界
timer = Timer() #计时
timer.tic()
scores, boxes = im_detect(sess, net, im) #获取图像中预选框及其得分
timer.toc()
print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time, boxes.shape[0]))
# Visualize detections for each class 可视化每个类的检测
CONF_THRESH = 0.8 #阈值
NMS_THRESH = 0.3 #阈值
for cls_ind, cls in enumerate(CLASSES[1:]):
cls_ind += 1 # because we skipped background
cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
cls_scores = scores[:, cls_ind]
dets = np.hstack((cls_boxes,
cls_scores[:, np.newaxis])).astype(np.float32)
keep = nms(dets, NMS_THRESH)
dets = dets[keep, :]
vis_detections(im, cls, dets, thresh=CONF_THRESH)
#参数设置
def parse_args():
"""Parse input arguments."""
parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN demo')
parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
choices=NETS.keys(), default='res101')
parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
choices=DATASETS.keys(), default='pascal_voc_0712')
args = parser.parse_args()
return args
if __name__ == '__main__':
cfg.TEST.HAS_RPN = True # 使用RPN产生预选框
cfg.USE_GPU_NMS =False #是否使用GPU nms
args = parse_args() #获取相关参数
# model path 模型路径
demonet = args.demo_net #网络模型
dataset = args.dataset #数据集
# demonet = 'vgg16'
# dataset = vgg16_faster_rcnn_iter_70000.ckpt
tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default',
NETS[demonet][0]) #模型路径
print(tfmodel)
tfmodel="E:/sxl_Programs/TargetDetection/tf-faster-rcnn-windows-master/data/output/res101/voc_2007_trainval+voc_2012_trainval/res101_faster_rcnn_iter_110000.ckpt"
if not os.path.isfile(tfmodel+ '.meta' ):
raise IOError(('{:s} not found.\nDid you download the proper networks from '
'our server and place them properly?').format(tfmodel + '.meta'))
# set config
tfconfig = tf.ConfigProto(allow_soft_placement=True)
tfconfig.gpu_options.allow_growth=True
# init session
sess = tf.Session(config=tfconfig)
# load network
if demonet == 'vgg16':
net = vgg16()
elif demonet == 'res101':
net = resnetv1(num_layers=101) #类对象
else:
raise NotImplementedError
net.create_architecture("TEST", 21,
tag='default', anchor_scales=[8, 16, 32])
saver = tf.train.Saver()
saver.restore(sess, tfmodel)
print('Loaded network {:s}'.format(tfmodel))
# im_names = ['000456.jpg', '000542.jpg', '001150.jpg',
# '001763.jpg', '004545.jpg']
im_names = ['1.jpg', '2.jpg', '3.jpg',
'4.jpg', '5.jpg']
for im_name in im_names:
print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
print('Demo for data/demo/{}'.format(im_name))
demo(sess, net, im_name)
plt.show()
运行结果:
我自己搜了几张图测试一下(代码中自带的图就不贴了):
trainval_net.py
# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Zheqi He, Xinlei Chen, based on code from Ross Girshick
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import _init_paths
from model.train_val import get_training_roidb, train_net
from model.config import cfg, cfg_from_file, cfg_from_list, get_output_dir, get_output_tb_dir
from datasets.factory import get_imdb
import datasets.imdb
import argparse
import pprint
import numpy as np
import sys
import tensorflow as tf
from nets.vgg16 import vgg16
from nets.resnet_v1 import resnetv1
class args:
"""
Parse input arguments
"""
cfg_file = 'E:/sxl_Programs/TargetDetection/tf-faster-rcnn-windows-master/experiments/cfgs/vgg16.yml'
weight = 'E:/sxl_Programs/TargetDetection/tf-faster-rcnn-windows-master/data/output/vgg16/vgg16.ckpt'
imdb_name = 'voc_2007_trainval'
imdbval_name = 'voc_2007_test'
max_iters = 100000
tag = None
net = 'vgg16'
# set_cfgs = ['ANCHOR_SCALES', '[8,16,32]', 'ANCHOR_RATIOS', '[0.5,1,2]', 'TRAIN.STEPSIZE', '50000']
set_cfgs = ['ANCHOR_SCALES', '[8,16,32]', 'ANCHOR_RATIOS', '[0.5,1,2]', 'TRAIN.STEPSIZE', '[20]']
def combined_roidb(imdb_names):
"""
Combine multiple roidbs
"""
def get_roidb(imdb_name):
imdb = get_imdb(imdb_name)
print('Loaded dataset `{:s}` for training'.format(imdb.name))
imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
print('Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD))
roidb = get_training_roidb(imdb)
return roidb
roidbs = [get_roidb(s) for s in imdb_names.split('+')]
roidb = roidbs[0]
if len(roidbs) > 1:
for r in roidbs[1:]:
roidb.extend(r)
tmp = get_imdb(imdb_names.split('+')[1])
imdb = datasets.imdb.imdb(imdb_names, tmp.classes)
else:
imdb = get_imdb(imdb_names)
return imdb, roidb
if __name__ == '__main__':
# args = parse_args()
print('Called with args:')
if args.cfg_file is not None:
cfg_from_file(args.cfg_file)
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs)
print('Using config:')
pprint.pprint(cfg)
np.random.seed(cfg.RNG_SEED)
# train set
imdb, roidb = combined_roidb(args.imdb_name)
print('{:d} roidb entries'.format(len(roidb)))
# output directory where the models are saved
output_dir = get_output_dir(imdb, args.tag)
print('Output will be saved to `{:s}`'.format(output_dir))
# tensorboard directory where the summaries are saved during training
tb_dir = get_output_tb_dir(imdb, args.tag)
print('TensorFlow summaries will be saved to `{:s}`'.format(tb_dir))
# also add the validation set, but with no flipping images
orgflip = cfg.TRAIN.USE_FLIPPED
cfg.TRAIN.USE_FLIPPED = False
_, valroidb = combined_roidb(args.imdbval_name)
print('{:d} validation roidb entries'.format(len(valroidb)))
cfg.TRAIN.USE_FLIPPED = orgflip
# load network
if args.net == 'vgg16':
net = vgg16()
# net = vgg16(batch_size=cfg.TRAIN.IMS_PER_BATCH)
else:
raise NotImplementedError
train_net(net, imdb, roidb, valroidb, output_dir, tb_dir,
pretrained_model=args.weight,
max_iters=args.max_iters)
运行结果: