# coding=utf-8
import numpy as np
from caffe2.proto import caffe2_pb2
import os
from caffe2.python import core, workspace,models
import matplotlib.pyplot as pyplot
import skimage
import skimage.io as io
import skimage.transform
import urllib3
print('required modules imported.')
CAFFE_MODELS = 'D:\Anaconda\envs\pytorch\Lib\site-packages\caffe2\python\models'
print('caffe_model path : {}'.format(CAFFE_MODELS))
IMG_LOCATION = 'F:\cocoDataAugment\data\\night\\night_1.jpg'
MODEL = 'squeezenet', 'init_net.pb', 'predict_net.pb', 'ilsvrc_2012_mean.py', 227
codes = "https://gist.githubusercontent.com/aaronmarkham/cd3a6b6ac071eca6f7b4a6e40e6038aa/raw/9edb4038a37da6b5a44c3b5bc52e448ff09bfe5b/alexnet_codes"
print('config set!')
def crop_center(img,cropx,cropy):
y, x, c = img.shape
startx = x//2-cropx//2
starty = y//2-cropy//2
return img[starty:starty+cropy, startx:startx+cropx]
def rescale(img, input_height, input_width):
print("Original image shape:" + str(img.shape) + " and remember it should be in H, W, C!")
print("Model's input shape is:{}x{}".format(input_height, input_width))
aspect = img.shape[1]/float(img.shape[0])
print("Orginal aspect ratio: " + str(aspect))
if(aspect>1):
# landscape orientation - wide image
res = int(aspect * input_height)
imgScaled = skimage.transform.resize(img, (input_height, res))
if(aspect<1):
# portrait orientation - tall image
res = int(input_width/aspect)
imgScaled = skimage.transform.resize(img, (res, input_width))
if(aspect == 1):
imgScaled = skimage.transform.resize(img, (input_height, input_width))
print("New image shape:" + str(imgScaled.shape) + " in HWC")
return imgScaled
CAFFE_MODELS = os.path.expanduser(CAFFE_MODELS)
MEAN_FILE = os.path.join(CAFFE_MODELS, MODEL[0],MODEL[3])
'''
# if not os.path.join(MEAN_FILE):
# mean = 128
# else:
# mean = np.load(MEAN_FILE).mean(1).mean(1)
# mean = mean[:, np.newaxis, np.newaxis]
# 找不到计算mean的py文件,不用了
'''
mean = 128
print('mean is set to : {}'.format(mean))
INPUT_IMAGE_SIZE = MODEL[4]
INIT_NET = os.path.join(CAFFE_MODELS, MODEL[0], MODEL[1]) # .../caffe2/python/models + squeezenet + init_net.pb
print('INIT_NET:{}'.format(INIT_NET))
PREDICT_NET = os.path.join(CAFFE_MODELS, MODEL[0], MODEL[2]) # .../caffe2/python/models + squeezenet + predict_net.pb
print('predict_net : {}'.format(PREDICT_NET))
if not os.path.exists(INIT_NET):
print(INIT_NET, ' not found@')
else:
print('Found ', INIT_NET, "...Now looking for ", PREDICT_NET)
if not os.path.exists(PREDICT_NET):
print(PREDICT_NET, ' not found!')
else:
print('all needed files found! loading model in the next block')
img = skimage.img_as_float(skimage.io.imread(IMG_LOCATION)).astype(np.float32)
img = rescale(img, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE)
img = crop_center(img, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE)
# hwc > chw , rgb > bgr
img = img.swapaxes(0,2).swapaxes(1,2)
img = img[(2,1,0),:,:]
img = img*255-mean
img = img[np.newaxis,:,:,:].astype(np.float32)
print('img shape nchw:{}'.format(img.shape))
'''
这里开始新知识,模型加载,预测,提取结果
'''
with open(INIT_NET, encoding='utf-8') as f:
print(f)
init_net = f.read()
with open(PREDICT_NET, encoding='utf-8') as f:
predict_net = f.read()
p = workspace.Predictor(init_net, predict_net)
result = p.run([img])
result = np.asarray(result)
print('result shape :{}'.format(result.shape))
参考https://zhuanlan.zhihu.com/p/34701037
但是由于使用的是python3,无法加载python2的模型。出现
UnicodeDecodeError: 'utf-8' codec can't decode byte 0xf8 in position 1: invalid start byte
因此重新下载、安装caffe2的python2版本。
此章是前面知识点的小集合
重点在于加载模型(read()),预测(workspace.predictor())以及预测结果的分析(未完成)
后续:
想在conda里面安caffe2,但系统是win10,再加上各种其他环境的影响,没有安装成功。结果的处理,需要以后用实验室电脑学习。
# the rest of this is digging through the results
results = np.delete(results, 1)
index = 0
highest = 0
arr = np.empty((0,2), dtype=object) #初始化矩阵
arr[:,0] = int(10)
arr[:,1:] = float(10)
for i, r in enumerate(results):
# imagenet index begins with 1!
i=i+1
arr = np.append(arr, np.array([[i,r]]), axis=0) # 按种类和概率分组,得到类似【(0,0.8),(1,0.9)】
if (r > highest):
highest = r # 最高识别概率
index = i # 识别的类别
# top 3 results
print "Raw top 3 results:", sorted(arr, key=lambda x: x[1], reverse=True)[:3]
# now we can grab the code list
response = urllib2.urlopen(codes)
# and lookup our result from the list
for line in response: #这里没懂,为什么要在list里面验证我们的result。。img不是自己的吗
code, result = line.partition(":")[::2]
if (code.strip() == str(index)):
print MODEL[0], "infers that the image contains ", result.strip()[1:-2], "with a ", highest*100, "% probability"