【TensorFlow】TensorFlow之Fine tune

                                            TensorFlow之Fine tune

TensorFlow中可以通过加载已训练模型的部分参数对当前训练任务进行微调

一、tf.train.Saver类中恢复部分参数的方法

在创建saver对象时,将模型中需要恢复的参数(名称)以字典的形式,作为参数传递:

with tf.variable_scope('conv_layer1-other'):
    conv1_weights = tf.get_variable("weight", [CONV1_SIZE, CONV1_SIZE, NUM_CHANNELS, CONV1_DEEP], initializer=tf.truncated_normal_initializer(stddev=0.1))
    ......
with tf.variable_scope('conv_layer2-other'):
    conv1_weights = tf.get_variable("weight", [CONV2_SIZE, CONV2_SIZE, CONV1_DEEP, CONV2_DEEP], initializer=tf.truncated_normal_initializer(stddev=0.1))
    ......

variables_to_restore = {'conv_layer1/weight': conv1_weights, 'conv_layer2/weight': conv2_weights}

init_op = tf.initialize_all_variables()
with tf.Session() as sess:
    sess.run(init_op )
    saver.restore(sess, model_name)
    ......

上述程序实现将模型中名称为:'conv_layer1\weight'和'conv_layer2\weight'的参数值,加载给当前网络中的变量conv1_weights和conv2_weights(名称为:'conv_layer1-other\weight'和'conv_layer2-other\weight')

二、slim模块中的方法

exclude = ['conv_layer1/weight', 'conv_layer2/weight']
variables_to_restore = slim.get_variables_to_restore(exclude=exclude )  
init_fn = slim.assign_from_checkpoint_fn(model_path, variables_to_restore)  
  
# Start training.  
slim.learning.train(train_op, log_dir, init_fn=init_fn, ignore_missing_vars=ignore_missing_vars) 

其中,ignore_missing_vars为True时,将会忽略不存在参数

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值