【TensorFlow系列】【二】如何从ckpt文件中拷贝权值到新的变量中

在基于TensorFlow做fine-tuning或者迁移学习时,面临的一个问题就是如何从已有的模型中,将其模型参数拷贝到自定义的新模型中。

本文讲述如下两个问题:

1、如何从ckpt模型文件中获取权值的名字?

2、如何将权值拷贝到新的变量中?

 

具体见代码注释:

import tensorflow as tf

#从ckpt文件中获取variable变量的名字
def get_trainable_variables_name_from_ckpt(meta_graph_path,ckpt_path):
    #定义一个新的graph
    graph = tf.Graph()
    #将其设置为默认图:
    with graph.as_default():
        with tf.Session() as session:
            #加载计算图
            saver = tf.train.import_meta_graph(meta_graph_path)
            #加载模型到session中关联的graph中,即将模型文件中的计算图加载到这里的graph中
            saver.restore(session,ckpt_path)
            v_names = []
            #获取session所关联的图中可被训练的variable
            #使用tf.trainable_variables()获取variable时,只有在该函数前面定义的variable才会被获取到
            #在其后面定义不会被获取到,
            for v in tf.trainable_variables():
                v_names.append(v)
            return v_names
#利用pywrap_tensorflow获取ckpt文件中的所有变量,得到的是variable名字与shape的一个map
from tensorflow.python import pywrap_tensorflow
def get_all_variables_name_from_ckpt(ckpt_path):
    reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
    all_var = reader.get_variable_to_shape_map()
    #reader.get_variable_to_dtype_map()
    return all_var


#从cpkt文件中拷贝模型的参数到自定义的变量中
def copy_var_from_ckpt(session,dst_var_name,dst_var,ckpt_path,meta_graph_path):
    #定义一个新的graph
    graph = tf.Graph()
    #将其设置为默认图:
    with graph.as_default():
        with tf.Session() as sess:
            #加载计算图
            saver = tf.train.import_meta_graph(meta_graph_path)
            #加载模型到session中关联的graph中,即将模型文件中的计算图加载到这里的graph中
            saver.restore(sess,ckpt_path)
            v_names = []
            #获取session所关联的图中可被训练的variable
            #使用tf.trainable_variables()获取variable时,只有在该函数前面定义的variable才会被获取到
            #在其后面定义不会被获取到,
            for v in tf.trainable_variables():
                v_names.append(v)
            if dst_var_name in v_names:
                #获取tensor
                tensor = graph.get_tensor_by_name(dst_var_name)
                #获取tensor的值,即网络中权值
                weight = sess.run(tensor)
                #拷贝权值,注意,需要使用dst_var所在的session
                #使用assign操作来拷贝dst_var是一个variable,weight是一个array
                session.run(dst_var.assign(weight))

 

转载于:https://my.oschina.net/u/3800567/blog/1637800

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值