tensorflow model save

一种是传统的Saver类save保存和restore恢复方法
    1. TensorFlow模型简介
        训练了一个神经网络之后,我们希望保存它以便将来使用。那么什么是TensorFlow模型?Tensorflow模型主要包含我们所培训的网络参数的网络设计或图形和值。因此,Tensorflow模型有两个主要的文件:
            a) Meta graph:
                这是一个协议缓冲区,它保存了完整的Tensorflow图形;即所有变量、操作、集合等。该文件以.meta作为扩展名。
            b) .data-00000-of-00001
                .data文件是包含我们训练变量的文件,我们待会将会使用它。
            c) checkpoint
                与此同时,Tensorflow也有一个名为checkpoint的文件,它只保存的最新保存的checkpoint文件的记录。
            d) .index
    2. 保存TensorFlow模型
        在Tensorflow中,我们希望保存所有参数的图和值,我们将创建一个tf.train.Saver()类的实例。
            saver = tf.train.Saver()
        请记住,Tensorflow变量仅在会话中存在。因此,您必须在一个会话中保存模型,调用您刚刚创建的save方法。
            saver.save(sess, 'my-test-model')
        这里,sess是会话对象,而'my-test-model'是保存的模型的名称。让我们来看一个完整的例子:

        codes:
            import tensorflow as tf
            w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
            w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
            saver = tf.train.Saver()
            sess = tf.Session()
            sess.run(tf.global_variables_initializer())
            saver.save(sess, 'my_test_model')

        如果我们在1000次迭代之后保存模型,我们将通过通过global_step来调用save:
            saver.save(sess, 'my_test_model',global_step=1000)
        如果你希望仅保留4个最新的模型,并且希望在训练过程中每两个小时后保存一个模型,那么你可以使用max_to_keep和keep_checkpoint_every_n_hours这样做。
            saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)

        #仅仅保存图结构文件
        图存储和加载write_graph/import_graph_def方法
            图存储方法:
                def write_graph(graph_or_graph_def, logdir, name, as_text=True):

                该函数存储一个tensorflow图原型到文件里,其参数含义如下:
                    graph_or_graph_def:tensorflow Graph或GraphDef;
                    logdir:保存图或图原型的目录;
                    as_text:默认为True,即以ASCII方式写到文件里
                    return:返回图或图原型保存的路径

                codes:
                    v = tf.Variable(0, name='my_variable')
                    sess = tf.Session()
                    # tf.train.write_graph(sess.graph, '/tmp/my-model', 'train.pbtxt') --> that is ok
                    tf.train.write_graph(sess.graph_def, '/tmp/my-model', 'train.pbtxt')

            图加载方法:
                def import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None, producer_op_list=None):

                该函数可加载已存储的"graph_def"到当前默认图里,并从系列化的tensorflow [`GraphDef`]协议缓冲里提取所有的tf.Tensor和tf.Operation到当前图里,其参数如下:
                    graph_def:一个包含图操作OP且要导入GraphDef的默认图;
                    input_map:字典关键字映射,用以从已保存图里恢复出对应的张量值;
                    return_elements:从已保存模型恢复的Ops或Tensor对象;
                    return:从已保存模型恢复后的Ops和Tensorflow列表,其名字位于return_elements;

                codes:
                    with tf.Session() as _sess:
                      with gfile.FastGFile("/tmp/tfmodel/train.pbtxt",'rb') as f:
                        graph_def = tf.GraphDef()
                        graph_def.ParseFromString(f.read())
                        _sess.graph.as_default()
                        tf.import_graph_def(graph_def, name='tfgraph')

        MetaGraph导出和导入export_meta_graph/ import_meta_graph方法
            一个MetaGraph既包含了tensorflow GraphDef,也包含了在跨越进程边界时在图形中运行计算所需的相关元数据,它也可以用来长期存储tensorflow图结构。
            MetaGraph包含继续训练、执行评估或在先前训练的图形上运行推理所需的信息。

            MetaGraph导出方法:
                def export_meta_graph(filename=None, collection_list=None, as_text=False, export_scope=None, clear_devices=False, clear_extraneous_savers=False):
                
                该函数可以导出tensorflow元图及其所需的数据,其参数如下:
                    filename:保存路径及其文件名;
                    collection_list:要收集的字符串键的列表;
                    as_text:为True时导出的文本格式为ASCII编码;
                    export_scope:导出的名字空间,用以删除;
                    clear_devices:导出时将与设备相关的信息去掉,即导出文件不与特定设备环境关联;
                    clear_extraneous_savers:从图中删除与此导出操作无关的任何saver相关信息(保存/恢复操作和SaverDefs)。
                    return:MetaGraphDef proto;

                codes:            
                    # Build the model
                    ...
                    with tf.Session() as sess:
                      # Use the model
                      ...
                    # Export the default running graph and only a subset of the collections.
                    meta_graph_def = tf.train.export_meta_graph(
                        filename='/tmp/my-model.meta',
                        collection_list=["input_tensor", "output_tensor"])

            MetaGraph导入方法:
                def import_meta_graph(meta_graph_or_file, clear_devices=False, import_scope=None, **kwargs):
                该函数以“MetaGraphDef”协议缓冲区作为输入,如果其参数是一个包含“MetaGraphDef”协议缓冲区的文件,它将以文件内容构造一个协议缓冲区,然后将“graph_def”字段中的所有节点添加到当前图形,并重新创建所有由collection_list收集的列表内容,最后返回由“saver_def”字段构造的saver以供使用,其参数如下:
                    meta_graph_or_file:`MetaGraphDef`协议缓冲区或者包含MetaGraphDef且带有路径的文件名;
                    clear_devices:导入时将与设备相关的信息去掉,即不与导出时的图设备环境关联,可兼容当前设备环境;
                    import_scope:导入名字空间,用以删除;
                    **kwargs:可选的参数;
                    return:在“MetaGraphDef”中由“saver_def”构造的存储模型,如果MetaGraphDef没有保存的变量则会直接返回None;

                codes:
                    ...
                    # Create a saver.
                    saver = tf.train.Saver(...variables...)
                    # Remember the training_op we want to run by adding it to a collection.
                    tf.add_to_collection('train_op', train_op)
                    sess = tf.Session()
                    for step in xrange(1000000):
                        sess.run(train_op)
                        if step % 1000 == 0:
                            # Saves checkpoint, which by default also exports a meta_graph
                            # named 'my-model-global_step.meta'.
                            saver.save(sess, 'my-model', global_step=step)
                     
                     
                    with tf.Session() as sess:
                      new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
                      new_saver.restore(sess, 'my-save-dir/my-model-10000')
                      # tf.get_collection() returns a list. In this example we only want the
                      # first one.
                      train_op = tf.get_collection('train_op')[0]
                      for step in xrange(1000000):
                        sess.run(train_op)

        补充:
            Saver的构造函数如下:
                __init__(
                    var_list=None,
                    reshape=False,
                    sharded=False,
                    max_to_keep=5,
                    keep_checkpoint_every_n_hours=10000.0,
                    name=None,
                    restore_sequentially=False,
                    saver_def=None,
                    builder=None,
                    defer_build=False,
                    allow_empty=False,
                    write_version=tf.train.SaverDef.V2,
                    pad_step_number=False,
                    save_relative_paths=False,
                    filename=None
                )

                保存模型时:
                    var_list:特殊需要保存和恢复的变量和可保存对象列表或字典,默认为空,将会保存所有的可保存对象;
                    max_to_keep:保存多少个最新的checkpoint文件,默认为5,即保存最近五个checkpoint文件;
                    keep_checkpoint_every_n_hours:多久保存checkpoint文件,默认为10000小时,相当于禁用了这个功能;
                    save_relative_paths:为True时,checkpoint文件将不会记录完整的模型路径,而只会仅仅记录模型名字,这方便于将保存下来的模型复制到其他目录并使用的情况;

                恢复模型时:
                    reshape:为True时,允许从已保存checkpoint文件里恢复并重新设定形状不一样的张量,默认为false;
                    sharded:碎片化checkpoint文件到每一个设备,默认false;
                    restore_sequentially:为True时,会在每个设备中顺序地恢复不同的变量,同时可以在恢复比较大的模型时节省内存

            save接口如下:
                save(
                    sess,
                    save_path,
                    global_step=None,
                    latest_filename=None,
                    meta_graph_suffix='meta',
                    write_meta_graph=True,
                    write_state=True
                )

                其参数说明如下:
                    sess:一个建好图的会话,用以运行保存操作;
                    save_path:包含模型名字的绝对路径,最终会自动在模型名字添加相应后缀
                    global_step:该参数会自动添加到save_path名字用以区别不同步骤保存的模型;
                    latest_filename:生成检查点文件的名字,默认是“checkpoint”;
                    meta_graph_suffix:MetaGraphDef元图后缀,默认为“meta”;
                    write_meta_graph:指明是否要保存元图数据,默认为True;
                    write_state:指明是否要写CheckpointStateProto,默认为True


    3.导入训练好的模型
        a)创建网络
            你可以通过编写python代码创建网络,以手工创建每一层,并将其作为原始模型。但是,如果你考虑一下,我们已经在.meta文件中保存了这个网络,我们可以使用tf.train.import()函数来重新创建这个网络:
                saver = tf.train.import_meta_graph('my_test_model-1000.meta')
        b)载入参数
            我们可以通过调用这个保护程序的实例来恢复网络的参数,它是tf.train.Saver()类的一个实例。
                with tf.Session() as sess:
                    new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')
                    new_saver.restore(sess, tf.train.latest_checkpoint('./'))
                    print(sess.run('w1:0'))


        获取最近保存的所有模型
            last_ckpt = saver.last_checkpoints
            或者
            ckpt = tf.train.get_checkpoint_state("/home/xsr-ai/study/mnist/mnist-model")
        我们要恢复哪一个模型,可以使用如下任一种类似方法:
            saver.restore(last_ckpt[-1])
            saver.restore(last_ckpt[0])
            saver.restore(ckpt.model_checkpoint_path)
            saver.restore(ckpt.all_model_checkpoint_paths[-1])
        使用restore恢复已保存模型
            saver.restore(sess, save_path)

            sess:用以恢复参数模型的会话;
            save_path:已保存模型的路径,通常包含模型名字;


    4.使用导入的模型
        现在你已经了解了如何保存和恢复Tensorflow模型,让我们开发一个实用的例子来恢复任何预先训练的模型,并使用它进行预测、微调或进一步训练。当您使用Tensorflow时,你将定义一个图,该图是feed examples(训练数据)和一些超参数(如学习速率、迭代次数等),它是一个标准的过程,我们可以使用占位符来存放所有的训练数据和超参数。接下来,让我们使用占位符构建一个小网络并保存它。注意,当网络被保存时,占位符的值不会被保存。
        现在,当我们想要恢复它时,我们不仅要恢复图和权重,还要准备一个新的feed_dict,它将把新的训练数据输入到网络中。我们可以通过graph.get_tensor_by_name()方法来引用这些保存的操作和占位符变量。

        codes:
            import tensorflow as tf
     
            sess=tf.Session()    
            #First let's load meta graph and restore weights
            saver = tf.train.import_meta_graph('my_test_model-1000.meta')
            saver.restore(sess,tf.train.latest_checkpoint('./'))
             
             
            # Now, let's access and create placeholders variables and
            # create feed-dict to feed new data
             
            graph = tf.get_default_graph()
            w1 = graph.get_tensor_by_name("w1:0")
            w2 = graph.get_tensor_by_name("w2:0")
            feed_dict ={w1:13.0,w2:17.0}
             
            #Now, access the op that you want to run. 
            op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
             
            print sess.run(op_to_restore,feed_dict)
            #This will print 60 which is calculated 
            #using new values of w1 and w2 and saved value of b1.

        但是,你是否可以在之前图的结构上构建新的网络?当然,您可以通过graph.get_tensor_by_name()方法访问适当的操作,并在此基础上构建图。这是一个真实的例子。在这里,我们使用元图加载一个vgg预训练的网络,并在最后一层中将输出的数量更改为2,以便对新数据进行微调。
            saver = tf.train.import_meta_graph('vgg.meta')
            # Access the graph
            graph = tf.get_default_graph()
            ## Prepare the feed_dict for feeding data for fine-tuning 
             
            #Access the appropriate output for fine-tuning
            fc7= graph.get_tensor_by_name('fc7:0')
             
            #use this if you only want to change gradients of the last layer
            fc7 = tf.stop_gradient(fc7) # It's an identity function
            fc7_shape= fc7.get_shape().as_list()
             
            new_outputs=2
            weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
            biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
            output = tf.matmul(fc7, weights) + biases
            pred = tf.nn.softmax(output)


    示例:
        MNIST代码

            """A very simple MNIST classifier.
            See extensive documentation at
            https://www.tensorflow.org/get_started/mnist/beginners
            """
            from __future__ import absolute_import
            from __future__ import division
            from __future__ import print_function
             
            import argparse
            import sys
             
            from tensorflow.examples.tutorials.mnist import input_data
             
            import tensorflow as tf
             
            FLAGS = None
             
             
            def main(_):
              # Import data
              mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
             
              # Create the model
              x = tf.placeholder(tf.float32, [None, 784])
              W = tf.Variable(tf.zeros([784, 10]))
              b = tf.Variable(tf.zeros([10]))
              y = tf.matmul(x, W) + b
             
              # Define loss and optimizer
              y_ = tf.placeholder(tf.float32, [None, 10])
             
              # The raw formulation of cross-entropy,
              #
              #   tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)),
              #                                 reduction_indices=[1]))
              #
              # can be numerically unstable.
              #
              # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw
              # outputs of 'y', and then average across the batch.
              cross_entropy = tf.reduce_mean(
                  tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
              train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
             
              sess = tf.InteractiveSession()
              tf.global_variables_initializer().run()
              # Train
              saver = tf.train.Saver()
              for index in range(1000):
                batch_xs, batch_ys = mnist.train.next_batch(100)
                sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
                if index % 100 == 0:
                  print("index: %d" % index)
                  path = saver.save(sess, "/home/xsr-ai/study/mnist/mnist-model/model.ckpt", global_step=index) # , latest_filename="hello"
             
              # Test trained model
              correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
              accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
              print(sess.run(accuracy, feed_dict={x: mnist.test.images,
                                                  y_: mnist.test.labels}))
             
              ckpt = tf.train.get_checkpoint_state("/home/xsr-ai/study/mnist/mnist-model")
              saver.restore(sess, ckpt.all_model_checkpoint_paths[0])
              print(ckpt)
              print(sess.run(accuracy, feed_dict={x: mnist.test.images,
                                                  y_: mnist.test.labels}))
             
            if __name__ == '__main__':
              parser = argparse.ArgumentParser()
              parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
                                  help='Directory for storing input data')
              FLAGS, unparsed = parser.parse_known_args()
              tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

一种是比较新颖的SavedModelBuilder类的builder保存和loader文件里的load恢复方法

        builder/loader方法也是可以保存和恢复tensorflow模型的,只是他们源代码是在不同文件里,builder其源代码在tensorflow/python/saved_model/builder_impl.py,
        而loader的源代码则位于tensorflow/python/saved_model/loader_impl.py。
        相较于save和restore方法会生成比较多的模型文件,builder和loader方法则会更简单一些,同时也是saver提供的更高级别的系列化,它也更适合于商业化,按照创作者的说法“它显然是未来!”

        使用builder方法保存模型:
            我们主要使用SavedModelBuilder类来新建一个builder,SavedModelBuilder的参数很简单,就一个export_dir参数即要保存模型的路径,但要确保所保存的目录是未有建立的,否则会导致出错!

            获取builder方法如下:
                builder = tf.saved_model.builder.SavedModelBuilder("/home/xsr-ai/study/mnist/saved-model")
            在训练完后,我们调用如下命令保存模型:
                builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING], signature_def_map=None, assets_collection=None)
                builder.save()

                add_meta_graph_and_variables的介绍如下:
                    def add_meta_graph_and_variables(sess,tags,signature_def_map=None,assets_collection=None,legacy_init_op=None,clear_devices=False,main_op=None)
                    该函数可以将当前元图添加到SavedModel并保存变量,其参数如下:
                        sess:用于执行添加元图和变量功能的会话;
                        tags:用于保存元图的标签;
                        signature_def_map:用于保存元图的签名;
                        assets_collection:使用SavedModel保存的资源集合;
                        legacy_init_op:在恢复模型操作后,对Op和Ops组的遗留支持;
                        clear_devices:如果默认图形上的设备信息应该被清除,则应该设置为true;
                        main_op:在加载图时执行Op或Ops组的操作。请注意,当main_op被指定时,它将在加载恢复op后运行;
                        return:无返回

                save()的介绍:
                    def save(as_text=False):
                    该函数将“SavedModel”协议缓冲区的数据写入到硬盘里,其参数只有一个as_text,主要用于指明是否按照ASCII编码格式写入到文件里,其返回的是保存模型的路径。

        使用loader方法恢复模型:
            def load(sess, tags, export_dir, **saver_kwargs):
            该函数可以从标签指定的SavedModel加载模型,其参数如下:
                sess:恢复模型的会话;
                tags:用于恢复元图的标签,需与保存时的一致,用于区别不同的模型;
                export_dir:存储SavedModel协议缓冲区和要加载的变量的目录;
                **saver_kwargs:可选的关键字参数传递给saver;
                return:在提供的会话中加载的“MetaGraphDef”协议缓冲区,这可以用于进一步提取signature-defs, collection-defs等;

        load通常使用方法如下:
            with tf.Session() as sess:  
                tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], "/home/xsr-ai/study/mnist/saved-model")

        那么又如何恢复由builder保存的模型呢?我使用如下例子来说明如何使用loader来恢复模型,代码比较简洁,主要是测试恢复模型后,可否正常获取到特定的变量权值:
            import tensorflow as tf
            with tf.Session() as sess:
              tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], "/home/xsr-ai/study/mnist/saved-model")
              var = sess.run('layer2/biases/Variable:0')
              print(var)

        示例:
            MNIST示例:
                """
                A simple MNIST classifier which displays summaries in TensorBoard.
                This is an unimpressive MNIST model, but it is a good example of using
                tf.name_scope to make a graph legible in the TensorBoard graph explorer, and of
                naming summary tags so that they are grouped meaningfully in TensorBoard.
                It demonstrates the functionality of every TensorBoard dashboard.
                """
                from __future__ import absolute_import
                from __future__ import division
                from __future__ import print_function
                 
                import argparse
                import os
                import sys
                 
                import tensorflow as tf
                 
                from tensorflow.examples.tutorials.mnist import input_data
                 
                FLAGS = None
                 
                 
                def train():
                  # Import data
                  mnist = input_data.read_data_sets(FLAGS.data_dir,
                                                    one_hot=True,
                                                    fake_data=FLAGS.fake_data)
                 
                  sess = tf.InteractiveSession()
                  # Create a multilayer model.
                 
                  # Input placeholders
                  with tf.name_scope('input'):
                    x = tf.placeholder(tf.float32, [None, 784], name='x-input')
                    y_ = tf.placeholder(tf.float32, [None, 10], name='y-input')
                 
                  with tf.name_scope('input_reshape'):
                    image_shaped_input = tf.reshape(x, [-1, 28, 28, 1])
                    tf.summary.image('input', image_shaped_input, 10)
                 
                  # We can't initialize these variables to 0 - the network will get stuck.
                  def weight_variable(shape):
                    """Create a weight variable with appropriate initialization."""
                    initial = tf.truncated_normal(shape, stddev=0.1)
                    return tf.Variable(initial)
                 
                  def bias_variable(shape):
                    """Create a bias variable with appropriate initialization."""
                    initial = tf.constant(0.1, shape=shape)
                    return tf.Variable(initial)
                 
                  def variable_summaries(var):
                    """Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
                    with tf.name_scope('summaries'):
                      mean = tf.reduce_mean(var)
                      tf.summary.scalar('mean', mean)
                      with tf.name_scope('stddev'):
                        stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
                      tf.summary.scalar('stddev', stddev)
                      tf.summary.scalar('max', tf.reduce_max(var))
                      tf.summary.scalar('min', tf.reduce_min(var))
                      tf.summary.histogram('histogram', var)
                 
                  def nn_layer(input_tensor, input_dim, output_dim, layer_name, act=tf.nn.relu):
                    """Reusable code for making a simple neural net layer.
                    It does a matrix multiply, bias add, and then uses ReLU to nonlinearize.
                    It also sets up name scoping so that the resultant graph is easy to read,
                    and adds a number of summary ops.
                    """
                    # Adding a name scope ensures logical grouping of the layers in the graph.
                    with tf.name_scope(layer_name):
                      # This Variable will hold the state of the weights for the layer
                      with tf.name_scope('weights'):
                        weights = weight_variable([input_dim, output_dim])
                        variable_summaries(weights)
                      with tf.name_scope('biases'):
                        biases = bias_variable([output_dim])
                        variable_summaries(biases)
                      with tf.name_scope('Wx_plus_b'):
                        preactivate = tf.matmul(input_tensor, weights) + biases
                        tf.summary.histogram('pre_activations', preactivate)
                      activations = act(preactivate, name='activation')
                      tf.summary.histogram('activations', activations)
                      return activations
                 
                  hidden1 = nn_layer(x, 784, 500, 'layer1')
                 
                  with tf.name_scope('dropout'):
                    keep_prob = tf.placeholder(tf.float32)
                    tf.summary.scalar('dropout_keep_probability', keep_prob)
                    dropped = tf.nn.dropout(hidden1, keep_prob)
                 
                  # Do not apply softmax activation yet, see below.
                  y = nn_layer(dropped, 500, 10, 'layer2', act=tf.identity)
                 
                  with tf.name_scope('cross_entropy'):
                    # The raw formulation of cross-entropy,
                    #
                    # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.softmax(y)),
                    #                               reduction_indices=[1]))
                    #
                    # can be numerically unstable.
                    #
                    # So here we use tf.nn.softmax_cross_entropy_with_logits on the
                    # raw outputs of the nn_layer above, and then average across
                    # the batch.
                    diff = tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)
                    with tf.name_scope('total'):
                      cross_entropy = tf.reduce_mean(diff)
                  tf.summary.scalar('cross_entropy', cross_entropy)
                 
                  with tf.name_scope('train'):
                    train_step = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(
                        cross_entropy)
                 
                  with tf.name_scope('accuracy'):
                    with tf.name_scope('correct_prediction'):
                      correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
                    with tf.name_scope('accuracy'):
                      accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
                  tf.summary.scalar('accuracy', accuracy)
                 
                  # Merge all the summaries and write them out to
                  # /tmp/tensorflow/mnist/logs/mnist_with_summaries (by default)
                  merged = tf.summary.merge_all()
                  train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
                  test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test')
                  tf.global_variables_initializer().run()
                 
                  # Train the model, and also write summaries.
                  # Every 10th step, measure test-set accuracy, and write test summaries
                  # All other steps, run train_step on training data, & add training summaries
                 
                  def feed_dict(train):
                    """Make a TensorFlow feed_dict: maps data onto Tensor placeholders."""
                    if train or FLAGS.fake_data:
                      xs, ys = mnist.train.next_batch(100, fake_data=FLAGS.fake_data)
                      k = FLAGS.dropout
                    else:
                      xs, ys = mnist.test.images, mnist.test.labels
                      k = 1.0
                    return {x: xs, y_: ys, keep_prob: k}
                 
                  for i in range(FLAGS.max_steps):
                    if i % 100 == 0:  # Record summaries and test-set accuracy
                      summary, acc = sess.run([merged, accuracy], feed_dict=feed_dict(False))
                      test_writer.add_summary(summary, i)
                      print('Accuracy at step %s: %s' % (i, acc))
                    else:  # Record train set summaries, and train
                      if i % 100 == 99:  # Record execution stats
                        run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                        run_metadata = tf.RunMetadata()
                        summary, _ = sess.run([merged, train_step],
                                              feed_dict=feed_dict(True),
                                              options=run_options,
                                              run_metadata=run_metadata)
                        train_writer.add_run_metadata(run_metadata, 'step%03d' % i)
                        train_writer.add_summary(summary, i)
                        print('Adding run metadata for', i)
                      else:  # Record a summary
                        summary, _ = sess.run([merged, train_step], feed_dict=feed_dict(True))
                        train_writer.add_summary(summary, i)
                 
                  builder = tf.saved_model.builder.SavedModelBuilder("/home/xsr-ai/study/mnist/saved-model")
                  builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING])
                  builder.save()  
                 
                  train_writer.close()
                  test_writer.close()
                 
                 
                def main(_):
                  if tf.gfile.Exists(FLAGS.log_dir):
                    tf.gfile.DeleteRecursively(FLAGS.log_dir)
                  tf.gfile.MakeDirs(FLAGS.log_dir)
                  train()
                 
                 
                if __name__ == '__main__':
                  parser = argparse.ArgumentParser()
                  parser.add_argument('--fake_data', nargs='?', const=True, type=bool,
                                      default=False,
                                      help='If true, uses fake data for unit testing.')
                  parser.add_argument('--max_steps', type=int, default=1000,
                                      help='Number of steps to run trainer.')
                  parser.add_argument('--learning_rate', type=float, default=0.001,
                                      help='Initial learning rate')
                  parser.add_argument('--dropout', type=float, default=0.9,
                                      help='Keep probability for training dropout.')
                  parser.add_argument(
                      '--data_dir',
                      type=str,
                      default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'),
                                           'tensorflow/mnist/input_data'),
                      help='Directory for storing input data')
                  parser.add_argument(
                      '--log_dir',
                      type=str,
                      default="/home/xsr-ai/study/mnist/logdir",
                      help='Summaries log directory')
                  FLAGS, unparsed = parser.parse_known_args()
                  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)


    一个pb文件,以及一个variables文件夹,里面存放的是variables.data-00000-of-00001和
    variables.index,与save/restore方法比,没有checkpoint检查点文件以及以“.meta”为后缀的元数据文件,但是多了一个pb文件,这是这两种tensorflow保存和恢复模型方法的区别!

参考链接:
https://blog.csdn.net/tan_handsome/article/details/79303269
https://blog.csdn.net/fly_time2012/article/details/82889418

  • 0
    点赞
  • 3
    收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
©️2022 CSDN 皮肤主题:大白 设计师:CSDN官方博客 返回首页
评论
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值