1、将bn(Batch Normalization)层合并到了卷积层中,打开merge_bn.py文件,然后注意修改其中的文件路径:
import numpy as np
import sys,os
import caffe
train_proto = 'MobileNetSSD_train.prototxt'
train_model = 'mobilenet_iter_3000.caffemodel' #should be your snapshot caffemodel
deploy_proto = 'MobileNetSSD_deploy.prototxt'
save_model = 'MobileNetSSD_deploy.caffemodel'
def merge_bn(net, nob):
'''
merge the batchnorm, scale layer weights to the conv layer, to improve the performance
var = var + scaleFacotr
rstd = 1. / sqrt(var + eps)
w = w * rstd * scale
b = (b - mean) * rstd * scale + shift
'''
for key in net.params.iterkeys():
if type(net.params[key]) is caffe._caffe.BlobVec:
if key.endswith("/bn") or key.endswith("/scale"):
continue
else:
conv = net.params[key]
if not net.params.has_key(key + "/bn"):
for i, w in enumerate(conv):
nob.params[key][i].data[...] = w.data
else:
bn = net.params[key + "/bn"]
scale = net.params[key + "/scale"]
wt = conv[0].data
channels = wt.shape[0]
bias = np.zeros(wt.shape[0])
if len(conv) > 1:
bias = conv[1].data
mean = bn[0].data
var = bn[1].data
scalef = bn[2].data
scales = scale[0].data
shift = scale[1].data
if scalef != 0:
scalef = 1. / scalef
mean = mean * scalef
var = var * scalef
rstd = 1. / np.sqrt(var + 1e-5)
rstd1 = rstd.reshape((channels,1,1,1))
scales1 = scales.reshape((channels,1,1,1))
wt = wt * rstd1 * scales1
bias = (bias - mean) * rstd * scales + shift
nob.params[key][0].data[...] = wt
nob.params[key][1].data[...] = bias
net = caffe.Net(train_proto, train_model, caffe.TRAIN)
net_deploy = caffe.Net(deploy_proto, caffe.TEST)
merge_bn(net, net_deploy)
net_deploy.save(save_model)
2、可以使用demo.py,使用模型进行检测:
import numpy as np
import sys,os
import cv2
caffe_root = '/home/xxx/caffe-ssd/'
sys.path.insert(0, caffe_root + 'python')
import caffe
#net_file= 'deploy.prototxt'
net_file= 'MobileNetSSD_deploy.prototxt'
caffe_model='MobileNetSSD_deploy.caffemodel'
#caffe_model = 'mobilenet_iter_73000.caffemodel'
test_dir = "images"
if not os.path.exists(caffe_model):
print("MobileNetSSD_deploy.caffemodel does not exist,")
print("use merge_bn.py to generate it.")
exit()
net = caffe.Net(net_file,caffe_model,caffe.TEST)
CLASSES = ('background',
'person', 'safely', 'sign', 'key',
'universal_key', 'dock', 'State', 'circle', 'pole',
'glove')
def preprocess(src):
img = cv2.resize(src, (300,300))
img = img - 127.5
img = img * 0.007843
return img
def postprocess(img, out):
h = img.shape[0]
w = img.shape[1]
box = out['detection_out'][0,0,:,3:7] * np.array([w, h, w, h])
cls = out['detection_out'][0,0,:,1]
conf = out['detection_out'][0,0,:,2]
return (box.astype(np.int32), conf, cls)
def detect(imgfile):
origimg = cv2.imread(imgfile)
img = preprocess(origimg)
img = img.astype(np.float32)
img = img.transpose((2, 0, 1))
net.blobs['data'].data[...] = img
out = net.forward()
box, conf, cls = postprocess(origimg, out)
for i in range(len(box)):
p1 = (box[i][0], box[i][1])
p2 = (box[i][2], box[i][3])
cv2.rectangle(origimg, p1, p2, (0,255,0))
p3 = (max(p1[0], 15), max(p1[1], 15))
title = "%s:%.2f" % (CLASSES[int(cls[i])], conf[i])
cv2.putText(origimg, title, p3, cv2.FONT_ITALIC, 0.6, (0, 255, 0), 1)
print title, p3
cv2.imshow("SSD", origimg)
k = cv2.waitKey(0) & 0xff
#Exit if ESC pressed
if k == 27 : return False
return True
for f in os.listdir(test_dir):
if detect(test_dir + "/" + f) == False:
break
3、检测视频或camera:
import numpy as np
import sys,os
import cv2
caffe_root = '/home/xxx/caffe-ssd/'
sys.path.insert(0, caffe_root + 'python')
import caffe
caffe.set_device(0)
caffe.set_mode_gpu()
#net_file= 'deploy.prototxt'
net_file= 'MobileNetSSD_deploy.prototxt'
caffe_model='MobileNetSSD_deploy.caffemodel'
#caffe_model = 'mobilenet_iter_73000.caffemodel'
test_dir = "images"
if not os.path.exists(caffe_model):
print("MobileNetSSD_deploy.caffemodel does not exist,")
print("use merge_bn.py to generate it.")
exit()
net = caffe.Net(net_file,caffe_model,caffe.TEST)
CLASSES = ('background',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')
def preprocess(src):
img = cv2.resize(src, (300,300))
img = img - 127.5
img = img * 0.007843
return img
def postprocess(img, out):
h = img.shape[0]
w = img.shape[1]
box = out['detection_out'][0,0,:,3:7] * np.array([w, h, w, h])
cls = out['detection_out'][0,0,:,1]
conf = out['detection_out'][0,0,:,2]
return (box.astype(np.int32), conf, cls)
def detect(imgfile):
#origimg = cv2.imread(imgfile)
origimg = imgfile
img = preprocess(origimg)
img = img.astype(np.float32)
img = img.transpose((2, 0, 1))
net.blobs['data'].data[...] = img
out = net.forward()
box, conf, cls = postprocess(origimg, out)
for i in range(len(box)):
p1 = (box[i][0], box[i][1])
p2 = (box[i][2], box[i][3])
cv2.rectangle(origimg, p1, p2, (0,255,0))
p3 = (max(p1[0], 15), max(p1[1], 15))
title = "%s:%.2f" % (CLASSES[int(cls[i])], conf[i])
cv2.putText(origimg, title, p3, cv2.FONT_ITALIC, 0.6, (0, 255, 0), 1)
print title, p3
cv2.imshow("SSD", origimg)
while (True):
raw_key = cv2.waitKey(1)
# check if the window is visible, this means the user hasn't closed
# the window via the X button
prop_val = cv2.getWindowProperty("SSD", cv2.WND_PROP_ASPECT_RATIO)
#if ((raw_key != -1) or (prop_val < 0.0)):
# the user hit a key or closed the window (in that order)
break
'''
k = cv2.waitKey(0) & 0xff
#Exit if ESC pressed
if k == 27 : return False
return True
'''
def fromcamera():
capture = cv2.VideoCapture(200)
print capture.isOpened()
frameNum = 1
while(capture.isOpened()):
_, frame = capture.read()
print frameNum
frameNum += 1
detect(frame)
#frame = cv2.resize(frame, (960, 540))
#cv2.imshow('frame', frame)
#videoWriter.write(frame) #
#
#if frameNum%2 == 1:
#v2.imwrite('image'+str(frameNum)+'.jpg', frame)
# pass
#if cv2.waitKey(1) == ord('q'):
# break
capture.release()
#cv2.destroyAllWindows()
def fromvideo():
#cap_video = cv2.VideoCapture("/home/xxx/darknet/59cbb31588870.mp4")
cap_video = cv2.VideoCapture("/home/xxx/caffe-ssd/examples/videos/ILSVRC2015_train_00755001.mp4")
success = True
while success:
success, frame_mp4 = cap_video.read()
#frame_mp4 = cv2.resize(frame_mp4, (960, 540))
detect(frame_mp4)
if __name__ == '__main__':
#fromcamera()
fromvideo()