最近在使用estimator做项目,发现 官网 对 这个 estimator整体的讲解 和使用 过程中的细节讲的比较少,结合 我是用过程中的问题,对 estimator的使用步骤进行了总结,如下:代码 见github,求star~~
1. estimator主要需要model_fn,input_fn 以及 serving_fn
2. model_fn主要是是用来定义model ,input_fn主要是用来 定义输入(一般情况下只负责用来定义 train和evaluate),serving_fn用来定义 serving过程中的输入
针对代码中estimator_template.ipynb详细说一下,您可以 参照着代码来看说明:
1. 建立模型
def create_model(params):
# 定义网络结构 和 损失 以及 返回值
pass
def model_fn_builder(params):
# 该方法实际 创建 estimator的model_fn
# 可以 有其他操作
def model_fn(features, labels, mode, params,config) #estimator需要的model_fn 参数固定
'''
features: from input_fn的返回 切记返回的顺序
labels: from input_fn 的返回 切记返回的顺序
mode: tf.estimator.ModeKeys实例的一种
params: 在初始化estimator时 传入的参数列表,dict形式,或者直接使用self.params也可以
config:初始化estimator时 的 Runconfig
'''
create_model(params)
if mode==tf.estimator.ModeKeys.PREDICT: # 执行预测
#...
elif mode==tf.estimator.ModeKeys.EVAL: #评估
#...
elif mode=tf.estimator.ModeKeys.TRAIN: # 训练
#...
#......其它操作
# 最后返回
return tf.estimator.EstimatorSpec(......)
return model_fn
在此,我将 model_fn这块,进行了分拆:create_model,model_fn_builder(返回model_fn)
create_model 只负责网络架构的创建,而不包括 后续 损失计算和返回的定义,这个操作 我统一放在了 model_fn中进行定义(为了 让 各个方法 只负责对应的事情),还有一点 需要注意,在create_model中,最后一层的输出,最好不适用 激活函数,而是在model_fn中对model的输出 进行 相应的操作,这样就可以保证 model可以共用
2. 输入方法:
def input_fn_builder(params):
'''
创建 输入函数闭包
'''
# 可以执行其它操作
def input_fn(......):
# 具体操作......
return features,labels # 返回的 顺序要和 model_fn一致 或者 dataset元素 格式为(features,label)元组 也可以
return input_fn
这个函数 返回 输入函数闭包
3. serving_fn
def serving_input_receiver_fn():
'''
定义模型导出后,serving的输入值
'''
#.......各种数据转换
# 在此处 多说一些 关于 batch_features以及 receiver_tensor
# 1. 首先 这两个 参数,相互之间 并没有 直接 的 关系(切记,没有直接关系,说明还是 有间接关系的)
# 2. batch_features这个参数的格式必须 满足 model_fn中features参数格式
# 2.1 关于值的格式,首先他必须是 tensor或者sparseTensor 或者 字典格式(value必须是tensor/sparsetensor),然后features被传给model
# 2.2 如果 features不是字典,则 该方法会自动将其封装为dict(视为一个样本),并使用‘feature’作为key
# 2.3 总结:model必须接受一个形如{'feature':tensor}的字典作为入参
# 3.receiver_tensor 这个参数 是用来接收 请求 的 参数,改参数 一般可以 用一个 placeholder代替,后续经过各种变化,
# 将receiver_tensor的值 转换为model_fn中features格式
# 3.1 必须是 tensor或者sparseTensor 或者 字典格式(value必须是tensor/sparsetensor)
return tf.estimator.export.ServingInputReceiver(batch_features,receiver_tensor)
强调一下,这一块 仔细的看,因为 涉及到 生产部署。这块也很容易出错
3. 模型训练,评估,预测,导出
这块直接看代码 并参照 tf官网 就会明白
4. 部署
具体部署方法还是 看代码吧,没什么 可说的,部署的话 还是 建议 使用 tensorflow serving docker方法,太方便了,后续 对k8s支持的也很好
最后,estimator.ipynb文件 是我按照 这个 步骤 写的一个demo
知乎: https://zhuanlan.zhihu.com/albertwang
微信公众号:AI-Research-Studio
下面是赞赏码