Tensorflow之Estimator初探

当tensorflow的文档停留在1.4版本的时候,我们的GET STARTED讲的是全连接网络训练
MNIST手写数据,这对于一个机器学习入门玩家来说,显得相当不友好。时过境迁,现在的
GET STARTED讲的是使用Estimator来训练一个Iris识别网络,隐藏层仅为10*10的网络。
这对于新手来说简单了许多。同时,以Estimator作为启蒙,代替了容易出错的底层接口实
现。这简化了对程序进行debug的难度,如有兴趣请自行查阅。
今天让我们简单过一下Estimator的基本概念。

TENSORFLOW GET STARTED教程
如果找不到数据点这里
如果想了解Dataset相关内容点这里

转载请注明出处

Estimator的核心概念

Estimator的核心想法就是:把工作网络封装成一个类,训练、评估、预测都是类方法。

在这种封装里面,首先隐藏了网络结构,对于程序运行者,只需要考虑输入输出。同时,包含了对参数数据的保存、对训练状态的保存,使得训练过程可复现,可追溯。其数据的管理,交给了Dataset,在进一步了解Dataset后,二者的协作会使得编写tensorflow程序变得井井有条。

以下就是一个Estimator的实例化

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns, #特征行向量
    hidden_units=[10, 10], #网络结构
    n_classes=3 #输出类型个数
    )

这是一个DNN分类器,用于对Iris进行分类
特征向量:Iris的特征,这里有四个特征,所以传入长度为4的数组
网络结构:以数组的方式定义网络结构
n_classes:类别个数,这里为3个

Tips:如果你想要下载教程代码并自己进行训练,有几个小坑先填一下:
1、本文提供的数据,需要自己分离训练样本和测试样本
2、本文提供的数据中,花名用spring所表示,当使用Estimator的时候,必须相应的转换为数字
如'Setosa', 'Versicolor', 'Virginica'转换位0、1、2其最大数必须不能大于3,因为类别
是3。否则程序会报错(为什么!贼不爽)
3、在iris_data.py中将train_input_fn和eval_input_fn的返回值改为
dataset.make_one_shot_iterator().get_next()为什么?因为代码本来就错了啊

这样我们就实例化了一个Estimator,这个Estimator是Tensorflow自带的。

Estimator的训练、评估和预测

这部分直接上代码吧。
训练:

classifier.train(
    input_fn=lambda:iris_data.train_input_fn(train_x, train_y, args.batch_size),
    steps=args.train_steps)

评估:

eval_result = classifier.evaluate(
    input_fn=lambda:iris_data.eval_input_fn(test_x, test_y, args.batch_size))

预测

predictions = classifier.predict(
    input_fn=lambda:iris_data.eval_input_fn(predict_x,
                                            batch_size=args.batch_size))

可以见得使用Estimator之后,我们几乎可以不关心网络长什么样子,而只需要将数据进行合理的组织,就可以运用机器学习工作了。在必要的情况下,还可以使用tf.estimator.RunConfig来设置运行时参数,控制训练、备份的过程。

当训练完成后,参数、tensorboardlog都被自动保存在module_dir下,你可以在初始化、或者config的时候自己指定这个dir,也可以使用Estimator的函数将这个dir打印出来。

classifier.latest_checkpoint()

当然,如果想要自己定义网络,只要继承Estimator的父类就可以了,这是更深的一个议题。

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值