fine tuning---tensorflow基础---加载A、B模型中的部分参数,来初始化C模型的部分参数

简单恢复全部参数

保存了所有的参数,然后加载所有的参数的方法如下:

一般实验情况下保存的时候,都是用的saver类来保存,如下

saver = tf.train.Saver()
saver.save(sess,"model.ckpt")

加载时的代码

saver.restore(sess,"model.ckpt")

恢复部分参数的方法步骤

TensorFlow restore部分变量

tensorflow restore 的基本原理

Tensorflow在restore模型的时候是严格按照变量名的对应关系赋值的。例如在构建的Graph里有个name=‘net/frame/cnn_1’ 的变量,那么在restore的阶段,会在checkpoint文件查找是否存在变量名相同的值,如果存在,则赋值;否则抛出异常。
所以如果我们只想要对Graph的一部分进行restore,(1)列出所有需要checkpoint恢复值的变量名,(2)告诉tf.train.Saver 去restore这些变量。

1、列出所有需要checkpoint恢复值的变量名

在Tensorflow中,每个变量的名字是Graph里面的唯一值,从而可以用变量名来获取具体的变量。变量名还可以有多个层级,如下面代码所示,变量nets_frame的名字是’frame/conv1d’。:

with tf.variable_scope('frame', reuse=reuse) as scope:
       nets_frame = tf.layers.conv1d(frame_input, filters=2048, kernel_size=5, name='conv1d')

根据这个特性,可以用下面这段代码获取任意变量, 这段代码列出了所有’frame’开头的变量(最上层的scope name=‘frame’)并保存在list中:

import tensorflow as tf
import tensorflow.contrib.slim as slim
variables = slim.get_variables_to_restore()
variables_to_restore_frame = [v for v in variables if v.name.split('/')[0] == 'frame']

如果想列出所有变量名中带有’image’的变量:

variables_to_restore_image = [v for v in variables if 'image' in v.name]
2、告诉tf.train.Saver 去restore这些变量

在创建Saver的时候指定需要赋值的变量即可。

saver_frame = tf.train.Saver(variables_to_restore_frame)#在创建Saver的时候指定需要赋值的变量
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
    sess.run(init_op)
    saver_frame.restore(sess, frame_checkpoint)

实例:加载A模型中恢复部分参数,来初始化B模型的部分参数

下面给出一个完整的例子,把两个子图restore合并成一个大图然后保存到checkpoint:

# 
global_step = tf.Variable(0, name="global_step", trainable=False)

#1.construct network
# 该函数创建了一个大图,包括三个部分:子图1(frame)、子图2(audio)、预测层
def construct_network(frame_input, audio_input, tags_input, reuse, is_training):
   """
   :param frame_input:
   :param tags_input:
   :param reuse:
   :param is_training:
   :return:
   """
   # 创建子图1
   with tf.variable_scope('frame', reuse=reuse) as scope:
       nets_frame = tf.layers.conv1d(frame_input, filters=2048, kernel_size=5, name='conv1d')
       nets_frame = slim.batch_norm(nets_frame,
                              decay=0.9997,
                              epsilon=0.001,
                              is_training=is_training)
       # global max pooling layer
       nets_frame = tf.reduce_max(nets_frame, reduction_indices=[1], name='max_pool')

       fc_frame = tf.layers.dense(nets_frame, 2048, name='fc1')
       fc_frame = tf.layers.dropout(fc_frame, keep_prob, training=is_training)
       fc_frame = tf.nn.relu(fc_frame)

   #创建子图2
   with tf.variable_scope('audio', reuse=reuse) as scope:
       nets_audio = tf.layers.conv1d(audio_input, filters=2048, kernel_size=5, name='conv1d')
       nets_audio = slim.batch_norm(nets_audio,
                              decay=0.9997,
                              epsilon=0.001,
                              is_training=is_training)
       # global max pooling layer
       nets_audio = tf.reduce_max(nets_audio, reduction_indices=[1], name='max_pool')

       fc_audio = tf.layers.dense(nets_audio, 2048, name='fc1')
       fc_audio = tf.layers.dropout(fc_audio, keep_prob, training=is_training)
       fc_audio = tf.nn.relu(fc_audio)

	#预测层,将子图1和子图2拼接后,进行全连接,然后预测输出
   with tf.variable_scope('predict', reuse=reuse) as scope:
       total_vector = tf.concat([fc_frame, fc_audio], axis=1)
       predict = tf.layers.dense(total_vector, TAG_NUM, name='predict')
       predict_confidence = tf.sigmoid(predict, name='confidence')  # (0,1)
       loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=predict,
                                                                     labels=tags_input)) * 1000

       L2 = 0
       for w in tl.layers.get_variables_with_name('frame', True, True):
           L2 += tf.contrib.layers.l2_regularizer(1.0)(w)
       for w in tl.layers.get_variables_with_name('audio', True, True):
           L2 += tf.contrib.layers.l2_regularizer(1.0)(w)
       for w in tl.layers.get_variables_with_name('predict', True, True):
           L2 += tf.contrib.layers.l2_regularizer(1.0)(w)


   result = dict()
   result['loss'] = loss
   result['predict'] = predict
   result['confidence'] = predict_confidence
   result['L2'] = L2
   return result




#2. 定义网络的输入placeholder
frame_fea_placeholder = tf.placeholder(dtype=tf.float32,
                                      shape=(None, FRAME_FEAT_LEN, FRAME_FEAT_DIM),
                                      name='frame_feat')
audio_fea_placeholder = tf.placeholder(dtype=tf.float32,
                                      shape=(None, AUDIO_FEAT_LEN, AUDIO_FEAT_DIM),
                                      name='audio_feat')
tags_placeholder = tf.placeholder(dtype=tf.float32,
                                 shape=(None, TAG_NUM),
                                 name='tags')


#3.构建大图
train_nets = construct_network(frame_fea_placeholder, audio_fea_placeholder,
                              tags_placeholder, reuse=False, is_training=True)

#4. 定义优化器和train_opt
learning_rate = tf.train.exponential_decay(0.0001, global_step,
                                          2000, 0.95,
                                          staircase=True)
optimizer = tf.train.AdamOptimizer(learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-08,
                                  use_locking=False)
train_opt = slim.learning.create_train_op(train_nets['loss'], optimizer, global_step)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)



# 5. run session
step = 0
init_op = [tf.global_variables_initializer(),
          tf.local_variables_initializer()]

#6.定义需要从预训练的模型中恢复哪些参数,其中有两个参数组,将分别从两个模型中恢复
variables = slim.get_variables_to_restore()
variables_to_restore_frame = [v for v in variables if v.name.split('/')[0] == 'frame']
variables_to_restore_audio = [v for v in variables if v.name.split('/')[0] == 'audio']

#定义两个saver,用于恢复不同的参数组
saver_frame = tf.train.Saver(variables_to_restore_frame)
saver_audio = tf.train.Saver(variables_to_restore_audio)
saver = tf.train.Saver()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
   sess.run(init_op)
   #两个saver,从不同的ckpt文件中恢复两个参数组(frame_checkpoint、audio_checkpoint是两个模型参数的文件)
   saver_frame.restore(sess, frame_checkpoint)
   saver_audio.restore(sess, audio_checkpoint)
   saver.save(sess, save_checkpoint)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值