tensorflow estimator使用总结

最近在使用estimator做项目,发现 官网 对 这个 estimator整体的讲解 和使用 过程中的细节讲的比较少,结合 我是用过程中的问题,对 estimator的使用步骤进行了总结,如下:代码 见github,求star~~

1. estimator主要需要model_fn,input_fn 以及 serving_fn

2. model_fn主要是是用来定义model  ,input_fn主要是用来 定义输入(一般情况下只负责用来定义 train和evaluate),serving_fn用来定义 serving过程中的输入

针对代码中estimator_template.ipynb详细说一下,您可以 参照着代码来看说明:

1. 建立模型

def create_model(params):
    # 定义网络结构 和 损失 以及 返回值
    pass


def  model_fn_builder(params):
    # 该方法实际 创建 estimator的model_fn
    # 可以 有其他操作
    def model_fn(features, labels, mode, params,config) #estimator需要的model_fn 参数固定
    '''
    features: from input_fn的返回  切记返回的顺序
    labels: from input_fn 的返回  切记返回的顺序
    mode: tf.estimator.ModeKeys实例的一种
    params: 在初始化estimator时 传入的参数列表,dict形式,或者直接使用self.params也可以
    config:初始化estimator时 的 Runconfig
    
    '''
        create_model(params)
        if mode==tf.estimator.ModeKeys.PREDICT: # 执行预测
            #...
        elif mode==tf.estimator.ModeKeys.EVAL: #评估
            #...
        elif mode=tf.estimator.ModeKeys.TRAIN: # 训练
            #...
        
        #......其它操作
        
        # 最后返回
        return tf.estimator.EstimatorSpec(......)
    return model_fn

在此,我将 model_fn这块,进行了分拆:create_model,model_fn_builder(返回model_fn)

create_model 只负责网络架构的创建,而不包括 后续 损失计算和返回的定义,这个操作 我统一放在了 model_fn中进行定义(为了 让 各个方法 只负责对应的事情),还有一点 需要注意,在create_model中,最后一层的输出,最好不适用 激活函数,而是在model_fn中对model的输出 进行 相应的操作,这样就可以保证 model可以共用

2. 输入方法:

def input_fn_builder(params):
    '''
    创建 输入函数闭包
    '''
    
    # 可以执行其它操作
    
    def input_fn(......):
        # 具体操作......
        return features,labels # 返回的 顺序要和 model_fn一致 或者 dataset元素 格式为(features,label)元组 也可以
    return input_fn

这个函数 返回 输入函数闭包 

3. serving_fn

def serving_input_receiver_fn():
    '''
    定义模型导出后,serving的输入值
    '''
    #.......各种数据转换
    # 在此处 多说一些 关于 batch_features以及 receiver_tensor
    # 1. 首先 这两个 参数,相互之间 并没有 直接 的 关系(切记,没有直接关系,说明还是 有间接关系的)
    # 2. batch_features这个参数的格式必须 满足 model_fn中features参数格式
    # 2.1 关于值的格式,首先他必须是 tensor或者sparseTensor 或者 字典格式(value必须是tensor/sparsetensor),然后features被传给model
    # 2.2 如果 features不是字典,则 该方法会自动将其封装为dict(视为一个样本),并使用‘feature’作为key
    # 2.3 总结:model必须接受一个形如{'feature':tensor}的字典作为入参
    # 3.receiver_tensor 这个参数 是用来接收 请求 的 参数,改参数 一般可以 用一个 placeholder代替,后续经过各种变化,
    # 将receiver_tensor的值 转换为model_fn中features格式
    # 3.1 必须是 tensor或者sparseTensor 或者 字典格式(value必须是tensor/sparsetensor)
    return tf.estimator.export.ServingInputReceiver(batch_features,receiver_tensor)

强调一下,这一块 仔细的看,因为 涉及到 生产部署。这块也很容易出错 

3. 模型训练,评估,预测,导出

这块直接看代码 并参照 tf官网 就会明白

4. 部署

具体部署方法还是 看代码吧,没什么 可说的,部署的话 还是 建议 使用 tensorflow serving docker方法,太方便了,后续 对k8s支持的也很好

最后,estimator.ipynb文件 是我按照 这个 步骤 写的一个demo

 

知乎: https://zhuanlan.zhihu.com/albertwang

微信公众号:AI-Research-Studio

https://i-blog.csdnimg.cn/blog_migrate/5509f60f875d387159a310532cc257dd.png ​​

下面是赞赏码

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值