tensorflow模型保存与加载

tensorflow版本为1.4.1

tensorflow提供了Saver类用于模型的保存与导入。该类定义在tensorflow/python/training/saver.py.中。

Saver类的默认初始化函数如下:

__init__(
    var_list=None,
    reshape=False,
    sharded=False,
    max_to_keep=5,
    keep_checkpoint_every_n_hours=10000.0,
    name=None,
    restore_sequentially=False,
    saver_def=None,
    builder=None,
    defer_build=False,
    allow_empty=False,
    write_version=tf.train.SaverDef.V2,
    pad_step_number=False,
    save_relative_paths=False,
    filename=None
)

由于该初始化函数均有缺省值,因此我们常用的创建一个Saver对象的操作为tf.train.Saver()

下面解释一下常用的参数:

  • var_list: Variable/SaveableObject的列表,或者是一个字典(mapping names to SaveableObjects)。默认为None,即保存所有可保存的对象。
  • reshape: 当为True时,表示从一个checkpoint中恢复参数时允许参数shape发生变化。当我们reshape了一个变量又希望加载旧模型时,该操作就很有用。
  • max_to_keep:为了避免填满整个磁盘,Saver可以自动的管理Checkpoints文件。该参数用于指定保存最近的N个Checkpoints文件,默认为5.
  • keep_checkpoint_every_n_hours: 为了避免填满整个磁盘,Saver可以自动的管理Checkpoints文件。该参数用于指定保留Checkpoints文件的时间,默认为10000小时

保存模型需要用到save函数:

save(
    sess,
    save_path,
    global_step=None,
    latest_filename=None,
    meta_graph_suffix='meta',
    write_meta_graph=True,
    write_state=True
)

下面介绍其中的参数:

  • sess: 保存模型要求必须有一个加载了计算图的会话,而且所有变量必须已被初始化。
  • save_path: 模型保存路径及保存名称。
  • global_step: 如果提供的话,这个数字会添加到save_path后面,用于区分不同训练阶段的结果。

加载模型的函数为:

restore(
    sess,
    save_path
)
  • sess: 加载模型要求必须有一个加载了计算图的会话,但是不要求变量初始化。
  • save_path: 模型保存路径

下面给一个模型保存的例子,例子还是使用 “单层感知机实现mnist数字分类”

# -*- coding: utf-8 -*-
import tensorflow as tf
from input_data import read_data_sets
import os

# don't show INFO 
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
# read mnist
mnist = read_data_sets('MNIST_data', one_hot=True)
# single layer perceptron: y = wx + b
# input
x = tf.placeholder(tf.float32, [None, 784])
# weights
W = tf.Variable(tf.random_normal([784,10], stddev=0.1))
# bias
b = tf.Variable(tf.zeros([10]))
# softmax 
y = tf.nn.softmax(tf.matmul(x,W) + b)
# output
y_ = tf.placeholder(tf.float32, [None, 10])
# cross_entropy loss
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
# optimization with gradient descend, the learning rate is set as 0.02
train_step = tf.train.GradientDescentOptimizer(0.02).minimize(cross_entropy)
# initalize all variables
init = tf.global_variables_initializer()
# start a new session
sess = tf.Session()
sess.run(init)
m_saver = tf.train.Saver()

# 2000 iterations
for i in range(2000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
    if i % 100 == 0:
        m_saver.save(sess, './model/mnist_slp', global_step=i)

# computer the accuracy
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

# close the session
sess.close()

最后可以在model文件下看到保存的文件:

这里写图片描述

说明:

  • 尽管每100次保存一次模型,但实际上只会保存最近的5次。
  • 每次保存都会产生3个文件,其中mnist_slp-1900.data-00000-of-00001存放的是模型参数,mnist_slp-1900.meta中存放的是计算图。

当我们加载模型时,如下:

# -*- coding: utf-8 -*-
import tensorflow as tf
from input_data import read_data_sets
import os

# don't show INFO 
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'

# read mnist
mnist = read_data_sets('MNIST_data', one_hot=True)


# single layer perceptron: y = wx + b
# input
x = tf.placeholder(tf.float32, [None, 784])

# weights
W = tf.Variable(tf.random_normal([784,10], stddev=0.1))

# bias
b = tf.Variable(tf.zeros([10]))

# softmax 
y = tf.nn.softmax(tf.matmul(x,W) + b)

# output
y_ = tf.placeholder(tf.float32, [None, 10])

# cross_entropy loss
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))

# optimization with gradient descend, the learning rate is set as 0.02
train_step = tf.train.GradientDescentOptimizer(0.02).minimize(cross_entropy)


# start a new session
sess = tf.Session()

m_saver = tf.train.Saver()

# load the model
m_saver.restore(sess, './model/mnist_slp-1900')

# computer the accuracy
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

# close the session
sess.close()
阅读更多
版权声明:本文为博主原创文章,转载请注明出处 https://blog.csdn.net/shuzfan/article/details/79197432
文章标签: tensorflow
个人分类: TensorFlow
上一篇tensorflow创建变量以及根据名称查找变量
下一篇latex对修改内容进行高亮
想对作者说点什么? 我来说一句

没有更多推荐了,返回首页

关闭
关闭