Tensorflow详解保存模型(进阶版一):如何有选择的保存变量

当然掌握了基础版还不够,我们来看一下进阶版一:如何有选择的保存变量:

这里还要另外涉及两个函数

tf.variable_scope("xxx") 和 tf.get_variable(name,shape,initializer = initializer)

简单介绍,详细介绍还请大家在查阅其他资料

tf.varibale_scope("xxx"),是定义变量的作用范围的函数,相当于把变量做一个集合方便管理。

tf.get_variable(),是生成一个变量。如果该变量生成过,就继承;如果没有声明过,则重新生成。(注意,必须在会话(sess)外调用,否则会报错类似:使用了未初始化(uninitialized)变量)

我们以C3D的做法举例:

def _variable_on_cpu(name, shape, initializer):
    #with tf.device('/cpu:%d' % cpu_id):
    with tf.device('/cpu:0'):
        var = tf.get_variable(name, shape, initializer=initializer)
    return var

def _variable_with_weight_decay(name, shape, stddev, wd):
    var = _variable_on_cpu(name, shape, tf.truncated_normal_initializer(stddev=stddev))
    if wd is not None:
        weight_decay = tf.nn.l2_loss(var) * wd
        tf.add_to_collection('losses', weight_decay)
    return var

with tf.Graph().as_default():
        
        with tf.variable_scope('var_name') as var_scope:
            weights = {
                'wc1': _variable_with_weight_decay('wc1', [3, 3, 3, 3, 64], 0.04, 0.00),
                'wc2': _variable_with_weight_decay('wc2', [3, 3, 3, 64, 128], 0.04, 0.00),
                'wc3a': _variable_with_weight_decay('wc3a', [3, 3, 3, 128, 256], 0.04, 0.00),
                'wc3b': _variable_with_weight_decay('wc3b', [3, 3, 3, 256, 256], 0.04, 0.00),
                'wc4a': _variable_with_weight_decay('wc4a', [3, 3, 3, 256, 512], 0.04, 0.00),
                'wc4b': _variable_with_weight_decay('wc4b', [3, 3, 3, 512, 512], 0.04, 0.00),
                'wc5a': _variable_with_weight_decay('wc5a', [3, 3, 3, 512, 512], 0.04, 0.00),
                'wc5b': _variable_with_weight_decay('wc5b', [3, 3, 3, 512, 512], 0.04, 0.00),
                'cam':_variable_with_weight_decay('cam', [1,1,512,c3d_model.NUM_CLASSES], 0.04,0.00),
            }
            biases = {
                'bc1': _variable_with_weight_decay('bc1', [64], 0.04, 0.0),
                'bc2': _variable_with_weight_decay('bc2', [128], 0.04, 0.0),
                'bc3a': _variable_with_weight_decay('bc3a', [256], 0.04, 0.0),
                'bc3b': _variable_with_weight_decay('bc3b', [256], 0.04, 0.0),
                'bc4a': _variable_with_weight_decay('bc4a', [512], 0.04, 0.0),
                'bc4b': _variable_with_weight_decay('bc4b', [512], 0.04, 0.0),
                'bc5a': _variable_with_weight_decay('bc5a', [512], 0.04, 0.0),
                'bc5b': _variable_with_weight_decay('bc5b', [512], 0.04, 0.0),
            }

仔细看,这里weights 和 bias 都是字典。且都被划分在了“var_name”这个集合之下:

在“var_name”这个集合下, _variable_with_weight_decay()函数调用了tf.get_variable()函数,这就相当于将tf.get_variable() 产生得变量加入到这个集合中。即在 tf.variable_scope('xxx') 下调用 tf.get_variable() 生成的变量都会被包括在该集合(xxx)中。

这会造成一个什么效果呢?就是在保存变量到模型中的时候Tensorname 会变成 xxx/wc1、xxx/wc2 等。

接下来我们来看如何有选择的保存变量:

import tensorflow as tf

with tf.Graph() as graph:

    model_name = 'xxx/xxx/name'

    with tf.Session() as sess:
     
         init = tf.global_variables_initializer()
         sess.run(init)
     
         variables_to_save = ... # 自己定义
     
         saver = tf.train.Saver(vraiables_to_save)

         for epoch in range(max_epochs)

             training() #
         
             saver.save(sess,model_name)

我们来看代码里的自己定义的部分:这也就是tf.train.Saver函数的输入变量,官方给的文档说,要求要是‘dict’型或者是‘list’型,这里大家获取能明白为什么用字典来声明了。

  • dict of names to variables: The keys are the names that will be used to save or restore the variables in the checkpoint files.
  • A list of variables: The variables will be keyed with their op name in the checkpoint files.

简单来说:

dict型:{变量名:变量值}
list型:[变量值,]

 

举个例子,我们只想保存wc,wc2,bc1,bc2。针对上面的 C3D,我们可以这样操作。

#1.dict型

variables_to_save = {'wc1':weights['wc1'],'wc2':weights['wc2'],'bc1':weights['bc1'],'bc2':weigghts['bc2']}

#2.list型

variables_to_save = [weight['wc1'],weight['wc2'],weight['bc1'],weights['bc2']]

 到这里我们就介绍结束了,可能初次学习会有困难,欢迎留言交流。

  

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值