tf.estimator.Estimator的使用

tf.estimator.Estimator是TF比较高级的接口。

最近在使用bert预训练模型的时候用到了tf.estimator.Estimator。使用该接口的时候需要开发者完成的工作比较少,一共3个步骤:

第一步,设置input_fun,第二步,设置model_fun,第三步,开始训练。

第一步的input_fun完成的功能是数据的输入准备工作,比如读取一个tfrecord文件,然后解析里面的内容,返回dataset;或者读取音频、图像等数据,返回相应的结果,目前来说返回的结果为dataset格式比较好。

第二步的model_fun完成的功能有:创建模型(输入feature,输出predict这种),设置loss,设置优化器,返回结果是tf.estimator.EstimatorSpec。(后续会说明tf.estimator.EstimatorSpec是什么,怎么设置)

第三步的开始训练是:参数准备(比如学习率什么的,就是上面的步骤1-2中需要用到的参数),设置config(用于训练模型是指定模型的保存路径,多长时间保存一次模型,使用GPU的一些情况),开始根据情况调用estimator.train 和 estimator.evaluate 或者 estimator.predict。

 

第一步:input_fun

def input_fn(filenames, batch_size=32, num_epochs=None, perform_shuffle=False):
    """
    每次调用,从TFRecord文件中读取一个大小为batch_size的batch
    Args:
        filenames: TFRecord文件
        batch_size: batch_size大小
        num_epochs: 将TFRecord中的数据重复几遍,如果是None,则永远循环读取不会停止
        perform_shuffle: 是否乱序

    Returns:
        tensor格式的,一个batch的数据
    """
    def _parse_fn(record):
        features = {
            "label": tf.FixedLenFeature([], tf.int64),
            "image": tf.FixedLenFeature([], tf.string),
        }
        parsed = tf.parse_single_example(record, features)
        # image
        image = tf.decode_raw(parsed["image"], tf.uint8)
        image = tf.reshape(image, [28, 28])
        # label
        label = tf.cast(parsed["label"], tf.int64)
        return {"image": image}, label

    # Extract lines from input files using the Dataset API, can pass one filename or filename list
    dataset = tf.data.TFRecordDataset(filenames).map(_parse_fn, num_parallel_calls=10).prefetch(500000)    # multi-thread pre-process then prefetch

    # Randomizes input using a window of 256 elements (read into memory)
    if perform_shuffle:
        dataset = dataset.shuffle(buffer_size=256)

    # epochs from blending together.
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size)  # Batch size to use

    iterator = dataset.make_one_shot_iterator()
    batch_features, batch_labels = iterator.get_next()
    return batch_features, batch_labels

 

第二步:model_fun

def model_fn(features, labels, mode, params):
    """
    :param features:
    :param labels:
    :param mode: 指定训练、验证和测试三种模式
    tf.estimator.ModeKeys.TRAIN    tf.estimator.ModeKeys.EVAL  tf.estimator.ModeKeys.PREDICT
    :param params: 包含学习率等超参数的设计
    :return:
    """
    # step1: 构建模型
    logits = create_model(features)
    predict = tf.nn.softmax(logits, axis=-1)

    # step2: 构建loss、optimization等
    loss = get_loss(logits, labels)
    train_op = tf.train.GradientDescentOptimizer(params['lr']).minimize(loss)

    # step3: 根据mode,构建不同情况下的tf.estimator.EstimatorSpec
    # For mode == ModeKeys.TRAIN: 需要的参数是 loss and train_op.
    # For mode == ModeKeys.EVAL:  需要的参数是  loss.
    # For mode == ModeKeys.PREDICT: 需要的参数是 predictions.
    if mode == tf.estimator.ModeKeys.TRAIN:
        # logging_hook是模型训练/测试的工具,主要执行特定的任务,如判断是否需要停止训练的EarlyStopping,
        # 改变学习速率的LearningRateScheduler,共性就是在每个step开始/结束或者每个epoch开始/结束时需要执行某个操作。
        output_spec = tf.estimator.EstimatorSpec(
            mode=mode,
            loss=loss,
            train_op=train_op,
            training_hooks=[logging_hook])
    elif mode == tf.estimator.ModeKeys.EVAL:
        output_spec = tf.estimator.EstimatorSpec(
            mode=mode,
            loss=loss,
            eval_metric_ops=eval_metrics)
    else:
        output_spec = tf.estimator.EstimatorSpec(
            mode=mode,
            predictions={"probabilities": predict})

    return output_spec

 

第三步:main

def main_():
    # 1. 设置超参数
    params = {'lr', 0.0001}

    # 2. 设置config,用于控制模型保存的位置,多久保存一次等
    session_config = tf.ConfigProto(log_device_placement=False,
                                    inter_op_parallelism_threads=0,
                                    intra_op_parallelism_threads=0,
                                    allow_soft_placement=True)
    run_config = tf.estimator.RunConfig(model_dir=model_output_dir,
                                        save_checkpoints_steps=5000,
                                        keep_checkpoint_max=3,
                                        session_config=session_config)

    # 3. 开始训练
    estimator = tf.estimator.Estimator(
        model_fn=model_fn,
        config=run_config,
        params=params)

    if do_train:
        train_input_fn = input_fun(...)
        estimator.train(input_fn=train_input_fn)
    
    elif do_eval:
        eval_input_fn = input_fun(...)
        estimator.train(input_fn=eval_input_fn)
        
    else:
        predict_input_fn = input_fun(...)
        estimator.train(input_fn=predict_input_fn)

 

 

===未完待续===

之后会更新关于hook等如何设置

参考文献:

https://zhuanlan.zhihu.com/p/129018863

https://zhuanlan.zhihu.com/p/106400162

https://www.jianshu.com/p/5495f87107e7

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值