在python文件的最上面放上下面内容就可以了
import sys
sys.path.append("/home/zhangqi/Desktop/caffe-master/python")
sys.path.append("/home/zhangqi/Desktop/caffe-master/python/caffe")
#############################################################################################
记得修改caffe_forward.py文件的路径
import sys
sys.path.append("/home/zhangqi/Desktop/caffe-master/python")
sys.path.append("/home/zhangqi/Desktop/caffe-master/python/caffe")
import caffe
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
model_defination = '/usr/xhh/model/general_prediction/cnn/forward_network.prototxt'
weights = '/usr/xhh/model/general_prediction/cnn/train_iter_146000.caffemodel'
data_path='/usr/xhh/model/general_prediction/cnn/current_position.txt'
def load_net(list):
caffe.set_mode_cpu()
# net=caffe.Net(weights)
net=caffe.Net(model_defination,weights,caffe.TEST)
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
# transformer.set_transpose('data', (2, 0, 1))
# transformer.set_mean('data', ()) # Load the mean file
# transformer.set_raw_scale('data', 1)
# transformer.set_channel_swap('data', (2, 1, 0)) # Convert RGB to GBR
net.blobs['data'].reshape(1, 3, 19, 19)
data=get_img_datum(data_path)
# net.blobs['data'].data[...] = transformer.preprocess('data', data)
net.blobs['data'].data[...] = data.reshape(1, 3, 19, 19)
res = net.forward()
res = np.asarray(res['loss'])
res=res[0]
if list[0]!='':
for i in list:
res[int(i)]=0
res=res.tolist()
a=sorted(res,reverse=True)
index=res.index(a[0])
#res = res.reshape((19, 19))
#plt.imshow(res, cmap= plt.cm.jet)
print index
for i in range(50):
print res.index(a[i])
def get_img_datum(data_path):
file=open(data_path)
for line in file:
str=line.split()
img=np.zeros((3,19,19))
img1=np.zeros((19,19))
img2=np.zeros((19,19))
img3=np.zeros((19,19))
i=0
j=0
for s1 in str[0]:
img1[i,j]=int(s1)
i=i+1
if i>18:
i=0
j=j+1
i=0
j=0
for s2 in str[1]:
img2[i,j]=int(s2)
i=i+1
if i>18:
i=0
j=j+1
i=0
j=0
for s3 in str[2]:
img3[i,j]=int(s3)
i=i+1
if i>18:
i=0
j=j+1
img[0,:,:]=img1
img[1,:,:]=img2
img[2,:,:]=img3
return img
if __name__ == '__main__':
index=""
list=[]
jList=sys.argv[1:]
jList_str=""
for i in jList:
jList_str+=i
jList_str=jList_str[1:len(jList_str)-1]
list=jList_str.split(',')
load_net(list)