tutorial_cifar10_tfrecord.py解读

把Github上TensorLayer的示例——tutorial_cifar10_tfrecord.py看了一下,做了一些中文注解,在此分享一下,方便学习Tensorflow的童鞋们快速了解之。针对TF r0.12的情况,将源代码做了相应的改动。
win10_x64 + Python3.5 + CUDA8.0 + Cudnn5.1环境下运行正常。

#! /usr/bin/python
# -*- coding: utf8 -*-


import tensorflow as tf
import tensorlayer as tl
# from tensorlayer.layers import set_keep
import numpy as np
import time
from PIL import Image
import os
import io

"""重现 TensorFlow 官方 CIFAR-10 卷积神经网络 指导书:
- 该模型有 1,068,298 个参数, 使用GPU训练数小时后准确率可达86%.

描述
-----------
图片作如下处理:
.. 图片被裁剪为 24 x 24 像素, 集中评估或随机训练.
.. 为使模型对动态范围不敏感,图片被近似白化.

为了改善训练效果,我们还对图片应用了一系列的随机变换来人为地增加数据集的大小:
.. 随机左右翻转.
.. 随机改变图片亮度.
.. 随机改变图片对比度.

加速
--------
从磁盘读取图像并进行变换耗费了不短的处理时间。为减轻这些操作对训练速度的影响,
我们在16个独立的线程运行,不断填补tensorflow队列
"""
model_file_name = r"W:\cifar10\models\model_cifar10_advanced.ckpt" # win10下自定义的model和checkpoint文件保存位置
resume = False  # 载入已存在的模型, 从之前的checkpoint重新开始吗?

## 下载数据集, 并转化为TFRecord格式
X_train, y_train, X_test, y_test = tl.files.load_cifar10_dataset(
    shape=(-1, 32, 32, 3), plotable=False)

X_train = np.asarray(X_train, dtype=np.float32)
y_train = np.asarray(y_train, dtype=np.int64)
X_test = np.asarray(X_test, dtype=np.float32)
y_test = np.asarray(y_test, dtype=np.int64)

print('X_train.shape', X_train.shape)  # (50000, 32, 32, 3)
print('y_train.shape', y_train.shape)  # (50000,)
print('X_test.shape', X_test.shape)  # (10000, 32, 32, 3)
print('y_test.shape', y_test.shape)  # (10000,)
print('X %s   y %s' % (X_test.dtype, y_test.dtype))


def data_to_tfrecord(images, labels, filename): # 定义格式转化函数(转化为TFRecord格式)
    """ 将数据转化为TFRecord格式 """
    print("Converting data into %s ..." % filename)
    cwd = os.getcwd() # 获取当前目录路径
    writer = tf.python_io.TFRecordWriter(filename) # 创建TFRecord格式文件
    for index, img in enumerate(images):
        img_raw = img.tobytes() # 将numpy数组类型的图像转化为bytes类型
        ## 可视化一张图像
        # tl.visualize.frame(np.asarray(img, dtype=np.uint8), second=1, saveable=False, name='frame', fig_idx=1236)
        label = int(labels[index])
        # print(label)
        ## 将bytes格式文件转换回图像的操作如下:
        # image = Image.frombytes('RGB', (32, 32), img_raw)
        # image = np.fromstring(img_raw, np.float32)
        # image = image.reshape([32, 32, 3])
        # tl.visualize.frame(np.asarray(image, dtype=np.uint8), second=1, saveable=False, name='frame', fig_idx=1236)
        example = tf.train.Example(features=tf.train.Features(feature={
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
        }))
        writer.write(example.SerializeToString())  # Serialize To String
    writer.close() # 关闭文件


def read_and_decode(filename, is_train=None):
    """ 从TFRecord文件读取并返回tensor """
    filename_queue = tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw': tf.FixedLenFeature([], tf.string),
                                       })
    # You can do more image distortion here for training data
    img = tf.decode_raw(features['img_raw'], tf.float32)
    img = tf.reshape(img, [32, 32, 3])
    # img = tf.cast(img, tf.float32) #* (1. / 255) - 0.5
    if is_train == True:
        # 1. 随机裁剪图像中[height, width] 的一部分.
        img = tf.random_crop(img, [24, 24, 3])
        # 2. 随机水平翻转.
        img = tf.image.random_flip_left_right(img)
        # 3. 随机改变亮度.
        img = tf.image.random_brightness(img, max_delta=63)
        # 4. 随机改变对比度.
        img = tf.image.random_contrast(img, lower=0.2, upper=1.8)
        # 5. 标准化(去均值,除方差).
        img = tf.image.per_image_standardization(img)
    elif is_train == False:
        # 1. 裁剪图片中央的[height, width]部分.
        img = tf.image.resize_image_with_crop_or_pad(img, 24, 24)
        # 2. 标准化(去均值,除方差).
        img = tf.image.per_image_standardization(img)
    elif is_train == None:
        img = img

    label = tf.cast(features['label'], tf.int32)
    return img, label


data_to_tfrecord(images=X_train, labels=y_train, filename="train.cifar10")
data_to_tfrecord(images=X_test, labels=y_test, filename="test.cifar10")

## 数据可视化举例
# img, label = read_and_decode("train.cifar10", None)
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值