googlenet 作为名噪一时的分类模型,影响比较深远!现列出自己实际使用googlenet的测试代码,以供参考!
其中模型权重文件和均值文件,分类文件自行下载!
# -*- coding=utf-8 -*-
import caffe
import numpy as np
import time
root='/home/xxx/caffe/models/bvlc_googlenet/'
deploy = root + 'deploy.prototxt'
caffe_model = root + 'bvlc_googlenet.caffemodel'
img_path = root+'dog.jpg'
label_path = root + 'synset_words.txt'
mean_file = root + 'ilsvrc_2012_mean.npy'
image_mean = np.load(mean_file).mean(1).mean(1)
def set_device_mode(GPU=False):
if GPU:
caffe.set_device(0)
caffe.set_mode_gpu()
else :
caffe.set_mode_cpu()
def show_feature_map(net):
for layer_name, feature_map in net.blobs.iteritems():
print(layer_name + '\t' + str(feature_map.data.shape))
def show_conv_kernel(net):
for layer_name, kernel in net.params.iteritems():
print(layer_name + '\t' + str(kernel[0].data.shape))
def load_model(deploy,caffe_model):
net = caffe.Net(deploy,caffe_model,caffe.TEST)
return net
def image_read_change(image_path,image_mean,net):
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
transformer.set_transpose('data', (2,0,1)) #改变维度的顺序,HWC->CHW
transformer.set_mean('data', image_mean)
transformer.set_raw_scale('data', 255) # 缩放到【0,255】之间
transformer.set_channel_swap('data', (2,1,0)) #RGB->BGR
img=caffe.io.load_image(image_path)
net.blobs['data'].data[...] = transformer.preprocess('data',img) #执行上面设置的图片预处理操作,并将图片载入到blob中
return net
def forward_test(net):
t1=time.time()
for i in range(10):
output = net.forward()
t2=time.time()
print("model test average forward time:%f s" % float((t2-t1)/10.0))
output_prob = output['prob'][0]
print 'The predicted class is : ', output_prob.argmax()
labels = np.loadtxt(label_path, str, delimiter='\t') #读取类别名称文件
print 'The label is : ', labels[output_prob.argmax()]
top_inds = output_prob.argsort()[::-1][:5]
print 'probabilities and labels: ', zip(output_prob[top_inds], labels[top_inds])
def main():
set_device_mode(True) #set GPU model
net = load_model(deploy,caffe_model)
new_net = image_read_change(img_path,image_mean,net)
forward_test(new_net)
# show_feature_map(new_net)
if __name__ == '__main__':
main()