一、使用estimator训练模型的流程
1、构建model_fn
def my_metric_fn(labels, predictions):
return {'accuracy': tf.metrics.accuracy(labels, predictions)}
def model_fn(features, labels, mode, params):
""" TODO: 模型函数必须有这四个参数
:param features: # 输入的特征数据
:param labels: # 输入的标签数据
:param mode: # train、evaluate或predict
:param params: #超参数,对应Estimator传来的参数
:return: TPUEstimatorSpec类型的对象
"""
eval_metrics=(my_metric_fn, [labels, predictions])
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
mode=mode,
loss=total_loss,
eval_metrics=eval_metrics,
scaffold_fn=scaffold_fn)
return output_spec
2、定义estimator
run_config = tf.contrib.tpu.RunConfig(
cluster=tpu_cluster_resolver,
master=FLAGS.master,
model_dir=FLAGS.output_dir,
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
keep_checkpoint_max=FLAGS.keep_checkpoint_max,
tf_random_seed=FLAGS.random_seed,
tpu_config=tf.contrib.tpu.TPUConfig(
iterations_per_loop=FLAGS.save_checkpoints_steps,
num_shards=FLAGS.num_tpu_cores,
per_host_input_for_training=is_per_host
))
# 自定义估算器
estimator = tf.contrib.tpu.TPUEstimator(
use_tpu=FLAGS.use_tpu,
model_fn=model_fn, # 模型函数
config=run_config, # 设置参数对象
train_batch_size=FLAGS.train_batch_size,
eval_batch_size=FLAGS.eval_batch_size,
predict_batch_size=FLAGS.predict_batch_size)
3、训练模型
def train_input_fn(params):
batch_size = params["batch_size"]
d = tf.data.TFRecordDataset(input_file)
if is_training:
d = d.repeat()
d = d.shuffle(buffer_size=100, seed=random.randint(1, 10000))
d = d.apply(tf.data.experimental.map_and_batch(
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size,
drop_remainder=drop_remainder
))
return d
estimator.train(input_fn=train_input_fn, max_steps=next_checkpoint)
4、验证模型
def eval_input_fn(params):
batch_size = params["batch_size"]
d = tf.data.TFRecordDataset(input_file)
if is_training:
d = d.repeat()
d = d.shuffle(buffer_size=100, seed=random.randint(1, 10000))
d = d.apply(tf.data.experimental.map_and_batch(
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size,
drop_remainder=drop_remainder
))
return d
result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
for key in sorted(result.keys()):
log_info = " %s = %s"%(key, str(result[key]))
5、测试模型
def predict_input_fn(params):
batch_size = params["batch_size"]
d = tf.data.TFRecordDataset(input_file)
if is_training:
d = d.repeat()
d = d.shuffle(buffer_size=100, seed=random.randint(1, 10000))
d = d.apply(tf.data.experimental.map_and_batch(
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size,
drop_remainder=drop_remainder
))
return d
result = estimator.predict(input_fn=predict_input_fn)
for key in sorted(result.keys()):
log_info = " %s = %s"%(key, str(result[key]))
二、使用estimator训练模型的样例