使用estimator结构训练tf模型

一、使用estimator训练模型的流程

1、构建model_fn

def my_metric_fn(labels, predictions):
   return {'accuracy': tf.metrics.accuracy(labels, predictions)}
   
def model_fn(features, labels, mode, params):
    """ TODO: 模型函数必须有这四个参数
    :param features: # 输入的特征数据
    :param labels: # 输入的标签数据
    :param mode: # train、evaluate或predict
    :param params: #超参数,对应Estimator传来的参数
    :return: TPUEstimatorSpec类型的对象
    """
    eval_metrics=(my_metric_fn, [labels, predictions])
    output_spec = tf.contrib.tpu.TPUEstimatorSpec(
           mode=mode, # "train" or "eval" or "predict"
           loss=total_loss, # double类型
           eval_metrics=eval_metrics, 
           scaffold_fn=scaffold_fn)  # None or fun
    return output_spec

2、定义estimator

run_config = tf.contrib.tpu.RunConfig(
    cluster=tpu_cluster_resolver,
    master=FLAGS.master,
    model_dir=FLAGS.output_dir,
    save_checkpoints_steps=FLAGS.save_checkpoints_steps,
    keep_checkpoint_max=FLAGS.keep_checkpoint_max,
    tf_random_seed=FLAGS.random_seed,
    tpu_config=tf.contrib.tpu.TPUConfig(
        iterations_per_loop=FLAGS.save_checkpoints_steps,
        num_shards=FLAGS.num_tpu_cores,
        per_host_input_for_training=is_per_host
    ))
    
# 自定义估算器
estimator = tf.contrib.tpu.TPUEstimator(
    use_tpu=FLAGS.use_tpu,
    model_fn=model_fn,  # 模型函数
    config=run_config,  # 设置参数对象
    train_batch_size=FLAGS.train_batch_size,
    eval_batch_size=FLAGS.eval_batch_size,
    predict_batch_size=FLAGS.predict_batch_size)

3、训练模型

def train_input_fn(params):
    batch_size = params["batch_size"]
    d = tf.data.TFRecordDataset(input_file)
    if is_training:
        d = d.repeat()
        d = d.shuffle(buffer_size=100, seed=random.randint(1, 10000))
    d = d.apply(tf.data.experimental.map_and_batch(
        lambda record: _decode_record(record, name_to_features),
        batch_size=batch_size,
        drop_remainder=drop_remainder
    ))
    return d

estimator.train(input_fn=train_input_fn, max_steps=next_checkpoint)

4、验证模型

def eval_input_fn(params): # 部分代码 只看框架即可
    batch_size = params["batch_size"]
    d = tf.data.TFRecordDataset(input_file)
    if is_training:
        d = d.repeat()
        d = d.shuffle(buffer_size=100, seed=random.randint(1, 10000))
    d = d.apply(tf.data.experimental.map_and_batch(
        lambda record: _decode_record(record, name_to_features),
        batch_size=batch_size,
        drop_remainder=drop_remainder
    ))
    return d
result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)  # type:dict
for key in sorted(result.keys()):
   log_info = "  %s = %s"%(key, str(result[key]))

5、测试模型

def predict_input_fn(params): # 部分代码 只看框架即可
    batch_size = params["batch_size"]
    d = tf.data.TFRecordDataset(input_file)
    if is_training:
        d = d.repeat()
        d = d.shuffle(buffer_size=100, seed=random.randint(1, 10000))
    d = d.apply(tf.data.experimental.map_and_batch(
        lambda record: _decode_record(record, name_to_features),
        batch_size=batch_size,
        drop_remainder=drop_remainder
    ))
    return d
result = estimator.predict(input_fn=predict_input_fn)  # type:dict
for key in sorted(result.keys()):
   log_info = "  %s = %s"%(key, str(result[key]))

二、使用estimator训练模型的样例

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值