假设我们定义了一个keras模型,并且使用它的save_weights函数保存了一些参数.现在我们只定义这个模型的一部分,并且使用load_weights去加载我们保存的这个完整的模型,会发生什么?
首先看源代码,load_weights实际上是调用了tensorflow_core/python/keras/engine/network.py文件中Network类的load_weights函数,而在这个函数中,分别对save_format为tf和h5两种类型的文件做了不同的处理,我们这里不是h5文件,那么显然就应该是tf文件.源代码就不贴了,根据源代码,我猜测读取tf文件的函数是tensorflow_core/python/training/tracking/util.py中的TrackableSaver类的restore函数.我在这个函数的doc中发现了这样一句话.
If the checkpoint has not been consumed completely, then the list of restore ops will grow as more objects are added to the dependency graph.
这句话好像在说如果我们的checkpoing文件中的内容没有被全部加载,那么剩下没有加载的变量将会在我们使用这个图的时候动态的加载.
那么来做实验吧.首先我们定义了一个图,就是prosody-tacotron中的global style token(gst)层.这个图有两种模式,'reference'和'weight',在'reference'模式下模型参数会被完全定义,而在'weight'模式下,模型只会定义gst嵌入,而不会定义多头注意力层和卷积层的变量.下面是效果
hp.gst_mode = 'reference'
model = get_gst()
for i in model.trainable_variables:
print(i.name)
"""
输出:
gst_tokens:0
gst/multihead_attention/attention_v:0
gst/multihead_attention/attention_g:0
gst/multihead_attention/attention_b:0
gst/multihead_attention/conv1d/kernel:0
gst/multihead_attention/conv1d/bias:0
gst/multihead_attention/conv1d_1/kernel:0
gst/multihead_attention/conv1d_1/bias:0
"""
hp.gst_mode = &#