【0.4】Tensorflow踩坑记之tf.estimator
澜子我终于重回博客界了,虽然说在协会开设课程中关于Tensorflow的部分已经完全PU JIE,但是(划重点),本篇博客里还是有满满的干货滴。
本篇博客涉及到的几个重要的API分别是:tf.estimator, tf.feature_column, tf.data, tf.metrics, tf.image。
本篇的主要内容就是融合了上述中高阶API,基于MNIST数据集,实现最基本的分类问题,让我们一起愉快滴SAY HELLO TO DEEP LEARNING
先看看一些参考资料
- 先甩出 Tensorflow官网的API文档
- 再甩出 tf.estimator官方教程
- 澜子前期关于 tf.data 和 tf.metrics 的教程
- 本篇博客的所有代码都已经传到 澜子的github 啦,欢迎follow和star啊喂
- 强烈建议大家点开看一下,因为是用jupyter notebook,所以每一个代码块都有对应的输出,还有我自己写的一些注释,应该还是很好理解的
先大致看一眼澜子github上的readme
代码结构
|--tensorflow_estimator_learn
|--data_csv
|--mnist_test.csv
|--mnist_train.csv
|--mnist_val.csv
|--images
|--ZJUAI_2018_AUT
|--ZJUAI_2018_AUT
|--tmp
|--ZJUAI_2018_AUT
|--ZJUAI_2018_AUT
|--CNNClassifier.jpynb
|--CNNClassifier_dataset.jpynb
|--CNN_raw.jpynb
|--DNNClassifier.jpynb
|--DNNClassifier_dataset.jpynb
data_csv
data_csv文件中存放了MNSIT原始csv文件,分为验证、训练、测试三个部分
images
images文件中存放了jupyter notebook中所涉及的一些图片
tmp
tmp 文件中存放了一些临时代码
CNNClassifier.jpynb
未采用tf.data
API的自定义estimator实现
CNNClassifier_dataset.jpynb
采用tf.data
API的自定义estimator实现
CNN_raw.jpynb
未采用高阶API的 搭建CNN实现MNIST分类
DNNClassifier.jpynb
未采用tf.data
API的预制sestimator实现
DNNClassifier_dataset.jpynb
采用tf.data
API的预制estimator实现
简单瞅瞅estimator的地位
tf.estimator是TensorFlow里封装性很好的高级API,之所以要用tf.estimator,是因为运用高级API可以很好地减少我们的代码量。当然也可能因为高级API的高封装性带来一些使用上的不灵活性。
我们具体要怎么用呢
让我们一起来看图说话
- 创建数据集输入函数 input_fn
- 定义特征列 tf.feature_column
- 实例化estimator tf.estimator.DNNClassifier
- 训练 / 验证 / 测试 model.train()/eval()/test()
需要注意的地方
- 本篇博客没有采用图中所示的 Pre-made estimators 而是采用了自定义的 Custom Estimators,不同点在于是否需要定义自己的 model_fn
- 关于 Pre-made estimators的用法,我也放在了github上,大噶也可以去看看。
简单瞅一眼数据流
下面的图表直接显示了本次MNIST例子的数据流向,共有2个卷积层,每一层卷积之后采用最大池化进行下采样(图中并未画出),最后接2个全连接层,实现对MNIST数据集的分类
不说闲话了,CODE走起
STEP 0:前期准备工作
- 导入各种库
%matplotlib inline
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import multiprocessing
from tensorflow import data
from tensorflow.python.feature_column import feature_column
tf.__version__
- 导入MNIST数据集及基本参数定义
TRAIN_DATA_FILES_PATTERN = 'data_csv/mnist_train.csv'
VAL_DATA_FILES_PATTERN = 'data_csv/mnist_val.csv'
TEST_DATA_FILES_PATTERN = 'data_csv/mnist_test.csv'
MULTI_THREADING = True
RESUME_TRAINING = False
NUM_CLASS = 10
IMG_SHAPE = [28,28]
IMG_WIDTH = 28
IMG_HEIGHT = 28
IMG_FLAT = 784
NUM_CHANNEL = 1
BATCH_SIZE = 128
NUM_TRAIN = 55000
NUM_VAL = 5000
NUM_TEST = 10000
train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN, header=None)
test_data = pd.read_csv(TEST_DATA_FILES_PATTERN, header=None)
val_data = pd.read_csv(VAL_DATA_FILES_PATTERN, header=None)
train_values = train_data.values
train_data = train_values[:,1:]/255.0
train_label = train_values[:,0