如果你苦于Tensorflow官方例程的冗长和数据读取的繁杂,那么这篇博客就是你需要看的啦~
Motivation
这段时间学习TensorFlow走了很多弯路,刚开始的MNIST例程还是比较容易啃的。后来钻研im2txt,大概啃了两三个星期,发现结构比MNIST复杂很多,实在难懂。遂转而看CIFAR10例程,但是光数据读取就啃了一个星期,现在都比较迷糊其中队列的用法。
经过这一个月的折腾,我发现Tensorflow虽然网络结构很好定义,而且用户可以方便灵活的自定义层,但是也有一些坑:
- 例程里有大量函数和方法用于描述代码架构。注意这些东西不是描述网络的,而是仅仅为了让代码更有层次。当然,这也是google工程师的牛逼之处吧。不过对于新手来说却有点迷茫,因为你除了会看到conv、relu、pool,还将会看到一大堆name_space、variable_sapce、tf.collection等等,这些代码量甚至超过模型描述的代码,有喧宾夺主的感觉。
- 数据读取比较复杂,如果深究的话,往往要理解Tensorflow中的“队列”。例如CIFAR10的数据读取函数,返回的是image节点,而不是image数据。若不能理解Tensorflow中的“队列”,你将难以理解image节点的含义,也就无法自如的对其操作了。(感兴趣的读者可自行阅读官方cifar10_input.py)
———————–因此————————-
我毅然决定自己造轮子,用一种清新脱俗的方式重写官方CIFAR10教程。
本代码优点:
- 数据读取采用LiFeiFei在cs231n课程上给出的cifar10读取方法,将cifar10数据整个读成numpy.array,熟悉numpy的你可以在数据层随便预处理。并且本代码给出了用numpy进行预处理方式(包括如何用numpy对数据进行shuffle,如何把label变成onehot形式,以及如何每个epoch之后都进行重新shuffle)。
- 数据预处理借鉴tensorflow里的crop、flip等函数,并解决了tensorflow某些预处理函数只能对一张图片处理,而不能对一个batch处理的问题。
- 官方教程在训练阶段只能输出train loss,而难以输出test loss(我为了观察过拟合,已经在官方教程上尝试各种方法,均以失败告终,这也是我重写CIFAR10数据层的动机)而我的code可以实时输出train loss、test loss、train accuracy、test accuracy。并在tensorboard上显示。
- 本代码风格极简,从上到下依次是 预处理、构造graph、训练和评测。结构清晰,基本不存在来回跳转,尽量避免使用tensorflow里牛逼但不必须的函数,所以叫做“清新脱俗”
- 本代码的模型部分与官方教程一致,运行的accuracy和loss曲线也与官方一致,证明代码不是胡写的。
- 注释比代码还多,而且都是英文注释,方便国际友人阅读(实际上是我的IDE不支持中文输入)。注释中包含了对以上各种trick的详细描述,以及相关trick在stackoverflow上的网址,以及调参遇到过的坑。导致注释比代码还多。。。
本代码缺点:
- 数据读取采用LiFeiFei在cs231n课程上给出的cifar10读取方法,将cifar10数据整个读成numpy.array。如果数据集更大,将很可能out of memory。所以本代码只是提供了一种小数据集的读取方法。大数据集要用tensorflow的queue来实现啦
运行结果:
刚刚运行时:
step=500k时:
代码(包括cifar10_easy_demo.py和data_utils.py)
cifar10_easy_demo.py
import re
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from data_utils import load_CIFAR10
#==============================================================================
# configurations
#==============================================================================
batch=128
INITIAL_LEARNING_RATE=0.1
LEARNING_RATE_DECAY_FACTOR=0.1
NUM_CLASSES=10
decay_steps=130000
MAX_STEPS=500000
log_dir='/home/kcheng/tensorflow_easy_demo/cifar10_log'
checkpoint_dir='/home/kcheng/tensorflow_easy_demo/cifar10_checkpoint'
def load_onehot_shuffle():
#==============================================================================
# this function is to load the data, shufflow it, split it, and transfer label into one-hot label
# Notice that following operations are implemented by numpy. In theory ,this operation is implemented by queue, see detial in tensorflow offical code for cifar10
# But that offical code really puzzled me, because I can't understand how the queue return a batch, so I decide to preprocess data with numpy instead of queue
# This is one of the biggest shortcoming of my code, because my code load all training sets into memory.
# Fortunatly, cifar10 datasets is so small that our memory can hold it.
# If you try our code on bigger datasets(i.e. ImageNet), you may run out of memory.
# All in all, this function shows a easy way to load small datasets as cifar10.
# If you have any suggestions for loading or preprocessiong data in tensorflow, please reply me
#==============================================================================
#==============================================================================
# read data
#==============================================================================
cifar10_dir = '/home/kcheng/tmaster/models-master/tutorials/image/cifar10/winter1516_assignment1/assignment1/cs231n/datasets/cifar-10-batches-py'
X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)
# As a sanity check, we print out the size of the training and test data.
print 'Training data shape: ', X_train.shape
print 'Training labels shape: ', y_train.shape
print 'Test data shape: ', X_test.shape
print 'Test labels shape: ', y_test.shape
#==============================================================================
# one_hot
#==============================================================================
y_train_one_hot=np.zeros((50000,10))
y_train_one_hot[np.arange(50000), y_train] = 1
print y_train_one_hot.shape
y_test_one_hot=np.zeros((10000,10))
y_test_one_hot[np.arange(10000), y_test] = 1
print y_test_one_hot.shape
y_train=y_train_one_hot
y_test=y_test_one_hot
#==========================================================&