首先对预训练模型的scope一定要做好定义,不然恢复起来会比较麻烦。
这里使用tf.get_collection()
1、tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='name')
tf.get_collection(
key,
scope=None
)
Args:
key
: The key for the collection. For example, theGraphKeys
class contains many standard names for collections.scope
: (Optional.) If supplied, the resulting list is filtered to include only items whosename
attribute matches usingre.match
. Items without aname
attribute are never returned if a scope is supplied and the choice orre.match
means that ascope
without special tokens filters by prefix.
Returns:
The list of values in the collection with the given name
, or an empty list if no value has been added to that collection. The list contains the values in the order under which they were collected.
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)用于获取当前图下,给定指定name的所有变量,并返回由这些变量构成的list。
2、申请saver
saver = tf.train.Saver(var_list=var)
这里表示当前的这个saver只对var中的变量进行恢复,其余的不管
3、载入之前预训练的ckpt
saver.restore(sess,MODELPATH)
这里表示指定恢复的变量的权重是从MODELPATH里面来的,MODELPATH是之前预训练模型的ckpt
如果有多个这样的scope需要恢复的话可以多次重复上述步骤
最后,代码总结
var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='name1')
saver.restore(sess,MODELPATH1)
var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='name2')
saver = tf.train.Saver(var_list=var)