工作中经常用到py-faster-rcnn做图片的检测与识别,训练过程有必要记录一下,下面是参照网上的一些资料整理实践后的总结:
py-faster-rcnn的github地址:https://github.com/rbgirshick/py-faster-rcnn
数据采用VOC 2007格式。
一、制作数据集
程序/工具:VOC2007文件夹、labelImg
处理流程:图像重命名为6位数字,使用labelImg工具标定,根据xml生成四个txt(train.txt、val.txt、test.txt、trainval.txt),将jpg、xml、txt等文件按照逻辑图所示位置存放
数据生成工具类可参考:faster-rcnn之生成训练数据
二、修改网络文件
train.txt:
models/pascal_voc/VGG16/faster_rcnn_end2end/train.prototxt VGG16的train.prototxt
Line 11:’num_classes’: 2 修改成 损伤类型数目+1(背景算一类)
Line 530:’num_classes’: 2 修改成 损伤类型数目+1(背景算一类)
Line 620:num_output: 2 修改成 损伤类型数目+1(背景算一类)
Line 643:num_output: 8 此处数字应为 (损伤类别数+1)*4 “4”是指bbox的四个角
test.prototxt:
models/pascal_voc/VGG16/faster_rcnn_end2end/test.prototxt VGG16的test.prototxt
Line 567:num_output: 2 修改成 损伤类型数目+1(背景算一类)
Line 592:num_output: 8 此处数字应为 (损伤类别数+1)*4 “4”是指bbox的四个角
pascal_voc.py
lib/datasets/pascal_voc.py 修改line 31 修改为自定义类型
三、 运行程序
每次改动数据记得清空缓存 rm -rf data/cache
终端访问py-faster-rcnn目录,输入以下命令:
./experiments/scripts/faster_rcnn_end2end.sh 0 VGG16 pascal_voc
0表示使用GPU 0运行程序,可修改;VGG16表示使用的网络
四、 预测阶段
直接上代码:
#!/usr/bin/env python
import _init_paths
from fast_rcnn.config import cfg
from fast_rcnn.test import im_detect
from fast_rcnn.nms_wrapper import nms
import numpy as np
import scipy.io as sio
import caffe, os, sys, cv2
import argparse
CLASSES = ('__background__',
'hand')
def get_detections(im, class_name, dets, thresh=0.5):
"""Draw detected bounding boxes."""
inds = np.where(dets[:, -1] >= thresh)[0]
if len(inds) == 0:
return None
bboxs = []
for i in inds:
bbox = dets[i, :4]
bboxs.append([int(bbox[0]),int(bbox[1]),int(bbox[2]),int(bbox[3])])
return bboxs
def frcn_predict(net,img_im):
# Detect all object classes and regress object bounds
scores, boxes = im_detect(net, img_im)
# Visualize detections for each class
CONF_THRESH = 0.65
NMS_THRESH = 0.15
res_dict = {}
for cls_ind, cls in enumerate(CLASSES[1:]):
cls_ind += 1 # because we skipped background
cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
cls_scores = scores[:, cls_ind]
dets = np.hstack((cls_boxes,
cls_scores[:, np.newaxis])).astype(np.float32)
keep = nms(dets, NMS_THRESH)
dets = dets[keep, :]
boxs = get_detections(img_im, cls, dets, thresh=CONF_THRESH)
if boxs is not None:
res_dict[cls] = boxs
return res_dict
def get_init_net():
cfg.TEST.HAS_RPN = True # Use RPN for proposals
prototxt = r'models/online/models/pascal_voc/VGG16/faster_rcnn_end2end/hand_test.prototxt'
caffemodel = r'models/online/faster_rcnn_models/vgg16_faster_rcnn_hand_iter_500000.caffemodel'
if not os.path.isfile(caffemodel):
raise IOError(('{:s} not found.\n').format(caffemodel))
caffe.set_mode_gpu()
caffe.set_device(0)
cfg.GPU_ID = 0
net = caffe.Net(prototxt, caffemodel, caffe.TEST)
print '\n\nLoaded network {:s}'.format(caffemodel)
# Warmup on a dummy image
im = 128 * np.ones((300, 500, 3), dtype=np.uint8)
for i in xrange(2):
_, _= im_detect(net, im)
return net
if __name__ == '__main__':
net = get_init_net()
img_path = r'data/VOCdevkit/VOC2007_lisa/JPEGImages/5500.jpg'
im = cv2.imread(img_path)
res=frcn_predict(net, im)
print(res)
返回结果res结果示例,数据格式:{label:[pic1_point,pic2_point,…]}:
{'hand': [[482, 347, 570, 438], [52, 289, 147, 362], [104, 261, 273, 375]]}
参考来源:
https://www.zhihu.com/question/57091642/answer/165134753
http://blog.csdn.net/otengyue/article/details/79243559