TensorFlow之模型的保存、恢复以及自定义命令行参数

在我们训练或者测试过程中,总会遇到需要保存训练完成的模型,然后从中恢复继续我们的测试或者其它使用。模型的保存和恢复也是通过tf.train.Saver类去实现,它主要通过将Saver类添加OPS保存和恢复变量到checkpoint。它还提供了运行这些操作的便利方法。

tf.train.Saver(var_list=None, reshape=False, sharded=False, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False, saver_def=None, builder=None, defer_build=False, allow_empty=False, write_version=tf.SaverDef.V2, pad_step_number=False)

  • var_list:指定将要保存和还原的变量。它可以作为一个dict或一个列表传递.
  • max_to_keep:指示要保留的最近检查点文件的最大数量。创建新文件时,会删除较旧的文件。如果无或0,则保留所有检查点文件。默认为5(即保留最新的5个检查点文件。)
  • keep_checkpoint_every_n_hours:多久生成一个新的检查点文件。默认为10,000小时

模型保存 

 保存我们的模型需要调用Saver.save()方法。save(sess, save_path, global_step=None),checkpoint是专有格式的二进制文件,将变量名称映射到张量值。

import tensorflow as tf

a = tf.Variable([[1.0,2.0]],name="a")
b = tf.Variable([[3.0],[4.0]],name="b")
c = tf.matmul(a,b)

saver=tf.train.Saver()
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    print(sess.run(c))
    saver.save(sess, '/tmp/ckpt/test/matmul')

在多次训练的时候可以指定多少间隔生成检查点文件

saver.save(sess, '/tmp/ckpt/test/matmu', global_step=0) ==> filename: 'matmu-0'

saver.save(sess, '/tmp/ckpt/test/matmu', global_step=1000) ==> filename: 'matmu-1000'

恢复模型

 恢复模型的方法是restore(sess, save_path),save_path是以前保存参数的路径,我们可以使用tf.train.latest_checkpoint来获取最近的检查点文件。

import tensorflow as tf

a = tf.Variable([[1.0,2.0]],name="a")
b = tf.Variable([[3.0],[4.0]],name="b")
c = tf.matmul(a,b)

saver=tf.train.Saver(max_to_keep=1)
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    print(sess.run(c))
    saver.save(sess, '/tmp/ckpt/test/matmul')

    # 恢复模型
    model_file = tf.train.latest_checkpoint('/tmp/ckpt/test/')
    saver.restore(sess, model_file)
    print(sess.run([c], feed_dict={a: [[5.0,6.0]], b: [[7.0],[8.0]]}))

自定义命令行参数

  • 利用python的argparse包进行自定义参数

import argparse
import sys
 
parser = argparse.ArgumentParser()
parser.add_argument('--fake_data', nargs='?', const=True, type=bool,
                      default=False,
                      help='If true, uses fake data for unit testing.')
parser.add_argument('--max_steps', type=int, default=1000,
                      help='Number of steps to run trainer.')
parser.add_argument('--learning_rate', type=float, default=0.001,
                      help='Initial learning rate')
parser.add_argument('--dropout', type=float, default=0.9,
                      help='Keep probability for training dropout.')
parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
                      help='Directory for storing input data')
parser.add_argument('--log_dir', type=str, default='/tmp/tensorflow/mnist/logs/mnist_with_summaries',
                      help='Summaries log directory')
args = parser.parse_args()

  •  利用tf.app.flags组件

tf.app.flags,它支持应用从命令行接受参数,可以用来指定集群配置等。在tf.app.flags下面有各种定义参数的类型

  • DEFINE_string(flag_name, default_value, docstring)

  • DEFINE_integer(flag_name, default_value, docstring)

  • DEFINE_boolean(flag_name, default_value, docstring)

  • DEFINE_float(flag_name, default_value, docstring)

第一个也就是参数的名字,路径、大小等等。第二个参数提供具体的值。第三个参数是说明文档

tf.app.flags.FLAGS,在flags有一个FLAGS标志,它在程序中可以调用到我们前面具体定义的flag_name。

import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('data_dir', '/tmp/tensorflow/mnist/input_data',
                           """数据集目录""")
tf.app.flags.DEFINE_integer('max_steps', 2000,
                            """训练次数""")
tf.app.flags.DEFINE_string('summary_dir', '/tmp/summary/mnist/convtrain',
                           """事件文件目录""")


def main(argv):
    print(FLAGS.data_dir)
    print(FLAGS.max_steps)
    print(FLAGS.summary_dir)
    print(argv)


if __name__=="__main__":
    tf.app.run()

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI追随者

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值