如何加载训练好的CNN模型来做图像数据集分类?

在Tensorflow中,调用CNN训练的模型来预测图片类别,主要分为以下几步:

1、加载训练后的模型

saver = tf.train.import_meta_graph('./model/my-model-95.meta')
saver.restore(sess,tf.train.latest_checkpoint('./model/'))

上述接口的参数都是训练好的模型存储后的文件,如下图所示:
模型文件

.meta文件保存了当前图结构
.index文件保存了当前参数名
.data文件保存了当前参数值

2、载入默认的图,提取图中与图像预测相关的张量,主要包括了输入数据和预测值两项,张量的提取依靠的是张量的name属性,需要在训练模型的时候提前指定。

graph = tf.get_default_graph()    # 载入图
x = graph.get_tensor_by_name("x:0")    # 得到图中的x张量并初始化为0,存储的是输入张量
logits = graph.get_tensor_by_name("logits_eval:0")    # 得到图中logits_eval张量,存储预测张量

# 以下为训练模型时为上述张量起的名字
x=tf.placeholder(tf.float32,shape=[None,w,h,c],name='x')
logits_eval = tf.multiply(logits,b,name='logits_eval')

3、在新的session中run起来,当然为了得到预测值,需要feed数据也就是需要预测的图像。

feed_dict = {x:imgarr}
classification_result = sess.run(logits,feed_dict)

4、得到上述分类结果以后就能够美滋滋地判断类别了,具体而言

output = tf.argmax(classification_result,1).eval()    # 这个语句能得到结果张量中值最大的元素下标也就是类别

需要特别注意的是,在使用模型做图像分类的时候,要保证输入的图像属性和训练模型时候一致,包括尺寸、色彩空间等。如果在训练模型的时候对图像做了预处理,在图像分类时可以用原图的拷贝做图像分类,而将原图移动到指定的类别文件夹。

以下是调用CNN模型做图像分类的一个示例代码:

from skimage import io,transform
import tensorflow as tf
import numpy as np
import os
from PIL import Image


# 类别标签
img_dict = {0:'猫',1:'狗',2:'鸟'}


# 图像预处理,和实际训练保持一致,如果训练时没有这个步骤可以略过
def img_preprocess(imgpath):
    print("Reading "+imgpath)
    
    img=Image.open(imgpath)
    img = img.convert('RGB')
    w,h=img.size
    box=[]
    
    if w>h:
        offset=(w-h)/2
        box=[offset,0,h+offset,h]
    else:
        offset=(h-w)/2
        box=[0,offset,w,w+offset]
    
    img=img.crop(box)
    img=img.resize((100,100),Image.ANTIALIAS)
    imgarr = np.asarray(img)
    img.close()

    return np.asarray(imgarr)


with tf.Session() as sess:
    # 加载网络模型
    saver = tf.train.import_meta_graph('./model/my-model-95.meta')
    saver.restore(sess,tf.train.latest_checkpoint('./model/'))
        
    #列出文件夹下所有的目录与文件
    rootdir = './test'
    list = os.listdir(rootdir)
    
    for i in range(0,len(list)):
        path = os.path.join(rootdir,list[i])
        imgtest = [] # 存放待判断的图片数组
        imgdata = [] # 存放图片数据
        imgname = [] # 存放图片名
        
        # 存图片的数据
        img = Image.open(path)
        imgdata.append(img)
        
        # 存图片名
        imgname.append(list[i])
        
        # 存待判断图片数组
        img = img_preprocess(path)
        imgtest.append(img)
        
        # 载入默认的图
        # 其实我们需要求的就是logits,也就是需要通过run模型来得到logits
        # logits的计算依赖x的输入,因此需要初始化网咯中的x
        graph = tf.get_default_graph()
        x = graph.get_tensor_by_name("x:0")
        logits = graph.get_tensor_by_name("logits_eval:0")
        
        feed_dict = {x:imgtest}
        classification_result = sess.run(logits,feed_dict)
 
        #打印出预测矩阵
        print(classification_result)
        
        #打印出预测矩阵每一行最大值的索引
        print(tf.argmax(classification_result,1).eval())
        
        #根据索引通过字典对应小动物的分类
        output = []
        output = tf.argmax(classification_result,1).eval()
        
        # 将图片存入指定文件夹
        for i in range(len(output)):
            print("该图预测:"+img_dict[output[i]]+"\n")
            savepath = "./result/" + str(output[i]) + "/" + str(imgname[i])
            imgdata[i].save(savepath)
        
        print("Done!")

请添加图片描述

  • 2
    点赞
  • 37
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值