关于TensorFlow中Estimator模板的基础理解与使用流程

理解

在TensorFlow 1.4版本之后,官方开始在入门文档中就鼓励使用高层的Estimator API。并且,许多开源代码中也使用了Estimator模板,因此我认为掌握好这个类使用方法,对于能够优雅地书写Tensorflow程序有着重要的意义。

什么是Estimator?在这里先不考虑源码实现细节以及类定义,首先建立一个整体的认识。Estimator,估计器,这个类的核心思想就是把一个网络封装起来,使用类方法中的trainevalpredict等等进行操作。具体的网络细节对于这个类的运行者来说是一个黑盒,只需要提供输入,选择相应的方法,就可以获得输出。

另外,TensorFlow中对于Estimator,不仅有预设好的对象可以直接生成,还可以自己定义。通俗的讲就是使用预先写好的网络框架(例如DNN)还是用户自定义的网络结构。当然了,预先写好的网络框架也不是死的,具体的隐层数目等等参数都是可以在初始化的时候进行设置的。预定义好的Estimator在大部分文档中被称作pre(made) Estimator,具体包含哪些类型的分类器,可以查看这一部分的文档。

Estimator类的主要结构如下,看不明白也没有关系,我们先关注初步的流程框架。需要重点注意的就是其中的model_fn函数,你会在接下来的使用流程中看到这个函数的作用。
在这里插入图片描述

基本使用流程

  1. 首先定义特征列(feature_columns)。这个是之后在Estimator对象初始化时需要接收的必要参数
		my_feature_columns = []
		for key in train_x.keys():
			my_feature_columns.append(
               	tf.feature_column.numeric_column(key=key))
  1. 之后初始化Estimator对象。这里分为两种情况:

    a)如果是使用预定义的Estimator(例如DNNClassifier)则可以直接调用其初始化函数。

		classifier = tf.estimator.DNNClassifier(
			    feature_columns=my_feature_columns,
			    hidden_units=[10, 10],
			    n_classes=3)

b) 如果自定义Estimator(意味着自定义的模型),则首先需要定义model_fn函数,描述模型细节,之后将model_fn与其他params一起传入Estimator的初始化函数中。

		classifier = tf.estimator.Estimator(
		        model_fn=my_model,
		        params={
		            'feature_columns': my_feature_columns,
		            'hidden_units': [10, 10],
		            'n_classes': 3,
		        })

其中params字典中的值将会被传入model_fn中用于定义自定义的模型。

前后进行对比,很容易看出两者的区别和共同点。相比之下,可以说自定义的模型比预设的模型多了一层壳,多传了一次参数

  1. 调用Esitimator对象中的trainevaluate等方法得到结果
	# train the model
	classifier.train(
	        input_fn=lambda:iris_data.train_input_fn(
	                         train_x, train_y, args.batch_size),
	        steps=args.train_steps)
	
	# evalute the model
	eval_result = classifier.evaluate(
	        input_fn=lambda:iris_data.eval_input_fn(
	                         test_x, test_y, args.batch_size))
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值