[tf1] 保存和加载参数

tf.keras

参考 https://github.com/tensorflow/docs/blob/529ba4346b8fc5e830e762a2f0ee87b3c345c0c9/site/en/r1/guide/keras.ipynb

# Save weights to a TensorFlow Checkpoint file
model.save_weights('./weights/my_model')

# Restore the model's state,
# this requires a model with the same architecture.
model.load_weights('./weights/my_model')

也可以保存为 Keras HDF5 格式

# Save weights to a HDF5 file
model.save_weights('my_model.h5', save_format='h5')

# Restore the model's state
model.load_weights('my_model.h5')

注意,h5 可能会遇到这个问题 https://stackoverflow.com/questions/53740577/does-any-one-got-attributeerror-str-object-has-no-attribute-decode-whi

model.load_weights('my_model.h5') 时会报错 'str' object has no attribute 'decode'。解决办法:

pip freeze | grep h5py
pip install h5py==2.10.0

tf.train.Saver

可以保存指定参数。参考 https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/Saver

self._saver = tf.train.Saver(var_list=self._get_var_list(), \
				max_to_keep=self.max_tf_checkpoints_to_keep)
self._saver.save(
	self._sess,
	os.path.join(checkpoint_dir, 'tf_ckpt'),
	global_step=self.iteration)
self._saver.restore(self._sess,
                    os.path.join(checkpoint_dir,
                                 'tf_ckpt-{}'.format(iteration_number)))

tf.train.Checkpoint

不懂。参考 https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/train/Checkpoint

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值