estimator是tensorflow高度封装的一个类,里面有一些可以直接使用的分类和回归模型,例如tf.estimator.DNNClassifier,但这不是这篇博客的主题,而是怎么使用estimator来实现我们自定义模型的训练。它的步骤主要分为以下几个部分:
- 构建model_fn,在这个方法里面定义自己的模型以及训练和测试过程要做的事情;
- 构建input_fn,在这个方法数据的来源和喂给模型的方式;
- 最后,创建estimator对象,然后开始训练模型了。可以添加一些config,比如:loss的输出频率等。
构建model_fn方法
import tensorflow as tfdef model_fn(features, labels, mode, params): # 必须要有前面三个参数 # feature和labels其实就是`input_fn`方法传输过来的 # mode是用来判断你现在是训练或测试阶段 # params是在创建`estimator`对象的输入参数 lr = params['lr'] try: init_checkpoint = params['init_checkpoint'] except KeyError: init_checkpoint = None x = features['inputs'] y = features['labels'] #####################在这里定义你自己的网络模型################### pre = tf.layers.dense(x, 1) loss = tf.reduce_mean(tf.pow(pre-y, 2), name='loss') ######################在这里定义你自己的网络模型################### # 这里可以加载你的预训练模型 assignment_map = dict() if init_checkpoint: for var in tf.train.list_variables(init_checkpoint): # 存放checkpoint的变量名称和shape assignment_map[var[0]] = var[0] tf.train.init_from_checkpoint(init_checkpoint, assignment_map) # 定义你训练过程要做的事情 if mode == tf.estimator.ModeKeys.TRAIN: optimizer = tf.train.AdamOptimizer(lr) train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) output_spec = tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) # 定义你测试(验证)过程 elif mode == tf.estimator.ModeKeys.EVAL: metrics = {'eval_loss': loss} output_spec = tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics) # 定义你的预测过程 elif mode == tf.estimator.ModeKeys.PREDICT: predictions = {'predictions': pre} output_spec = tf.estimator.EstimatorSpec(mode, predictions=predictions) else: raise TypeError return output_spec
提几点需要注意的地方:
- model_fn方法返回的是tf.estimator.EstimatorSpec;
- TRAIN、EVAL和PREDICT模式不可缺少的参数是不一样的。
构建input_fn方法
def input_fn_bulider(inputs_file, batch_size, is_training): name_to_features = {'inputs': tf.FixedLenFeature([3], tf.float32), 'labels': tf.FixedLenFeature([], tf.float32)} def input_fn(params): d = tf.data.TFRecordDataset(inputs_file) if is_training: d = d.repeat() d = d.shuffle() # map_and_batch其实就是将map和batch结合起来而已 d = d.apply(tf.contrib.data.map_and_batch(lambda x: tf.parse_single_example(x, name_to_features), batch_size=batch_size)) return d return input_fn
执行estimator
if __name__ == '__main': # 定义日志消息的输出级别,为了获取模型的反馈信息,选择INFO tf.logging.set_verbosity(tf.logging.INFO) # 我在这里是指定模型的保存和loss输出频率 runConfig = tf.estimator.RunConfig(save_checkpoints_steps=1, log_step_count_steps=1) estimator = tf.estimator.Estimator(model_fn, model_dir='your_save_path', config=runConfig, params={'lr': 0.01}) # log_step_count_steps控制的只是loss的global_step的输出 # 我们还可以通过tf.train.LoggingTensorHook自定义更多的输出 # tensor是我们要输出的内容,输入一个字典,key为打印出来的名称,value为你要输出的tensor的name logging_hook = tf.train.LoggingTensorHook(every_n_iter=1, tensors={'loss': 'loss'}) # 其实给到estimator.train是一个dataset对象 input_fn = input_fn_bulider('test.tfrecord', batch_size=1, is_training=True) estimator.train(input_fn, max_steps=1000) # 下面你还可以对模型进行验证和测试,做法是差不多的,我就不列举了