有兴趣的可以加qq群点击链接加入群聊【深度学习交流】:
前几天一直在修改模型,但是在修改的时候要加载原始预训练模型,我现在修改过的模型(现模型)有新加的参数,而有些预训练模型中的参数也没有用到,所以这样的情况下对于预训练模型来说,就相当于加载部分模型参数了,然后现模型中的剩余的参数就通过手动初始化完成,其实在加载模型的时候就相当于初始化参数。
也就是说现模型的参数初始化分为两部分:
一,加载部分预训练模型的参数。
二,手动初始化剩下的(预训练模型中没有的)参数。
在做这些之前,先对Saver类说明一下,其中有一个很重要的点要get到:
-
...
-
# Create a saver.
-
saver = tf.train.Saver(...variables...)
-
# Launch the graph and train, saving the model every 1,000 steps.
-
sess = tf.Session()
-
for step
in xrange(
1000000):
-
sess.run(..training_op..)
-
if step %
1000 ==
0:
-
# Append the step number to the checkpoint name:
-
saver.save(sess,
'my-model', global_step=step)
这个是官网的一个例子,请看下面这一句:
saver = tf.train.Saver(...variables...)
其中这个Saver是一个类,上面的那一句就是通过类取得Saver的对象,里面的variables是构造函数传入的参数,请看这个构造函数对这个参数的解释:
__init__
-
__init__(var_list=
None, reshape=
False, sharded=
False,
-
max_to_keep=
5, keep_checkpoint_every_n_hours=
10000.0, name=
None, restore_sequentially=
False, saver_def=
None,
-
builder=
None, defer_build=
False,
-
allow_empty=
False, write_version=tf.train.SaverDef.V2,
-
pad_step_number=
False, save_relative_paths=
False,
-
filename=
None)
__init__是构造器,里面可以传很多参数,其中第一个参数就是var_list,也就是上面的variables.
下面是对var_list参数的解释:
Creates a Saver
.
The constructor adds ops to save and restore variables.
var_list
specifies the variables that will be saved and restored. It can be passed as a dict
or a list:
- A
dict
of names to variables: The keys are the names that will be used to save or restore the variables in the checkpoint files. - A list of variables: The variables will be keyed with their op name in the checkpoint files.
注意到红字所表达的意思:var_list指定要保存和恢复的变量。
所以里面传的参数是要保存和恢复的变量,举个例子说明问题:
保存参数:
-
weight=[weights[
'wc1'],weights[
'wc2'],weights[
'wc3a']]
-
saver = tf.train.Saver(weight)
#创建一个saver对象,.values是以列表的形式获取字典值
-
saver.save(sess,
'model.ckpt')
上面的意思是,只保存weight里的这些变量,如果saver=tf.train.Saver()里面不传入参数,默认保存全部变量
恢复参数:
-
weight=[weights[
'wc1'],weights[
'wc2'],weights[
'wc3a']]
-
saver = tf.train.Saver(weight)
#创建一个saver对象,.values是以列表的形式获取字典值
-
saver.restore(sess, model_filename)
上面这个恢复参数要注意,model_filename是你要恢复的模型文件,整段代码的意思是从model_filename文件里只恢复weight的这些参数,如果model_filename里面没有这些参数,则报错。(当然这些变量你不一定都一一列出,你可以通过遍历的算法得到,详细请看下面的参考文献)
像我的这种情况应该怎么恢复变量呢,也是分为两步:
一,恢复部分预训练模型的参数。
-
weight=[weights[
'wc1'],weights[
'wc2'],weights[
'wc3a']]
-
saver = tf.train.Saver(weight)
#创建一个saver对象,.values是以列表的形式获取字典值
-
saver.restore(sess, model_filename)
二,手动初始化剩下的(预训练模型中没有的)参数。
var = tf.get_variable(name, shape, initializer=tf.contrib.layers.xavier_initializer())
保存的时候怎么保存呢?我想保存全部变量,所以要重新写一个对象,名字和恢复的那个saver对象不同:
-
saver_out=tf.train.Saver()
-
saver_out.save(sess,
'file_name')
这个时候就保存了全部变量,如果你想保存部分变量,只需要在构造器里传入想要保存的变量的名字就行了。
通过一段代码看看预训练模型文件里都是什么东西吧:
-
import tensorflow
as tf
-
-
import os
-
from tensorflow.python
import pywrap_tensorflow
-
model_dir=
r'G:\KeTi\C3D'
-
checkpoint_path = os.path.join(model_dir,
"sports1m_finetuning_ucf101.model")
-
# 从checkpoint中读出数据
-
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
-
# reader = tf.train.NewCheckpointReader(checkpoint_path) # 用tf.train中的NewCheckpointReader方法
-
var_to_shape_map = reader.get_variable_to_shape_map()
-
# 输出权重tensor名字和值
-
for key
in var_to_shape_map:
-
print(
"tensor_name: ", key,reader.get_tensor(key).shape)
输出:
-
tensor_name: var_name/wc4a (
3,
3,
3,
256,
512)
-
tensor_name: var_name/wc3a (
3,
3,
3,
128,
256)
-
tensor_name: var_name/wd1 (
8192,
4096)
-
tensor_name: var_name/wc5b (
3,
3,
3,
512,
512)
-
tensor_name: var_name/bd1 (
4096,)
-
tensor_name: var_name/wd2 (
4096,
4096)
-
tensor_name: var_name/wout (
4096,
101)
-
tensor_name: var_name/wc1 (
3,
3,
3,
3,
64)
-
tensor_name: var_name/bc4b (
512,)
-
tensor_name: var_name/wc2 (
3,
3,
3,
64,
128)
-
tensor_name: var_name/bc3a (
256,)
-
tensor_name: var_name/bd2 (
4096,)
-
tensor_name: var_name/bc5a (
512,)
-
tensor_name: var_name/bc2 (
128,)
-
tensor_name: var_name/bc5b (
512,)
-
tensor_name: var_name/bout (
101,)
-
tensor_name: var_name/bc4a (
512,)
-
tensor_name: var_name/bc3b (
256,)
-
tensor_name: var_name/wc4b (
3,
3,
3,
512,
512)
-
tensor_name: var_name/bc1 (
64,)
-
tensor_name: var_name/wc3b (
3,
3,
3,
256,
256)
-
tensor_name: var_name/wc5a (
3,
3,
3,
512,
512)
都是权重和偏置
更多关于变量恢复的文件类型问题,请参考:
1.https://blog.csdn.net/leo_xu06/article/details/79200634
2.https://blog.csdn.net/b876144622/article/details/79962727