- caffe模型训练后使用python接口测试,做此记录.
From:言有三
import caffe
import numpy as np
import cv2
def start_test(model_proto,model_weight,imgtxt,testsize,enable_crop):
### 初始化网络
caffe.set_device(0)
net = caffe.Net(model_proto, model_weight, caffe.TEST)
imgs = open(imgtxt,'r').readlines()
count = 0
acc = 0
for imgname in imgs:
imgname,label = imgname.strip().split(' ')
imgtype = imgname.split('.')[-1]
if imgtype != 'png' and imgtype != 'jpg' and imgtype != 'JPG' and imgtype != 'jpeg' and imgtype != 'tif' and imgtype != 'bmp':
print(imgtype,"error")
continue
imgpath = imgname
img = cv2.imread(imgpath)
if img is None:
print("---------img is empty---------",imgpath)
continue
imgheight,imgwidth,channel = img.shape
### 选择使用裁剪或者缩放的方案
if enable_crop == 1:
print("use crop")
cropx = (imgwidth - testsize) // 2
cropy = (imgheight - testsize) // 2
img = img[cropy:cropy+testsize,cropx:cropx+testsize,0:channel]
else:
img = cv2.resize(img,(testsize,testsize),interpolation=cv2.INTER_NEAREST)
### 减均值预处理
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
transformer.set_mean('data', np.array([104.008,116.669,122.675]))
transformer.set_transpose('data', (2,0,1))
out = net.forward_all(data=np.asarray([transformer.preprocess('data', img)]))
result = out['classifier'][0]
print("result=",result)
predict = np.argmax(result)
if str(label) == str(predict):
acc = acc + 1
count = count + 1
print("acc=",float(acc) / float(count))
if __name__ = __main__:
start_test('deploy.prototxt', 'models/mobilenet_finetune_iter_2000.caffemodel', 'all_shuffle_val.txt', 96, 1)