Tensorflow笔记(八)——Estimator

1、概述

TF1.3开始引入estimator,通过框架图可以看到Estimator是属于High level的API,而Mid-level API分别是:

  • Layers:用来构建网络结构
  • Datasets: 用来构建数据读取pipeline
  • Metrics:用来评估网络性能

Estimator简化和抽象了管理训练、评估和预测,跟Keras类似,Estimator是模型级别的API,我们不再需要关注Session这样的操作,只需要几步简单的定义就能让模型run起来。
在这里插入图片描述在这里插入图片描述

2、使用方法

主要步骤:
(1)实例化一个Estimator,并与model_fn绑定。
(2)定义model_fn函数,包含模型结构、优化方法等,返回。
(3)定义定义input_fn函数,作为参数传递给Estimator的train函数,最后会被model_fn消费。

下面是一个基本的模式:
(1)构建estimator

import tensorflow as tf
# 创建一个名为classifier的estimator
classifier = tf.estimator.Estimator(model_fn, model_dir=None, 
	config=None, params=None, warm_start_from=None)
  • model_fn是模型函数,estimator做的事情就是把获取到的参数传给model_fn。
  • model_dir是存储和加载检查点和事件文件的路径,如果这里不设置就会从config里读取,如果都不设置则保存到临时目录,如果都设置则必须一致。
  • config是estimator.RunConfig对象,主要是分布式训练和模型存储相关配置。
  • params是传入model_fn的参数。
  • warm_start_from是预训练模型路径。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值