干货!如何修改在TensorFlow框架下训练保存的模型参数名称

干货!如何修改在TensorFlow框架下训练保存的模型参数名称

为何要修改TensorFlow训练的模型参数名?

在TensorFlow框架下的深度学习程序中,我们将训练得到的模型参数进行保存。在我们进行某些训练任务时,也许要从已经保存的预训练模型中载入参数,或者将TensorFlow框架训练得到的参数转换到其他框架使用。在进行上述操作的时候,有可能需要将已训练的模型参数(尤其是参数名称)做出改变。

举个例子,笔者最近在做实验,打算使用DeepLab V2作基准(baseline),并下载预训练的DeepLab V2模型参数进行自己的模型的初始化,然后进行微调(finetune)。由于笔者的任务跟迁移学习比较相关,在程序中需要在某个参数域(variable_scope)下大量重用(reuse)参数。可是,网上下载的DeepLab V2模型不会考虑到笔者自己设置的参数域。因此,在使用下载的模型参数,使用键值对的方式进行参数初始化时,会报错(参数名匹配不上)。因此,笔者需要修改下载的预训练模型中的参数名称,在每个参数前面加上笔者在自己的程序中设定的参数域名。这样在程序中才可以既在指定的参数域下重用参数,又可以使用预训练的参数进行初始化

笔者观察到,介绍TensorFlow保存读取参数的博客很多,但是很少有介绍修改已保存参数的博客。而修改参数在某些情况下是会使用到的。因此在本篇博客中,笔者就介绍怎么在程序中对已保存的模型参数(名称)进行修改,在笔者的程序中,也可以对参数的值进行修改。

如何修改TensorFlow训练保存的参数名?

在笔者最开始进行参数名称修改摸索的时候,网上的资源真是少之又少,搜索了一段时间后。笔者看到一篇文档进行了介绍并附带了代码,大家可以移步这篇知乎专栏:Tensorflow修改已训练模型变量名字的方法。这篇专栏对笔者的帮助比较大,笔者也借鉴了里面的少量代码。

可是,笔者觉得上述专栏里面的做法比较繁琐。因为,从代码里面可以看到,在修改模型参数的时候,进行了读取数据流图(Graph)的操作。可是,在我们使用预训练模型初始化的时候,是按照字典,即键值对的方式进行初始化的。具体解释就是按照我们定义的参数名称,去已保存的模型参数里面读取对应的值来初始化。因此,笔者认为没有必要专门读取数据流图,并进行了更简洁的尝试。

在放出代码之前,笔者先介绍一下用到的两个重要的接口:

  1. tf.contrib.framework.list_variables。将已保存参数的(名称,形状)以列表的形式返回。在更新的TensorFlow版本中,该接口已经被整合到了tf.train.list_variables里面。
  2. tf.contrib.framework.load_variable。可以传入名称,返回读取的已保存参数的值。在更新的TensorFlow版本中,该接口已经被整合到了tf.train.load_variable里面。

在修改保存的参数名称时,做法分为以下6步:

  1. 使用list_variables函数逐个读出已保存的参数名称
  2. 使用load_variable函数逐个读取已保存的参数值
  3. 逐个修改参数名称
  4. 使用已修改的参数名称,结合tf.Variable函数逐个重建参数
  5. 将已重建的参数逐个加入新参数列表
  6. 使用tf.train.Saver().save将新参数列表写入硬盘

下面放出笔者的代码,在代码中,笔者给DeepLab V2预训练的模型参数全加上了前缀“deeplab_v2”。在这里笔者使用的还是许久之前的DeepLab预训练模型,参数保存还是一个ckpt文件(deeplab_resnet.ckpt)。代码如下:

import tensorflow as tf
import argparse
import os

parser = argparse.ArgumentParser(description='')

parser.add_argument("--checkpoint_path", default='../deeplab_resnet/deeplab_resnet.ckpt', help="restore ckpt") #原参数路径
parser.add_argument("--new_checkpoint_path", default='../deeplab_resnet_altered/', help="path_for_new ckpt") #新参数保存路径
parser.add_argument("--add_prefix", default='deeplab_v2/', help="prefix for addition") #新参数名称中加入的前缀名

args = parser.parse_args()


def main():
    if not os.path.exists(args.new_checkpoint_path):
        os.makedirs(args.new_checkpoint_path)
    with tf.Session() as sess:
        new_var_list=[] #新建一个空列表存储更新后的Variable变量
        for var_name, _ in tf.contrib.framework.list_variables(args.checkpoint_path): #得到checkpoint文件中所有的参数(名字,形状)元组
            var = tf.contrib.framework.load_variable(args.checkpoint_path, var_name) #得到上述参数的值

            new_name = var_name
            new_name = args.add_prefix + new_name #在这里加入了名称前缀,大家可以自由地作修改

            #除了修改参数名称,还可以修改参数值(var)

            print('Renaming %s to %s.' % (var_name, new_name))
            renamed_var = tf.Variable(var, name=new_name) #使用加入前缀的新名称重新构造了参数
            new_var_list.append(renamed_var) #把赋予新名称的参数加入空列表

        print('starting to write new checkpoint !')
        saver = tf.train.Saver(var_list=new_var_list) #构造一个保存器
        sess.run(tf.global_variables_initializer()) #初始化一下参数(这一步必做)
        model_name = 'deeplab_resnet_altered' #构造一个保存的模型名称
        checkpoint_path = os.path.join(args.new_checkpoint_path, model_name) #构造一下保存路径
        saver.save(sess, checkpoint_path) #直接进行保存
        print("done !")

if __name__ == '__main__':
    main()

在终端下面运行一下代码:
在这里插入图片描述
可以看到参数名称都被重置了,加上了前缀“deeplab_v2”:
在这里插入图片描述
在代码中设定的保存文件夹下,能够查看已保存的新参数名称的模型参数:
在这里插入图片描述
由于后来的TensorFlow框架在保存模型时已经放弃了保存单个ckpt文件的做法,因此都是得到4个文件,如上所示。

然后我们就可以在代码中愉快地使用新参数名称的模型进行初始化啦~

loader = tf.train.Saver(var_list=restore_vars) #设置一下要初始化哪些参数
checkpoint = tf.train.latest_checkpoint(args.checkpoint_path) #保存的新参数名的模型路径
loader.restore(sess, ckpt_path) #初始化模型参数

到这里,本篇博文就接近尾声了。本篇博文主要讲述了如何修改TensorFlow框架下训练的参数名称,核心还是找出参数名->更改参数名->重建参数->保存。笔者也衷心希望本篇博客能对大家的科研与工作有帮助。

欢迎阅读笔者后续博客,各位读者朋友的支持与鼓励是我最大的动力

written by jiong
谦,亨,君子有终。

  • 14
    点赞
  • 18
    收藏
  • 打赏
    打赏
  • 18
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
©️2022 CSDN 皮肤主题:编程工作室 设计师:CSDN官方博客 返回首页
评论 18

打赏作者

jiongnima

你的鼓励将是我创作的最大动力

¥2 ¥4 ¥6 ¥10 ¥20
输入1-500的整数
余额支付 (余额:-- )
扫码支付
扫码支付:¥2
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值