tf.train.Saver()

 

saver = tf.train.Saver(...variables...)

__init__(
    var_list=None, #var_list指定要保存和恢复的变量
    reshape=False,
    sharded=False,
    max_to_keep=5,
    keep_checkpoint_every_n_hours=10000.0,
    name=None,
    restore_sequentially=False,
    saver_def=None,
    builder=None,
    defer_build=False,
    allow_empty=False,
    write_version=tf.train.SaverDef.V2,
    pad_step_number=False,
    save_relative_paths=False,
    filename=None
)

var_list :specifies the variables that will be saved and restored. It can be passed as a dict or a list.
    A dict of names to variables:
        The keys are the names that will be used to save or restore the variables in the checkpoint files.
    A list of variables:
        The variables will be keyed with their op name in the checkpoint files.

 

eg:保存参数:

weight=[weights['wc1'],weights['wc2'],weights['wc3a']]
saver = tf.train.Saver(weight)#创建一个saver对象,.values是以列表的形式获取字典值
saver.save(sess,'model.ckpt')

eg:恢复参数:

weight=[weights['wc1'],weights['wc2'],weights['wc3a']]
saver = tf.train.Saver(weight)#创建一个saver对象,.values是以列表的形式获取字典值
saver.restore(sess, model_filename)

 

一,恢复部分预训练模型的参数。

weight=[weights['wc1'],weights['wc2'],weights['wc3a']]
saver = tf.train.Saver(weight)#创建一个saver对象,.values是以列表的形式获取字典值
saver.restore(sess, model_filename)

二,手动初始化剩下的(预训练模型中没有的)参数。

var = tf.get_variable(name, shape, initializer=tf.contrib.layers.xavier_initializer())

三.我想保存全部变量,所以要重新写一个对象,名字和恢复的那个saver对象不同:

saver_out=tf.train.Saver()
saver_out.save(sess,'file_name')

这个时候就保存了全部变量,如果你想保存部分变量,只需要在构造器里传入想要保存的变量的名字就行了。

 

查看预训练模型文件内容

import tensorflow as tf  
 
import os
from tensorflow.python import pywrap_tensorflow
model_dir=r'G:\KeTi\C3D'
checkpoint_path = os.path.join(model_dir, "sports1m_finetuning_ucf101.model")
# 从checkpoint中读出数据
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
# reader = tf.train.NewCheckpointReader(checkpoint_path) # 用tf.train中的NewCheckpointReader方法
var_to_shape_map = reader.get_variable_to_shape_map()
# 输出权重tensor名字和值
for key in var_to_shape_map:
    print("tensor_name: ", key,reader.get_tensor(key).shape)


output:
tensor_name:  var_name/wc4a (3, 3, 3, 256, 512)
tensor_name:  var_name/wc3a (3, 3, 3, 128, 256)
tensor_name:  var_name/wd1 (8192, 4096)
tensor_name:  var_name/wc5b (3, 3, 3, 512, 512)
tensor_name:  var_name/bd1 (4096,)
tensor_name:  var_name/wd2 (4096, 4096)
tensor_name:  var_name/wout (4096, 101)
tensor_name:  var_name/wc1 (3, 3, 3, 3, 64)
tensor_name:  var_name/bc4b (512,)
tensor_name:  var_name/wc2 (3, 3, 3, 64, 128)
tensor_name:  var_name/bc3a (256,)
tensor_name:  var_name/bd2 (4096,)
tensor_name:  var_name/bc5a (512,)
tensor_name:  var_name/bc2 (128,)
tensor_name:  var_name/bc5b (512,)
tensor_name:  var_name/bout (101,)
tensor_name:  var_name/bc4a (512,)
tensor_name:  var_name/bc3b (256,)
tensor_name:  var_name/wc4b (3, 3, 3, 512, 512)
tensor_name:  var_name/bc1 (64,)
tensor_name:  var_name/wc3b (3, 3, 3, 256, 256)
tensor_name:  var_name/wc5a (3, 3, 3, 512, 512)

参考:

https://blog.csdn.net/mieleizhi0522/article/details/80535189

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值