Tensorflow加载预训练模型的特殊操作

本文转载自:https://blog.csdn.net/huachao1001/article/details/110957491?spm=1001.2014.3001.5501

本文介绍一些不常规的操作:

如何只加载部分参数? 如何从两个模型中加载不同部分参数?
当预训练的模型的命名与当前定义的网络中的参数命名不一致时该怎么办?

1 只加载部分参数

举个例子,对已有的网络结构做了细微修改,例如只改了几层卷积通道数。如果从头训练显然没有finetune收敛速度快,但是模型又没法全部加载。此时,只需将未修改部分参数加载到当前网络即可。假设修改过的卷积层名称包含`conv_``,示例代码如下:

import tensorflow as tf
def restore(sess, ckpt_path):
	vars = tf.trainable_variables()
	vars = [v for v vars if not "conv_1" in v.name]
    saver = tf.train.Saver(var_list=vars)
	saver.restore(sess, ckpt_path)

2 从两个预训练模型中加载不同部分参数
如果需要从两个不同的预训练模型中加载不同部分参数,例如,网络中的前半部分用一个预训练模型参数,后半部分用另一个预训练模型中的参数,示例代码如下:

import tensorflow as tf
def restore(sess, ckpt_path):
	vars = tf.trainable_variables()
	model_1_vars = [v for v vars if "model_1" in v.name]
	model_2_vars = [v for v vars if "model_2" in v.name]
    saver_1 = tf.train.Saver(var_list=model_1_vars)
    saver_2 = tf.train.Saver(var_list=model_2_vars)
	saver_1 .restore(sess, ckpt_path)
	saver_2 .restore(sess, ckpt_path)

3 从参数名称不一致的模型中加载参数

举个例子,例如,预训练的模型所有的参数有个前缀name_1,现在定义的网络结构中的参数以name_2作为前缀。那么使用如下示例代码即可加载:

import tensorflow as tf
def restore(sess, ckpt_path):
	vars = tf.trainable_variables()
	vars_dict = dict()
	for v in vars:
	    key = v.name.split(':')[0]
	    if key.startswith("name_2/"):
	        key = key.replace("name_2/", "name_1/")
	    vars_dict[key] = v
	saver =tf.train.Saver(var_list=vars_dict)
	saver.restore(sess, ckpt_path)

注意: 使用上面代码时,要确保参数的shape一致,否则会无法加载参数。

如果不知道预训练的ckpt中参数名称,可以使用如下代码打印:

for name, shape in tf.train.list_variables(ckpt_path):
    print(name)

# 或者
var_list = tf.compat.v1.get_collection(tf.v1.GraphKeys.GLOBAL_VARIABLES)

# 或者
names = [n.name for n in session.graph_def.node]

其实如果了解Tensorflow的图原理和基本特性,就比较好理解上面的操作了,推荐:
深入理解tensorflow的基本概念:Graph、Operation、Tensor、Node的区别

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值