【0.4】Tensorflow踩坑记之tf.estimator

本文介绍了如何使用Tensorflow的tf.estimator和tf.feature_column API,结合MNIST数据集,实现图像分类。通过自定义的Custom Estimators构建包含卷积层和全连接层的CNN模型,详细讲解了数据处理、模型定义和训练过程。同时提供了相关代码资源。
摘要由CSDN通过智能技术生成

【0.4】Tensorflow踩坑记之tf.estimator

澜子我终于重回博客界了,虽然说在协会开设课程中关于Tensorflow的部分已经完全PU JIE但是(划重点),本篇博客里还是有满满的干货滴。

本篇博客涉及到的几个重要的API分别是:tf.estimator, tf.feature_column, tf.data, tf.metrics, tf.image

本篇的主要内容就是融合了上述中高阶API,基于MNIST数据集,实现最基本的分类问题,让我们一起愉快滴SAY HELLO TO DEEP LEARNING

先看看一些参考资料

先大致看一眼澜子github上的readme

代码结构

|--tensorflow_estimator_learn               
    |--data_csv
        |--mnist_test.csv
        |--mnist_train.csv
        |--mnist_val.csv
        
    |--images
        |--ZJUAI_2018_AUT
        |--ZJUAI_2018_AUT

    |--tmp
        |--ZJUAI_2018_AUT
        |--ZJUAI_2018_AUT
        
    |--CNNClassifier.jpynb
    
    |--CNNClassifier_dataset.jpynb    
    
    |--CNN_raw.jpynb    
    
    |--DNNClassifier.jpynb    
    
    |--DNNClassifier_dataset.jpynb    

data_csv

data_csv文件中存放了MNSIT原始csv文件,分为验证、训练、测试三个部分

images

images文件中存放了jupyter notebook中所涉及的一些图片

tmp

tmp 文件中存放了一些临时代码

CNNClassifier.jpynb

未采用tf.dataAPI的自定义estimator实现

CNNClassifier_dataset.jpynb

采用tf.dataAPI的自定义estimator实现

CNN_raw.jpynb

未采用高阶API的 搭建CNN实现MNIST分类

DNNClassifier.jpynb

未采用tf.dataAPI的预制sestimator实现

DNNClassifier_dataset.jpynb

采用tf.dataAPI的预制estimator实现

简单瞅瞅estimator的地位

tf.estimator是TensorFlow里封装性很好的高级API,之所以要用tf.estimator,是因为运用高级API可以很好地减少我们的代码量。当然也可能因为高级API的高封装性带来一些使用上的不灵活性。

estimator

我们具体要怎么用呢

让我们一起来看图说话
dataflow

  • 创建数据集输入函数 input_fn
  • 定义特征列 tf.feature_column
  • 实例化estimator tf.estimator.DNNClassifier
  • 训练 / 验证 / 测试 model.train()/eval()/test()

需要注意的地方

  • 本篇博客没有采用图中所示的 Pre-made estimators 而是采用了自定义的 Custom Estimators,不同点在于是否需要定义自己的 model_fn
  • 关于 Pre-made estimators的用法,我也放在了github上,大噶也可以去看看。
    estimator

简单瞅一眼数据流

下面的图表直接显示了本次MNIST例子的数据流向,共有2个卷积层,每一层卷积之后采用最大池化进行下采样(图中并未画出),最后接2个全连接层,实现对MNIST数据集的分类
flow

不说闲话了,CODE走起

STEP 0:前期准备工作

  • 导入各种库
%matplotlib inline
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import multiprocessing

from tensorflow import data
from tensorflow.python.feature_column import feature_column

tf.__version__
  • 导入MNIST数据集及基本参数定义
TRAIN_DATA_FILES_PATTERN = 'data_csv/mnist_train.csv'
VAL_DATA_FILES_PATTERN = 'data_csv/mnist_val.csv'
TEST_DATA_FILES_PATTERN = 'data_csv/mnist_test.csv'

MULTI_THREADING = True
RESUME_TRAINING = False

NUM_CLASS = 10
IMG_SHAPE = [28,28]

IMG_WIDTH = 28
IMG_HEIGHT = 28
IMG_FLAT = 784
NUM_CHANNEL = 1

BATCH_SIZE = 128
NUM_TRAIN = 55000
NUM_VAL = 5000
NUM_TEST = 10000

train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN, header=None)
test_data = pd.read_csv(TEST_DATA_FILES_PATTERN, header=None)
val_data = pd.read_csv(VAL_DATA_FILES_PATTERN, header=None)

train_values = train_data.values
train_data = train_values[:,1:]/255.0
train_label = train_values[:,0
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值