saver = tf.train.Saver(...variables...)
__init__(
var_list=None, #var_list指定要保存和恢复的变量
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=tf.train.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None
)
var_list :specifies the variables that will be saved and restored. It can be passed as a dict or a list.
A dict of names to variables:
The keys are the names that will be used to save or restore the variables in the checkpoint files.
A list of variables:
The variables will be keyed with their op name in the checkpoint files.
eg:保存参数:
weight=[weights['wc1'],weights['wc2'],weights['wc3a']]
saver = tf.train.Saver(weight)#创建一个saver对象,.values是以列表的形式获取字典值
saver.save(sess,'model.ckpt')
eg:恢复参数:
weight=[weights['wc1'],weights['wc2'],weights['wc3a']]
saver = tf.train.Saver(weight)#创建一个saver对象,.values是以列表的形式获取字典值
saver.restore(sess, model_filename)
一,恢复部分预训练模型的参数。
weight=[weights['wc1'],weights['wc2'],weights['wc3a']]
saver = tf.train.Saver(weight)#创建一个saver对象,.values是以列表的形式获取字典值
saver.restore(sess, model_filename)
二,手动初始化剩下的(预训练模型中没有的)参数。
var = tf.get_variable(name, shape, initializer=tf.contrib.layers.xavier_initializer())
三.我想保存全部变量,所以要重新写一个对象,名字和恢复的那个saver对象不同:
saver_out=tf.train.Saver()
saver_out.save(sess,'file_name')
这个时候就保存了全部变量,如果你想保存部分变量,只需要在构造器里传入想要保存的变量的名字就行了。
查看预训练模型文件内容
import tensorflow as tf
import os
from tensorflow.python import pywrap_tensorflow
model_dir=r'G:\KeTi\C3D'
checkpoint_path = os.path.join(model_dir, "sports1m_finetuning_ucf101.model")
# 从checkpoint中读出数据
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
# reader = tf.train.NewCheckpointReader(checkpoint_path) # 用tf.train中的NewCheckpointReader方法
var_to_shape_map = reader.get_variable_to_shape_map()
# 输出权重tensor名字和值
for key in var_to_shape_map:
print("tensor_name: ", key,reader.get_tensor(key).shape)
output:
tensor_name: var_name/wc4a (3, 3, 3, 256, 512)
tensor_name: var_name/wc3a (3, 3, 3, 128, 256)
tensor_name: var_name/wd1 (8192, 4096)
tensor_name: var_name/wc5b (3, 3, 3, 512, 512)
tensor_name: var_name/bd1 (4096,)
tensor_name: var_name/wd2 (4096, 4096)
tensor_name: var_name/wout (4096, 101)
tensor_name: var_name/wc1 (3, 3, 3, 3, 64)
tensor_name: var_name/bc4b (512,)
tensor_name: var_name/wc2 (3, 3, 3, 64, 128)
tensor_name: var_name/bc3a (256,)
tensor_name: var_name/bd2 (4096,)
tensor_name: var_name/bc5a (512,)
tensor_name: var_name/bc2 (128,)
tensor_name: var_name/bc5b (512,)
tensor_name: var_name/bout (101,)
tensor_name: var_name/bc4a (512,)
tensor_name: var_name/bc3b (256,)
tensor_name: var_name/wc4b (3, 3, 3, 512, 512)
tensor_name: var_name/bc1 (64,)
tensor_name: var_name/wc3b (3, 3, 3, 256, 256)
tensor_name: var_name/wc5a (3, 3, 3, 512, 512)
参考:
https://blog.csdn.net/mieleizhi0522/article/details/80535189