Tensorflow中保存与恢复模型tf.train.Saver类讲解(恢复部分模型参数的方法)

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接: https://blog.csdn.net/mieleizhi0522/article/details/80535189

       有兴趣的可以加qq群点击链接加入群聊【深度学习交流】:

      前几天一直在修改模型,但是在修改的时候要加载原始预训练模型,我现在修改过的模型(现模型)有新加的参数,而有些预训练模型中的参数也没有用到,所以这样的情况下对于预训练模型来说,就相当于加载部分模型参数了,然后现模型中的剩余的参数就通过手动初始化完成,其实在加载模型的时候就相当于初始化参数。

        也就是说现模型的参数初始化分为两部分:

           一,加载部分预训练模型的参数。

          二,手动初始化剩下的(预训练模型中没有的)参数。

 

 

 

在做这些之前,先对Saver类说明一下,其中有一个很重要的点要get到:


 
 
  1. ...
  2. # Create a saver.
  3. saver = tf.train.Saver(...variables...)
  4. # Launch the graph and train, saving the model every 1,000 steps.
  5. sess = tf.Session()
  6. for step in xrange( 1000000):
  7.     sess.run(..training_op..)
  8.     if step % 1000 == 0:
  9.         # Append the step number to the checkpoint name:
  10.         saver.save(sess, 'my-model', global_step=step)

 

 

这个是官网的一个例子,请看下面这一句:

 

saver = tf.train.Saver(...variables...)

其中这个Saver是一个类,上面的那一句就是通过类取得Saver的对象,里面的variables是构造函数传入的参数,请看这个构造函数对这个参数的解释:

 

__init__


 
 
  1. __init__(var_list= None, reshape= False, sharded= False,   
  2. max_to_keep= 5,    keep_checkpoint_every_n_hours= 10000.0,  name= None,    restore_sequentially= False,    saver_def= None
  3.   builder= None,    defer_build= False,   
  4. allow_empty= False,    write_version=tf.train.SaverDef.V2, 
  5.   pad_step_number= False,    save_relative_paths= False
  6.   filename= None)

__init__是构造器,里面可以传很多参数,其中第一个参数就是var_list,也就是上面的variables.

下面是对var_list参数的解释:

Creates a Saver.

The constructor adds ops to save and restore variables.

var_list specifies the variables that will be saved and restored. It can be passed as a dict or a list:

  • dict of names to variables: The keys are the names that will be used to save or restore the variables in the checkpoint files.
  • A list of variables: The variables will be keyed with their op name in the checkpoint files.

注意到红字所表达的意思:var_list指定要保存和恢复的变量。                             

 

 

 

所以里面传的参数是要保存和恢复的变量,举个例子说明问题:

保存参数:


 
 
  1. weight=[weights[ 'wc1'],weights[ 'wc2'],weights[ 'wc3a']]
  2. saver = tf.train.Saver(weight) #创建一个saver对象,.values是以列表的形式获取字典值
  3. saver.save(sess, 'model.ckpt')

上面的意思是,只保存weight里的这些变量,如果saver=tf.train.Saver()里面不传入参数,默认保存全部变量

恢复参数:


 
 
  1. weight=[weights[ 'wc1'],weights[ 'wc2'],weights[ 'wc3a']]
  2. saver = tf.train.Saver(weight) #创建一个saver对象,.values是以列表的形式获取字典值
  3. saver.restore(sess, model_filename)

上面这个恢复参数要注意,model_filename是你要恢复的模型文件,整段代码的意思是从model_filename文件里只恢复weight的这些参数,如果model_filename里面没有这些参数,则报错。(当然这些变量你不一定都一一列出,你可以通过遍历的算法得到,详细请看下面的参考文献)

 

 

 

像我的这种情况应该怎么恢复变量呢,也是分为两步:

一,恢复部分预训练模型的参数。

 


 
 
  1. weight=[weights[ 'wc1'],weights[ 'wc2'],weights[ 'wc3a']]
  2. saver = tf.train.Saver(weight) #创建一个saver对象,.values是以列表的形式获取字典值
  3. saver.restore(sess, model_filename)

二,手动初始化剩下的(预训练模型中没有的)参数。

var = tf.get_variable(name, shape, initializer=tf.contrib.layers.xavier_initializer())
 
 

 

保存的时候怎么保存呢?我想保存全部变量,所以要重新写一个对象,名字和恢复的那个saver对象不同:


 
 
  1. saver_out=tf.train.Saver()
  2. saver_out.save(sess, 'file_name')

这个时候就保存了全部变量,如果你想保存部分变量,只需要在构造器里传入想要保存的变量的名字就行了。

通过一段代码看看预训练模型文件里都是什么东西吧:


 
 
  1. import tensorflow as tf
  2. import os
  3. from tensorflow.python import pywrap_tensorflow
  4. model_dir= r'G:\KeTi\C3D'
  5. checkpoint_path = os.path.join(model_dir, "sports1m_finetuning_ucf101.model")
  6. # 从checkpoint中读出数据
  7. reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
  8. # reader = tf.train.NewCheckpointReader(checkpoint_path) # 用tf.train中的NewCheckpointReader方法
  9. var_to_shape_map = reader.get_variable_to_shape_map()
  10. # 输出权重tensor名字和值
  11. for key in var_to_shape_map:
  12. print( "tensor_name: ", key,reader.get_tensor(key).shape)

输出:


 
 
  1. tensor_name: var_name/wc4a ( 3, 3, 3, 256, 512)
  2. tensor_name: var_name/wc3a ( 3, 3, 3, 128, 256)
  3. tensor_name: var_name/wd1 ( 8192, 4096)
  4. tensor_name: var_name/wc5b ( 3, 3, 3, 512, 512)
  5. tensor_name: var_name/bd1 ( 4096,)
  6. tensor_name: var_name/wd2 ( 4096, 4096)
  7. tensor_name: var_name/wout ( 4096, 101)
  8. tensor_name: var_name/wc1 ( 3, 3, 3, 3, 64)
  9. tensor_name: var_name/bc4b ( 512,)
  10. tensor_name: var_name/wc2 ( 3, 3, 3, 64, 128)
  11. tensor_name: var_name/bc3a ( 256,)
  12. tensor_name: var_name/bd2 ( 4096,)
  13. tensor_name: var_name/bc5a ( 512,)
  14. tensor_name: var_name/bc2 ( 128,)
  15. tensor_name: var_name/bc5b ( 512,)
  16. tensor_name: var_name/bout ( 101,)
  17. tensor_name: var_name/bc4a ( 512,)
  18. tensor_name: var_name/bc3b ( 256,)
  19. tensor_name: var_name/wc4b ( 3, 3, 3, 512, 512)
  20. tensor_name: var_name/bc1 ( 64,)
  21. tensor_name: var_name/wc3b ( 3, 3, 3, 256, 256)
  22. tensor_name: var_name/wc5a ( 3, 3, 3, 512, 512)

都是权重和偏置

 

更多关于变量恢复的文件类型问题,请参考:

1.https://blog.csdn.net/leo_xu06/article/details/79200634 

2.https://blog.csdn.net/b876144622/article/details/79962727

                                                                                微笑

 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值