本文转载自:https://blog.csdn.net/huachao1001/article/details/110957491?spm=1001.2014.3001.5501
本文介绍一些不常规的操作:
如何只加载部分参数? 如何从两个模型中加载不同部分参数?
当预训练的模型的命名与当前定义的网络中的参数命名不一致时该怎么办?
1 只加载部分参数
举个例子,对已有的网络结构做了细微修改,例如只改了几层卷积通道数。如果从头训练显然没有finetune收敛速度快,但是模型又没法全部加载。此时,只需将未修改部分参数加载到当前网络即可。假设修改过的卷积层名称包含`conv_``,示例代码如下:
import tensorflow as tf
def restore(sess, ckpt_path):
vars = tf.trainable_variables()
vars = [v for v vars if not "conv_1" in v.name]
saver = tf.train.Saver(var_list=vars)
saver.restore(sess, ckpt_path)
2 从两个预训练模型中加载不同部分参数
如果需要从两个不同的预训练模型中加载不同部分参数,例如,网络中的前半部分用一个预训练模型参数,后半部分用另一个预训练模型中的参数,示例代码如下:
import tensorflow as tf
def restore(sess, ckpt_path):
vars = tf.trainable_variables()
model_1_vars = [v for v vars if "model_1" in v.name]
model_2_vars = [v for v vars if "model_2" in v.name]
saver_1 = tf.train.Saver(var_list=model_1_vars)
saver_2 = tf.train.Saver(var_list=model_2_vars)
saver_1 .restore(sess, ckpt_path)
saver_2 .restore(sess, ckpt_path)
3 从参数名称不一致的模型中加载参数
举个例子,例如,预训练的模型所有的参数有个前缀name_1,现在定义的网络结构中的参数以name_2作为前缀。那么使用如下示例代码即可加载:
import tensorflow as tf
def restore(sess, ckpt_path):
vars = tf.trainable_variables()
vars_dict = dict()
for v in vars:
key = v.name.split(':')[0]
if key.startswith("name_2/"):
key = key.replace("name_2/", "name_1/")
vars_dict[key] = v
saver =tf.train.Saver(var_list=vars_dict)
saver.restore(sess, ckpt_path)
注意: 使用上面代码时,要确保参数的shape一致,否则会无法加载参数。
如果不知道预训练的ckpt中参数名称,可以使用如下代码打印:
for name, shape in tf.train.list_variables(ckpt_path):
print(name)
# 或者
var_list = tf.compat.v1.get_collection(tf.v1.GraphKeys.GLOBAL_VARIABLES)
# 或者
names = [n.name for n in session.graph_def.node]
其实如果了解Tensorflow的图原理和基本特性,就比较好理解上面的操作了,推荐:
深入理解tensorflow的基本概念:Graph、Operation、Tensor、Node的区别