在Tensorflow中,调用CNN训练的模型来预测图片类别,主要分为以下几步:
1、加载训练后的模型
saver = tf.train.import_meta_graph('./model/my-model-95.meta')
saver.restore(sess,tf.train.latest_checkpoint('./model/'))
上述接口的参数都是训练好的模型存储后的文件,如下图所示:
.meta文件保存了当前图结构
.index文件保存了当前参数名
.data文件保存了当前参数值
2、载入默认的图,提取图中与图像预测相关的张量,主要包括了输入数据和预测值两项,张量的提取依靠的是张量的name属性,需要在训练模型的时候提前指定。
graph = tf.get_default_graph() # 载入图
x = graph.get_tensor_by_name("x:0") # 得到图中的x张量并初始化为0,存储的是输入张量
logits = graph.get_tensor_by_name("logits_eval:0") # 得到图中logits_eval张量,存储预测张量
# 以下为训练模型时为上述张量起的名字
x=tf.placeholder(tf.float32,shape=[None,w,h,c],name='x')
logits_eval = tf.multiply(logits,b,name='logits_eval')
3、在新的session中run起来,当然为了得到预测值,需要feed数据也就是需要预测的图像。
feed_dict = {x:imgarr}
classification_result = sess.run(logits,feed_dict)
4、得到上述分类结果以后就能够美滋滋地判断类别了,具体而言
output = tf.argmax(classification_result,1).eval() # 这个语句能得到结果张量中值最大的元素下标也就是类别
需要特别注意的是,在使用模型做图像分类的时候,要保证输入的图像属性和训练模型时候一致,包括尺寸、色彩空间等。如果在训练模型的时候对图像做了预处理,在图像分类时可以用原图的拷贝做图像分类,而将原图移动到指定的类别文件夹。
以下是调用CNN模型做图像分类的一个示例代码:
from skimage import io,transform
import tensorflow as tf
import numpy as np
import os
from PIL import Image
# 类别标签
img_dict = {0:'猫',1:'狗',2:'鸟'}
# 图像预处理,和实际训练保持一致,如果训练时没有这个步骤可以略过
def img_preprocess(imgpath):
print("Reading "+imgpath)
img=Image.open(imgpath)
img = img.convert('RGB')
w,h=img.size
box=[]
if w>h:
offset=(w-h)/2
box=[offset,0,h+offset,h]
else:
offset=(h-w)/2
box=[0,offset,w,w+offset]
img=img.crop(box)
img=img.resize((100,100),Image.ANTIALIAS)
imgarr = np.asarray(img)
img.close()
return np.asarray(imgarr)
with tf.Session() as sess:
# 加载网络模型
saver = tf.train.import_meta_graph('./model/my-model-95.meta')
saver.restore(sess,tf.train.latest_checkpoint('./model/'))
#列出文件夹下所有的目录与文件
rootdir = './test'
list = os.listdir(rootdir)
for i in range(0,len(list)):
path = os.path.join(rootdir,list[i])
imgtest = [] # 存放待判断的图片数组
imgdata = [] # 存放图片数据
imgname = [] # 存放图片名
# 存图片的数据
img = Image.open(path)
imgdata.append(img)
# 存图片名
imgname.append(list[i])
# 存待判断图片数组
img = img_preprocess(path)
imgtest.append(img)
# 载入默认的图
# 其实我们需要求的就是logits,也就是需要通过run模型来得到logits
# logits的计算依赖x的输入,因此需要初始化网咯中的x
graph = tf.get_default_graph()
x = graph.get_tensor_by_name("x:0")
logits = graph.get_tensor_by_name("logits_eval:0")
feed_dict = {x:imgtest}
classification_result = sess.run(logits,feed_dict)
#打印出预测矩阵
print(classification_result)
#打印出预测矩阵每一行最大值的索引
print(tf.argmax(classification_result,1).eval())
#根据索引通过字典对应小动物的分类
output = []
output = tf.argmax(classification_result,1).eval()
# 将图片存入指定文件夹
for i in range(len(output)):
print("该图预测:"+img_dict[output[i]]+"\n")
savepath = "./result/" + str(output[i]) + "/" + str(imgname[i])
imgdata[i].save(savepath)
print("Done!")