estimator使用流程共有四部
第一步、定义input_fn函数
定义input_fn函数,构建数据集,包括数据预处理、数据增广
第二步、定义model_fn函数
- 构建模型。
- 计算学习率、构建优化器、创建train_op操作。
- 定义性能指标(性能指标在命令行或summary操作中都会用到)
第三部、实例化tf.estimator.Estimator
定义训练过程中相关操作,包括什么时候进行summary/save/logging操作,summary/save操作的保存路径。
设置tf.Session的参数。
传入自定义的hook,进行定制训练。
第四部、对象的train、evaluate、predict方法
通过 tf.estimator.Estimator 对象的train、evaluate、predict方法,传入input_fn函数进行对应的操作。
可以通过传入hooks来实现自定义功能。