tensorflow saver 保存和恢复指定 tensor

原创 2017年04月25日 10:36:19

在实践中经常会遇到这样的情况:

1, 用简单的模型预训练参数

2, 把预训练的参数导入复杂的模型后训练复杂的模型

这时就产生一个问题:

                如何加载预训练的参数。

下面就是我的总结。

为了方便说明,做一个假设:               简单的模型只有一个卷基层,复杂模型有两个。

                卷积层的实现代码如下:

import tensorflow as tf
# PS:本篇的重担是saver,不过为了方便阅读还是说明下参数
# 参数
# name:创建卷基层的代码这么多,必须要函数化,而为了防止变量冲突就需要用tf.name_scope
# input_data:输入数据
# width, high:卷积小窗口的宽、高
# deep_before, deep_after:卷积前后的神经元数量
# stride:卷积小窗口的移动步长
def make_conv(name, input_data, width, high, deep_before,deep_after, stride, padding_type='SAME'):
    global parameters
    with tf.name_scope(name) asscope:
        weights =tf.Variable(tf.truncated_normal([width, high, deep_before, deep_after],
            dtype=tf.float32,stddev=0.01), trainable=True, name='weights')
        biases =tf.Variable(tf.constant(0.1, shape=[deep_after]), trainable=True, name='biases')
        conv =tf.nn.conv2d(input_data, weights, [1, stride, stride, 1], padding=padding_type)
        bias = tf.add(conv,biases)
        bias = batch_norm(bias,deep_after, 1) # batch_norm是自己写的batchnorm函数
        conv =tf.maximum(0.1*bias, bias)
        return conv

简单的预训练模型就下面一句话

conv1 =make_conv('simple-conv1', images, 3, 3, 3, 32, 1)

复杂的模型是两个卷基层,如下:

conv1 = make_conv('complex-conv1',images, 3, 3, 3, 32, 1)
pool1= make_max_pool('layer1-pool1', conv1, 2, 2)
conv2= make_conv('complex-conv2', pool1, 3, 3, 32, 64, 1)

这时简简单单的在预训练模型中:

saver = tf.train.Saver()
with tf.Session() as sess:
    saver.save(sess,'model.ckpt')

就不行了,因为:

    1,如果你在预训练模型中使用下面的话打印所有tensor

all_v =tf.global_variables()
for i in all_v:  print  i

    会发现tensor的名字不是weights和biases,而是'simple-conv1/weights和'simple-conv1/biases,如下:

        <tf.Variable'simple-conv1/weights:0' shape=(3, 3, 3, 32) dtype=float32_ref>

        <tf.Variable'simple-conv1/biases:0' shape=(32,) dtype=float32_ref>

        <tf.Variable 'simple-conv1/Variable:0' shape=(32,)dtype=float32_ref>

        <tf.Variable 'simple-conv1/Variable_1:0' shape=(32,)dtype=float32_ref>

        <tf.Variable 'simple-conv1/Variable_2:0' shape=(32,)dtype=float32_ref>

        <tf.Variable 'simple-conv1/Variable_3:0' shape=(32,)dtype=float32_ref>

    同理,在复杂模型中就是complex-conv1/weights和complex-conv1/biases,这是对不上号的。

     2,预训练模型中只有1个卷积层,而复杂模型中有两个,而tensorflow默认会从模型文件('model.ckpt')中找所有的“可训练的”tensor,找不到会报错。

解决方法:

    1,在预训练模型中定义全局变量

parm_dict={}

    并在“return conv”上面添加下面两行

parm_dict['complex-conv1/weights']= weights
parm_dict['complex-conv1/']= biases

    然后在定义saver时使用下面这句话:

saver= tf.train.Saver(parm_dict)

    这样保存后的模型文件就对应到复杂模型上了。

    2,在复杂模型中定义全局变量

parameters= []

    并在“return conv”上面添加下面行

parameters+= [weights, biases]

    然后判断如果是第二个卷积层就不更新parameters。

    接着在定义saver时使用下面这句话:

saver= tf.train.Saver(parameters)

    这样就可以告诉saver,只需要从模型文件中找weights和biases,而那些什么complex-conv1/Variable~ complex-conv1/Variable_3统统滚一边去(上面红色部分)。

    最后使用下面的代码加载就可以了                              

with tf.Session() as sess:
    ckpt= tf.train.get_checkpoint_state('.')
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess,ckpt.model_checkpoint_path)
    else:
        print '  no saver.'
        exit()                    





                
版权声明:本文为博主原创文章,未经博主允许不得转载。

TensorFlow利用saver保存和提取参数

保存参数: import tensorflow as tf W = tf.Variable([[1, 2, 3]], dtype=tf.float32) b = tf.Variable([1], d...
  • winycg
  • winycg
  • 2017年11月19日 01:07
  • 215

TensorFlow saver之指定变量的存取

今天和大家分享一下用TensorFlow的saver存取训练好的模型那点事。 1. 用saver存取变量; 2. 用saver存取指定变量。 用saver存取变量。 话不多说,先上代码# c...
  • main_h_
  • main_h_
  • 2017年07月03日 14:07
  • 574

tensorflow学习笔记(五):变量保存与导入

如何使用tensorflow内置的参数导出和导入方法:基本用法如果你还在纠结如何保存tensorflow训练好的模型参数,用这个方法就对了import tensorflow as tf """ 变量声...
  • u012436149
  • u012436149
  • 2016年10月21日 15:26
  • 11888

tensorflow保存变量出现错误(提示不能save)

错误名称:Tensorflow - ValueError: Parent directory of trained_variables.ckpt doesn’t exist, can’t saveTh...
  • a18852867035
  • a18852867035
  • 2017年04月28日 22:55
  • 3219

关于Tensorflow计算图与Tensor的理解

关于Tensorflow计算模型tensorflow的编程和我以往接触的编程方式有很大差异。以前的编程,无论是编译类型的语言还是脚本语言,都是一步一步的,变量计算后,就会得到结果,比如c=a+b,当执...
  • qian99
  • qian99
  • 2017年04月23日 14:51
  • 3406

Tensorflow lesson 3---变量Variable

Tensorflow中的变量就是一个放在内存中的tensor结构,用于在计算过程中保存数据,变量的数值可以保存到文件中,也可以从文件中读取1.变量的初始化import tensorflow as tf...
  • mwlwlm
  • mwlwlm
  • 2017年05月10日 14:33
  • 839

tensorflow笔记 :常用函数说明

本文章内容比较繁杂,主要是一些比较常用的函数的用法,结合了网上的资料和源码,还有我自己写的示例代码。建议照着目录来看。1.矩阵操作1.1矩阵生成这部分主要将如何生成矩阵,包括全0矩阵,全1矩阵,随机数...
  • u014595019
  • u014595019
  • 2016年10月13日 11:29
  • 54163

tensorflow saver 保存和恢复指定 tensor

在实践中经常会遇到这样的情况: 1, 用简单的模型预训练参数 2, 把预训练的参数导入复杂的模型后训练复杂的模型 这时就产生一个问题:                 如何加载预训练的参数。 下面就是...
  • xueyingxue001
  • xueyingxue001
  • 2017年04月25日 10:36
  • 1073

tensorflow1.x版本加载saver.restore目录报错

这个错误是最新的错误哈,目前只在tensorflow上的github仓库上面有提出,所以你在百度上面找不到。 是个tensorflow的bug十天前提出的。saver.restore(sess, 'D...
  • u014283248
  • u014283248
  • 2017年03月21日 12:31
  • 11941

TensorFLow 入门 - 用Saver保存和恢复变量

建立文件tensor_save.py, 保存变量v1,v2到checkpoint files中,变量名分别设置为v3,v4。import tensorflow as tf# Create some v...
  • muyiyushan
  • muyiyushan
  • 2017年03月30日 16:27
  • 1822
内容举报
返回顶部
收藏助手
不良信息举报
您举报文章:tensorflow saver 保存和恢复指定 tensor
举报原因:
原因补充:

(最多只允许输入30个字)