tensorflow模型参数的保存、恢复(part save and restore)与训练

本文阐述了在深度学习中,对于复杂模型如CNN与LSTM级联,采用分步训练的方法,通过固定部分参数训练其余参数,实现模型各部分的良好收敛。介绍了设定变量范围、获取参数集合、设置不同saver保存和恢复参数、指定训练部分参数等关键步骤。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

  • 模型分步训练方法综述

对于较为复杂(如CNN与LSTM级联),或者网络层数过深的深度学习模型,在训练过程中,如果采用直接训练的方式,会导致训练难度极大,甚至模型收敛性降低。此时需要采用固定部分参数,训练部分参数的方式来分开训练整个庞大模型,在完成模型的分布训练之后,模型的各个部分均能较好的收敛,并保存得到优良的参数。随后在测试过程中,分别恢复保存好的这两部分参数,完成模型的测试工作。下面以CNN与LSTM级联模型为例,介绍分步训练方法。

  • 要点分类
    第一,在构建模型时,需要预先为模型中的各个子模块设定变量范围scope,便于变量管理;以下是设置scope的几种方式。
with tf.variable_scope('inference', reuse = tf.AUTO_REUSE):
       xxxxx
conv6 = slim.conv2d(conv5, 128, [3,3], stride=2, padding='VALID', scope='conv2d_1a_3x3')

第二,因为后续涉及到模型参数的分开保存,需要首先设法获取两类参数各自的集合;

var_list = tf.trainable_variables() #获取所有变量
#根据变量名相关特征获取参数集合
lstm_variables = [v for v in var_list if v.name.split('/')[1].split('_')[0] == 'lstm']    
cnn_variables = [v for v in var_list if v.name.split('/')[1].split('_')[0] != 'lstm']

第三,设置不同的saver与相应检查点路径,用于保存和恢复对应类别的参数

    saver = tf.train.Saver(cnn_variables) #saver用于保存部分参数:cnn_variable
    saver2 = tf.train.Saver(lstm_variables)  #saver2用于保存部分参数lstm_variable
    saver1 = tf.train.Saver() #saver1默认保存神经网络会话中的所有参数
    ckpt = tf.train.get_checkpoint_state(logs_dir_lstm) #此处的log_dir_lstm对应lstm参数保存路径,如果有预训练参数,可以在预训练的基础上fine tune
    if ckpt and ckpt.model_checkpoint_path:
        saver2.restore(sess, ckpt.model_checkpoint_path) #此处调用saver2
        print("Model restored...")

第四,训练过程中保存相关参数

        if index % 1000 == 0 or index == MAX_STEP - 1:
            checkpoint_path = os.path.join(logs_dir_cnn, 'model.ckpt') #设定参数保存路径
            saver.save(sess, checkpoint_path, global_step= index)  #指定相应saver,保存对应参数

-指定训练部分参数
参数的训练与模型参数的保存是两个不同的问题,在训练时同样可以通过指定参数范围来控制需要训练的变量。下边的代码中,通过设定FLAGS.tune_scope来指定训练变量范围。

def _get_training(logits, labels):
    """Set up training ops"""
    with tf.name_scope("train"):

        if FLAGS.tune_scope:
            scope = FLAGS.tune_scope
            
        else:
            scope = r"inference/conv"
#            scope = r"inference/lstm_layer"
#        rnn_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)
        cnn_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)
#        rnn_vars = var_filter(tf.trainable_variables(), last_layers = range(2))    
        global_step = tf.Variable(0)
        trainable_var_collection.append(cnn_vars)
        loss = losses(logits, labels)
        correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        tf.summary.scalar('accuracy', accuracy)

        extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(extra_update_ops):

            learning_rate = tf.train.exponential_decay(
                FLAGS.learning_rate,
                global_step,
                FLAGS.decay_steps,
                FLAGS.decay_rate,
                staircase=FLAGS.decay_staircase,
                name='lr')

            optimizer = tf.train.AdamOptimizer(
                learning_rate=learning_rate,
                beta1=FLAGS.momentum)

            train_op = tf.contrib.layers.optimize_loss(
                loss=loss,
                global_step=tf.train.get_global_step(),
                learning_rate=learning_rate,
                optimizer=optimizer,
                variables=trainable_var_collection)
            tf.summary.scalar('learning_rate', learning_rate)
            tf.summary.scalar('loss', loss)

    return train_op, loss
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值