TensorFlow restore部分变量

有时候我们需要在Tensorflow中restore图的部分变量,比如我们想把两个已经训练好的子图合并成一个大图;或者我们在训练完一个网络结构后对其进行扩展,并且保留已经训练好的部分不变;又或者要对graph的几个部分交替进行训练。

      Tensorflow在restore模型的时候是严格按照变量名的对应关系赋值的。例如在构建的Graph里有个name='net/frame/cnn_1' 的变量,那么在restore的阶段,会在checkpoint文件查找是否存在变量名相同的值,如果存在,则赋值;否则抛出异常。

      所以如果我们只想要对Graph的一部分进行restore,(1)列出所有需要checkpoint恢复值的变量名,(2)告诉tf.train.Saver 去restore这些变量。

1、列出所有需要checkpoint恢复值的变量名

      在Tensorflow中,每个变量的名字是Graph里面的唯一值,从而可以用变量名来获取具体的变量。变量名还可以有多个层级,例如:

with tf.variable_scope('frame', reuse=reuse) as scope:
        nets_frame = tf.layers.conv1d(frame_input, filters=2048, kernel_size=5, name='conv1d')

变量nets_frame的名字是'frame/conv1d'

      根据这个特性,可以用下面这段代码获取任意变量:

import tensorflow as tf
import tensorflow.contrib.slim as slim
variables = slim.get_variables_to_restore()
variables_to_restore_frame = [v for v in variables if v.name.split('/')[0] == 'frame']

      上面这段代码列出了所有'frame'开头的变量(最上层的scope name='frame')并保存在list中,如果想列出所有变量名中带有'image'的变量:

variables_to_restore_image = [v for v in variables if 'image' in v.name]
2、告诉tf.train.Saver 去restore这些变量
saver_frame = tf.train.Saver(variables_to_restore_frame)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
    sess.run(init_op)
    saver_frame.restore(sess, frame_checkpoint)

      在创建Saver的时候指定需要赋值的变量即可。

      下面给出一个完整的例子,把两个子图restore合并成一个大图然后保存到checkpoint:

# 1.construct network
global_step = tf.Variable(0, name="global_step", trainable=False)


def construct_network(frame_input, audio_input, tags_input, reuse, is_training):
   """
   :param frame_input:
   :param tags_input:
   :param reuse:
   :param is_training:
   :return:
   """
   with tf.variable_scope('frame', reuse=reuse) as scope:
       nets_frame = tf.layers.conv1d(frame_input, filters=2048, kernel_size=5, name='conv1d')
       nets_frame = slim.batch_norm(nets_frame,
                              decay=0.9997,
                              epsilon=0.001,
                              is_training=is_training)
       # global max pooling layer
       nets_frame = tf.reduce_max(nets_frame, reduction_indices=[1], name='max_pool')

       fc_frame = tf.layers.dense(nets_frame, 2048, name='fc1')
       fc_frame = tf.layers.dropout(fc_frame, keep_prob, training=is_training)
       fc_frame = tf.nn.relu(fc_frame)

   with tf.variable_scope('audio', reuse=reuse) as scope:
       nets_audio = tf.layers.conv1d(audio_input, filters=2048, kernel_size=5, name='conv1d')
       nets_audio = slim.batch_norm(nets_audio,
                              decay=0.9997,
                              epsilon=0.001,
                              is_training=is_training)
       # global max pooling layer
       nets_audio = tf.reduce_max(nets_audio, reduction_indices=[1], name='max_pool')

       fc_audio = tf.layers.dense(nets_audio, 2048, name='fc1')
       fc_audio = tf.layers.dropout(fc_audio, keep_prob, training=is_training)
       fc_audio = tf.nn.relu(fc_audio)


   with tf.variable_scope('predict', reuse=reuse) as scope:
       total_vector = tf.concat([fc_frame, fc_audio], axis=1)
       predict = tf.layers.dense(total_vector, TAG_NUM, name='predict')
       predict_confidence = tf.sigmoid(predict, name='confidence')  # (0,1)
       loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=predict,
                                                                     labels=tags_input)) * 1000

       L2 = 0
       for w in tl.layers.get_variables_with_name('frame', True, True):
           L2 += tf.contrib.layers.l2_regularizer(1.0)(w)
       for w in tl.layers.get_variables_with_name('audio', True, True):
           L2 += tf.contrib.layers.l2_regularizer(1.0)(w)
       for w in tl.layers.get_variables_with_name('predict', True, True):
           L2 += tf.contrib.layers.l2_regularizer(1.0)(w)


   result = dict()
   result['loss'] = loss
   result['predict'] = predict
   result['confidence'] = predict_confidence
   result['L2'] = L2
   return result


frame_fea_placeholder = tf.placeholder(dtype=tf.float32,
                                      shape=(None, FRAME_FEAT_LEN, FRAME_FEAT_DIM),
                                      name='frame_feat')

audio_fea_placeholder = tf.placeholder(dtype=tf.float32,
                                      shape=(None, AUDIO_FEAT_LEN, AUDIO_FEAT_DIM),
                                      name='audio_feat')

tags_placeholder = tf.placeholder(dtype=tf.float32,
                                 shape=(None, TAG_NUM),
                                 name='tags')

train_nets = construct_network(frame_fea_placeholder, audio_fea_placeholder,
                              tags_placeholder, reuse=False, is_training=True)

# training op
learning_rate = tf.train.exponential_decay(0.0001, global_step,
                                          2000, 0.95,
                                          staircase=True)
optimizer = tf.train.AdamOptimizer(learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-08,
                                  use_locking=False)
train_opt = slim.learning.create_train_op(train_nets['loss'], optimizer, global_step)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

# 2. run session
step = 0
init_op = [tf.global_variables_initializer(),
          tf.local_variables_initializer()]

variables = slim.get_variables_to_restore()
variables_to_restore_frame = [v for v in variables if v.name.split('/')[0] == 'frame']
variables_to_restore_audio = [v for v in variables if v.name.split('/')[0] == 'audio']

saver_frame = tf.train.Saver(variables_to_restore_frame)
saver_audio = tf.train.Saver(variables_to_restore_audio)
saver = tf.train.Saver()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
   sess.run(init_op)
   saver_frame.restore(sess, frame_checkpoint)
   saver_audio.restore(sess, audio_checkpoint)
   saver.save(sess, save_checkpoint)
  • 3
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值