tensorflow的tf.train.Saver()函数使用小技巧

tensorflow的saver是很重要的,不光在保存模型文件的时候用到,在微调网络的过程中,加载预训练模型的时候也会用到;下面就一些实际工程中遇到的问题做一些讲解。

  • Saver类
def __init__(self,
               var_list=None,
               reshape=False,
               sharded=False,
               max_to_keep=5,
               keep_checkpoint_every_n_hours=10000.0,
               name=None,
               restore_sequentially=False,
               saver_def=None,
               builder=None,
               defer_build=False,
               allow_empty=False,
               write_version=saver_pb2.SaverDef.V2,
               pad_step_number=False,
               save_relative_paths=False,
               filename=None)
 var_list: 变量的列表,如果为None,则默认的变量为图中可保存的变量
 max_to_keep: 允许保存的最多的模型个数,当超过这个数值时,后面的模型会替换掉之前保存的模型
  • 经常用到的方法
saver = tf.train.Saver()

saver.save(sess,
           save_path,
           global_step=None,
           latest_filename=None,
           meta_graph_suffix="meta",
           write_meta_graph=True,
           write_state=True)     # 用于保存训练的模型
           
 saver.recover_last_checkpoints(checkpoint_paths)     # 用于从最近一次的训练结果恢复模型
 
 saver.restore(sess, save_path)      # 加载模型,可以指定加载某个模型,不一定非得最近一次
  • 微调网络过程中遇到的问题
  1. 网络只需加载一部分预训练模型的权重怎么办;或者说网络中某些层的权重,预训练模型中没有。
    解决办法:在定义saver对象的时候,把网络中这些层排除掉即可。然后在用restore从预训练模型中加载权重时就不会报错了,网络其余层没有从预训练模型加载权重的就需要初始化啦。
# 比如我网络中Logits层在预训练模型中没有
# 指定加载某些变量的权重
all_vars = tf.trainable_variables()
var_to_skip = [v for v in all_vars if v.name.startswith('Logits')]
print("got pretrained model, var_to_skip:\n" + " \n".join([x.name for x in var_to_skip]))
var_to_restore = [v for v in all_vars if not (v.name.startswith('Logits'))]
saver = tf.train.Saver(var_to_restore, max_to_keep=20)
sess.run(tf.global_variables_initializer())     # 初始化其余层的变量
saver.restore(sess, pretrained_model)           # 利用saver.restore恢复指定层的权重
  1. 保存是时候还要用之前定义的saver吗?
    我们之前定义的saver为了正确加载预训练模型,是把网络中以‘Logits’开头的变量排除了的;所以,如果还用这个saver来save训练模型的话,模型中会没有‘Logits’层的权重的。
    解决办法:重新再定义一个包含网络全部变量的saver对象用于保存模型,一个图中可以定义多个saver对象哟~
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值