当tensorflow的文档停留在1.4版本的时候,我们的GET STARTED讲的是全连接网络训练
MNIST手写数据,这对于一个机器学习入门玩家来说,显得相当不友好。时过境迁,现在的
GET STARTED讲的是使用Estimator来训练一个Iris识别网络,隐藏层仅为10*10的网络。
这对于新手来说简单了许多。同时,以Estimator作为启蒙,代替了容易出错的底层接口实
现。这简化了对程序进行debug的难度,如有兴趣请自行查阅。
今天让我们简单过一下Estimator的基本概念。
TENSORFLOW GET STARTED教程
如果找不到数据点这里
如果想了解Dataset相关内容点这里
转载请注明出处
Estimator的核心概念
Estimator的核心想法就是:把工作网络封装成一个类,训练、评估、预测都是类方法。
在这种封装里面,首先隐藏了网络结构,对于程序运行者,只需要考虑输入输出。同时,包含了对参数数据的保存、对训练状态的保存,使得训练过程可复现,可追溯。其数据的管理,交给了Dataset,在进一步了解Dataset后,二者的协作会使得编写tensorflow程序变得井井有条。
以下就是一个Estimator的实例化
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns, #特征行向量
hidden_units=[10, 10], #网络结构
n_classes=3 #输出类型个数
)
这是一个DNN分类器,用于对Iris进行分类
特征向量:Iris的特征,这里有四个特征,所以传入长度为4的数组
网络结构:以数组的方式定义网络结构
n_classes:类别个数,这里为3个
Tips:如果你想要下载教程代码并自己进行训练,有几个小坑先填一下:
1、本文提供的数据,需要自己分离训练样本和测试样本
2、本文提供的数据中,花名用spring所表示,当使用Estimator的时候,必须相应的转换为数字
如'Setosa', 'Versicolor', 'Virginica'转换位0、1、2其最大数必须不能大于3,因为类别
是3。否则程序会报错(为什么!贼不爽)
3、在iris_data.py中将train_input_fn和eval_input_fn的返回值改为
dataset.make_one_shot_iterator().get_next()为什么?因为代码本来就错了啊
这样我们就实例化了一个Estimator,这个Estimator是Tensorflow自带的。
Estimator的训练、评估和预测
这部分直接上代码吧。
训练:
classifier.train(
input_fn=lambda:iris_data.train_input_fn(train_x, train_y, args.batch_size),
steps=args.train_steps)
评估:
eval_result = classifier.evaluate(
input_fn=lambda:iris_data.eval_input_fn(test_x, test_y, args.batch_size))
预测
predictions = classifier.predict(
input_fn=lambda:iris_data.eval_input_fn(predict_x,
batch_size=args.batch_size))
可以见得使用Estimator之后,我们几乎可以不关心网络长什么样子,而只需要将数据进行合理的组织,就可以运用机器学习工作了。在必要的情况下,还可以使用tf.estimator.RunConfig来设置运行时参数,控制训练、备份的过程。
当训练完成后,参数、tensorboardlog都被自动保存在module_dir下,你可以在初始化、或者config的时候自己指定这个dir,也可以使用Estimator的函数将这个dir打印出来。
classifier.latest_checkpoint()
当然,如果想要自己定义网络,只要继承Estimator的父类就可以了,这是更深的一个议题。