tensorflow variable的保存和加载

本文探讨了如何使用tf.train.Saver在TensorFlow中保存和加载变量。通过实验展示了保存全部或部分变量、恢复时的名称匹配规则、在不初始化变量时的错误处理。还介绍了如何加载部分变量进行迁移学习,并且详细阐述了如何处理变量重命名的情况,以便在模型调整后仍能使用旧的ckpt文件。
摘要由CSDN通过智能技术生成

tensorflow提了供tf.train.saver类已完成variable的保存和加载。其中save方法可以用来将计算图中的variable全部或者部分存储到ckpt文件,restore方法可以将ckpt文件中的全部或者部分变量导入计算图中。按照官方定义ckpt文件的作用是: map variable names to tensor values

variable存储和加载的一组实验:

saver_test.py
import tensorflow as tf

# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)
v3 = tf.get_variable("v3", shape=[4], initializer = tf.zeros_initializer)

inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)
dec_v3 = v3.assign(v3-2)

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  inc_v1.op.run()
  dec_v2.op.run()
  dec_v3.op.run()

  print sess.run(v1)
  print sess.run(v2)
  print sess.run(v3)

  # Save the variables to disk.^
  save_path = saver.save(sess, "./saved_model/model.ckpt")
  print("Model saved in file: %s" % save_path)
restore_test.py
import tensorflow as tf

tf.reset_default_graph()

# Create some variables.
v1 = tf.get_variable(
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值