如何使用训练好的Tensorflow模型的案例及分析

1. 什么是Tensorflow模型?

我们知道Tensorflow是由张量和计算模型组成,Tensorflow中的所有计算都会被转换为计算图上的节点,TensorFlow内部会将运算过程表示为一个数据流图。当你训练好一个神经网络后,同时系统将模型结果保存下来。因此,什么是Tensorflow模型?Tensorflow模型主要包含网络设计(或者网络图)和训练好的网络参数的值。所以Tensorflow模型有两个主要的文件:
在这里插入图片描述

1.1. Meta图:

Meta图是一个协议缓冲区(protocol buffer),它保存了完整的Tensorflow图;比如所有的变量、运算、集合等。这个文件的扩展名是.meta。

1.2. Checkpoint 文件

这是一个二进制文件,它保存了权重、偏置项、梯度以及其他所有的变量的取值,扩展名为.ckpt。但是, 从0.11版本开始,Tensorflow对改文件做了点修改,checkpoint文件不再是单个.ckpt文件,而是如下两个文件:
在这里插入图片描述

1.3. 关于saver

Saver的作用是将我们训练好的模型的参数保存下来,以便下一次继续用于训练或测试。

saver.save(sess, 'save/water_model.ckpt')

Saver类训练完后,是以checkpoints文件形式保存。一般地,Saver会自动的管理Checkpoints文件。我们可以指定保存最近的N个Checkpoints文件,当然每一步都保存ckpt文件也是可以的。

  • saver()可以选择global_step参数来为ckpt文件名添加数字标记:
  • max_to_keep参数定义saver()将自动保存的最近n个ckpt文件,默认n=5,即保存最近的5个检查点ckpt文件。若n=0或者None,则保存所有的ckpt文件。
  • keep_checkpoint_every_n_hours与max_to_keep类似,定义每n小时保存一个ckpt文件。
...
# Create a saver.
saver = tf.train.Saver(...variables...)
# Launch the graph and train, saving the model every 1,000 steps.
sess = tf.Session()
for step in xrange(1000000):
    sess.run(..training_op..)
    if step % 1000 == 0:
        # Append the step number to the checkpoint name:
        saver.save(sess, 'my-model', global_step=step)

2. 训练模型及会话介绍

2.1. 模型网络结构

模型中的输入/输出变量名称,需要在训练模型时定义到网络结构中,以便重新加载(计算图)网络结构图时可以识别,如下代码中定义的“x”、“keep_prob”、“y_conv”,通过占位符中定义出变量。

x = tf.placeholder("float", shape=[None, 20 * 20 * 1 ],name='x') #图片像素20*20
......
keep_prob = tf.placeholder("float",name='keep_prob')
......
y_conv=tf.nn.softmax(tf.matmul(h_fc2_drop, W_fc3) + b_fc3,name='y_conv')  # softmax层,计算输出标签

2.2. Tensorflow运行模型——会话

Tensorflow中使用会话一般有两种:
第一种模式需要明确调用会话生产函数和关闭会话。

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

for i in range(80):
    ......
    training.run(feed_dict={x: batch_x, y: batch_y, keep_drop: 0.5})

sess.close()

第二种模式是Tensorflow可以通过Python的上下文管理器来使用会话,如下“with tf.Session() as sess”,只要把所有计算放在“with”的内部就可以。当上下文管理器退出时候,系统将会自动释放所有资源。

本模式案例代码,详见后续内容。

3. 加载模型的两种方法

3.1. 定义模型网络结构

3.1.1. 定义网络结构

定义模型网络结构就是使用原训练模型的全套网络结构代码,以及其中使用到变量,也就是明示网络结构。

3.1.2. 恢复参数

使用saver.restore()函数恢复模型参数,与保存模型的文件路径名称保持一致,特别注意,文件名称不是全名称,不带文件名称最后的后缀,例如:
saver.restore(sess, “save/water_model.ckpt”) #使用模型,参数和之前的代码保持一致

#占位符,通过为输入图像和目标输出类别创建节点,来开始构建计算图。
x = tf.placeholder("float", shape=[None, 20 * 20 * 1 ],name='x') #图片像素20*20
y_ = tf.placeholder("float", shape=[None, 10],name='y_')          #输出为10个数字
......
x_test = input_data('imgs_lib/img2_0.png')  # 按目录取图片

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print("开始测试!")
    saver.restore(sess, "save/water_model.ckpt") #使用模型,参数和之前的代码保持一致
    
    ret=sess.run(y_conv,feed_dict={x:x_test,keep_prob: 1.0})
    predint=sess.run(tf.argmax(ret,1))
    y_pre = predint

例如加载模型时遇到类似如下错误:“Assign requires shapes of both tensors to match. lhs shape= [3,3,1,16] rhs shape= []…”,原因是训练时所保存的模型与测试加载模型的网络结构不一致造成的。产生的原因是模型使用者和训练者可能是不同的人,或者,不同机器设备等等。
在这里插入图片描述
出现报错的原因是“移动了不用变量的位置”,不用的变量代码不仅要保留,位置也不能变。这个问题困扰了我好几天,如果要训练一次模型,将会耗费很长时间!

3.2. 加载模型网络结构

从文件中将保存的graph的所有节点加载到当前的default graph中,并返回一个saver。也就是说,我们在保存的时候,除了将变量的值保存下来,其实还有将对应graph中的各种节点保存下来,所以模型的结构也同样被保存下来了。

3.2.1. 导入创建网络结构图

使用tf.train.import()函数加载以前保存的网络。
saver = tf.train.import_meta_graph(‘save/water_model.ckpt.meta’)
注意,import_meta_graph将保存在.meta文件中的图添加到当前的图中。所以,创建了一个图/网络,但是我们使用需要加载训练的参数到这个图中。

3.2.2. 导入加载参数

调用由tf.train.Saver()创建的对象saver中的restore方法来恢复网络中的参数。
saver.restore(sess,tf.train.latest_checkpoint(“save/”))

3.2.3. 网络结构中输入/输出参数

通过graph,获取训练模型时定义变量标识名称。详见3.1章节中定义。

        graph = tf.get_default_graph()
        x = graph.get_tensor_by_name('x:0')
        keep_prob = graph.get_tensor_by_name('keep_prob:0')
        y_conv = graph.get_tensor_by_name('y_conv:0')

3.2.3. 示例代码

所附代码是绕开第一种方法而实现的,没有了网上常见的网络结构代码,也看不到网络模型。

'''
Created on 2019年5月23日
@author: xiaoyw
'''
import tensorflow as tf
import cv2
import numpy as np
import datetime
import os

#输入图片文件名,图片为20*20
def input_data(file_name):
    img = cv2.imread(file_name, cv2.IMREAD_GRAYSCALE)
    x_data = []
    x_data.append(img.flatten())            
    x_data = np.array(x_data)
    x_data = x_data.astype("float")
    x_data = np.multiply(x_data, 1.0 / 255.0)
           
    return x_data
def modele_test():
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.import_meta_graph('save/water_model.ckpt.meta')
        print("模型初始化!")
        saver.restore(sess,tf.train.latest_checkpoint("save/"))
        
        graph = tf.get_default_graph()
        x = graph.get_tensor_by_name('x:0')
        keep_prob = graph.get_tensor_by_name('keep_prob:0')
        y_conv = graph.get_tensor_by_name('y_conv:0')

        file_list = readImgFileName('imgs_lib')   #测试文件路径为imgs_lib
        for file_row in file_list:
            x_test = input_data(file_row[1])
   
            feed_dict = {x:x_test,keep_prob:1.0}            
            ret = sess.run(y_conv,feed_dict)

            predint=sess.run(tf.argmax(ret,1))   # 返回预测结果
            preddigit = predint[0]        # 预测数字
            digit = file_row[0]           # 真实数字
            y_pre = ret[0]    #预测返回概率            
            
            print('输入标签: ' + str(digit) + ', 预测结果: ' + str(preddigit) )
            if int(digit)!=int(preddigit):
                print ('参考概率:' + str(y_pre))
                print('图片文件名称:' + file_row[1])

#读取源文件列表,并拆分为出数据标签
def readImgFileName(path):
    list = [file[2] for file in os.walk(path)]
    file_names = [file_name.strip('.png') for file_name in list[0]]
    file_list = []
    for f0 in file_names:
        kk=f0.split('_')
        tmp = []
        tmp.append(kk[len(kk)-1])
        tmp.append(path + '/' + f0 + '.png')
        file_list.append(tmp)
    
    return file_list

if __name__ == '__main__':
    nowTime=datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')#现在
    print('start :{}'.format(nowTime))
    modele_test()

但是需要注意的是,使用feed_dict设置tensor的时候,需要你给出的值类型与占位符定义的类型相同。

参考:
《一份快速完整的Tensorflow模型保存和恢复教程(译)》 CSDN博客 Meringue_zz 2018年1月
《TensorFlow CNN卷积神经网络实现工况图分类识别(一)》 CSDN博客 肖永威 2019年3月
《使用Python开发工具Jupyter Notebook学习Tensorflow入门及Tensorboard实践》 CSDN博客 肖永威 2019年1月
《TensorFlow学习笔记:Saver与Restore》 简书 DexterLei 2017年10月
《TensorFlow模型保存/载入的两种方法》 CSDN博客 thriving_fcl 2017年5月

  • 6
    点赞
  • 57
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

肖永威

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值