Tensorflow1.x加载模型的方法

代码地址:查看完整代码

一个错误的使用

之前有同学问过我这个问题,TF加载模型,跑出来的结果不对,代码见incurrect_usage.py,正确率和猜的一样,怀疑是模型加载那里出问题了。

#****************** incurrent usage.py*********************
x = tf.placeholder(tf.float32, [None, 784], name="input")
y_ = tf.placeholder(tf.float32, [None, 10], name="label")
pred = forward(x)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.import_meta_graph('./model/mnist_model-4000.meta')
    saver.restore(sess, './model/mnist_model-4000')
    correct_prediction = tf.equal(tf.argmax(pred,1), tf.argmax(test_label, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    acc = sess.run(accuracy, feed_dict={x: test_image, y_: test_label})

跑出来的正确率都在0.1左右,训练正确率都在0.9以上,再差也不会这样,所以加载模型哪里出错了。

Tensorflow加载模型的方法

本例使用tf.train.Saver()保存模型的方法,执行saver.save(sess, model_name)后,会得到3个名为model_name的文件,.data-00000-of-00001中保存了网络训练的参数,.meta保存了网络的图结构。

Tensorflow在加载模型的时候就需要上述的两个东西,网络参数和图结构,而加载图有两种方式,重新搭建网络直接用.meta文件。

重新搭建网络

顾名思义,在测试代码中重新把训练时forward的流程再搭一遍,这样就能得到由训练好的参数得到forward的结果。

#****************** test with network.py*********************
x = tf.placeholder(tf.float32, [None, 784], name="input")
y_ = tf.placeholder(tf.float32, [None, 10], name="label")
pred = forward(x)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess, './model/mnist_model-4000')
    correct_prediction = tf.equal(tf.argmax(pred,1), tf.argmax(test_label, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    acc = sess.run(accuracy, feed_dict={x: test_image, y_: test_label})

因为forward流程和训练时一样,所以直接在训练代码里拿来用,已经重新搭建了图,就不要加载.meta文件了,所以直接restore参数文件就可以了。

拿测试集中前5000个样本做测试,测试结果:

test with network: 
INFO:tensorflow:Restoring parameters from ./model/mnist_model-4000
accuracy is:  0.9784

网络结构:

在这里插入图片描述

使用.meta文件构建图

使用.meta文件需要注意,在训练时最好为输入和输出取一个名字,因为需要直接从.meta保存的图结构中取输入和输出,有名字的时候会更明确一些。

像这样:

x = tf.placeholder(tf.float32, [None, 784], name="input")
y_ = tf.placeholder(tf.float32, [None, 10], name="label")

加载.meta代码如下:

#****************** test with meta.py*********************
with tf.Session() as sess:
    saver = tf.train.import_meta_graph('./model/mnist_model-4000.meta')
    saver.restore(sess, tf.train.latest_checkpoint("./model/"))
    
    graph = tf.get_default_graph()
    input_x = graph.get_operation_by_name("input").outputs[0]
    feed_dict = {"input:0":test_image, "label:0":test_label}
    pred = graph.get_tensor_by_name("output:0")
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(test_label, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    acc = sess.run(accuracy, feed_dict=feed_dict)

使用.meta文件,直接根据名字找到对应的输出和输出,获取默认图结构,不需要重新初始化参数。

拿测试集中前5000个样本做测试,测试结果:

test with .meta:
INFO:tensorflow:Restoring parameters from ./model/mnist_model-4000
accuracy is:  0.9784

测试结果和重新构建网络是一样的。

网络结构:

在这里插入图片描述

使用.meta测试时,网络输出那里出现了两个分支,猜测是.meta保存了训练时测试accuracy那部分图,我在测试的代码里又写了一个测试accuracy的部分,所以两部分都被保存了,但不影响测试的结果。

错误的原因

很容易猜到,图加载了两次,已经重建网络了,然后又加载了.meta,导致图的结构乱了,看图:

在这里插入图片描述

网络的结构已经变了,所以加载训练好的模型时,要么重建图,要么加载.meta,混合起来就容易出错。

TODO:使用滑动平均如何加载模型

参考

Mnist网络backbone:点击前往

TF加载模型方法: 点击前往

  • 1
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
好的,针对您的问题,我可以给您提供一个基于 TensorFlow 2.x 的食物图像识别模型的简单实现思路和步骤。 首先,我们需要准备一组带有标注的食品图像数据集,可以使用现有的公开数据集,例如 Food-101 数据集,该数据集包括 101 类不同类型的食品图像,并且已经经过标注。 接下来,我们可以使用 TensorFlow 2.x 中的 Keras 库来构建食物图像识别模型,具体的步骤如下: 1. 加载数据集并进行预处理。我们可以使用 TensorFlow 中的 ImageDataGenerator 类来进行数据增强和预处理操作,例如对图像进行旋转、缩放、翻转等操作,以增加数据集的多样性和鲁棒性。 2. 定义模型结构。我们可以使用 Keras 中提供的各种卷积神经网络(CNN)模型,例如 VGG、ResNet、Inception 等,也可以自己设计模型结构。根据数据集的大小和复杂度,我们可以选择使用较浅的网络结构,以避免过拟合,或者使用较深的网络结构,以提高模型的准确率。 3. 编译模型。在编译模型时,我们需要指定损失函数、优化器和评估指标。对于分类任务,我们可以选择使用交叉熵损失函数,常用的优化器有 Adam、SGD、RMSprop 等,评估指标可以选择准确率、精确率、召回率等。 4. 训练模型。在训练模型时,我们可以根据数据集的大小和计算资源的限制,选择适当的批量大小和训练轮数。可以使用 Keras 中的 fit() 方法来进行模型训练,同时可以指定验证集和回调函数,以监控模型的训练过程和性能。 5. 测试模型。在测试模型时,我们可以使用测试集来评估模型的准确率和其他指标。可以使用 Keras 中的 evaluate() 方法来进行模型测试,并输出测试结果。 以上就是一个基于 TensorFlow 2.x 的食物图像识别模型的简单实现思路和步骤。当然,实际应用中还需要根据具体的场景和需求进行进一步的优化和调整。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值