tensorflow 标准数据读取 tfrecords

原创 2018年04月14日 22:03:41

TensorFlow提供了一种TFRecords的格式来统一存储数据。理论上,TFRecords可以存储任何形式的数据 , TFRecords文件的是以tf.train.Example Protocol Buffer的格式存储的。以下的代码给出了tf.train.Example的数据结构:

    message Example {
        Features features = 1;
    };
    message Features {
        map<string, Feature> feature = 1;
    };
    message Feature {
        oneof kind {
        BytesList bytes_list = 1;
        FloatList float_list = 2;
        Int64List int64_list = 3;
    }
    };
首先介绍一下我接下来要展示给大家的工程结构(使用的IDE是 pycharm 2017 community):

工程结构

接下来代码分三个文件, 分别是 加载数据 prepare_data.py ,制作tfrecords文件 make_data.py, 读取tfrecords文件read_data.py。

1.prepare_data.py

下面代码中数据增强部分我就略过了,可以参考tensorflow数据增强

# -*- coding: utf-8 -*-
import os
import cv2
import numpy as np

dir = "imgs/"  # 加载jpg, testShape = (333, 500, 3)


def data_augmentation(data):
    """
    数据增强处理
    :param data:
    :return: 
    """
    return data


def get_img_data(file_dir):
    """
    获取图片数据, 返回类型是 list
    :param file_dir: 图片所在目录
    :return: 返回类型是 list
    """
    files = [os.path.join('imgs', x) for x in os.listdir(file_dir)]
    raw_data = [cv2.imread(img) for img in files]
    raw_data = data_augmentation(raw_data)
    return raw_data


if __name__ == "__main__":
    get_img_data(dir)
  1. make_data.py
# _*_ coding: utf-8 _*_

import tensorflow as tf
import numpy as np

from prepare_data import get_img_data

# tfrecords 支持的数据类型
# tf.train.Feature(int64_list = tf.train.Int64List(value=[int_scalar]))
# tf.train.Feature(bytes_list = tf.train.BytesList(value=[array_string_or_byte]))
# tf.train.Feature(bytes_list = tf.train.FloatList(value=[float_scalar]))

# 创建tfrecords文件
file_nums = 2
instance_per_file = 5
dir = "imgs/"

data = get_img_data(dir)  # type(data) list
for i in range(file_nums):
    tfrecords_filename = './tfrecords/train.tfrecords-%.5d-of-%.5d' % (i, file_nums)
    writer = tf.python_io.TFRecordWriter(tfrecords_filename)  # 创建.tfrecord文件

    for j in range(instance_per_file):
        # type(data[i*instance_per_file+j]) numpy.ndarray
        img_raw = np.asarray(data[i*instance_per_file+j]).tostring()

        example = tf.train.Example(features=tf.train.Features(
            feature={
                'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[j])),
                'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
            }))
        writer.write(example.SerializeToString())

    writer.close()
  1. read_data.py
import tensorflow as tf
import numpy as np
import cv2
import matplotlib.pyplot as plt


# 读取tfrecords文件
# --------------hyperParams--------------------------
batch_size = 2
capacity = 1000 + 3*batch_size
train_rounds = 3
num_epochs = 30
img_h = 333
img_w = 500
# ---------------------------------------------------

tfrecord_files = tf.train.match_filenames_once('./tfrecords/train.tfrecords-*')
queue = tf.train.string_input_producer(tfrecord_files, num_epochs=num_epochs, shuffle=True, capacity=10)

reader = tf.TFRecordReader()
# 从文件中读出一个队列, 也可以使用read_uo_to函数一次性读取多个样例
_, serialized_example = reader.read(queue)

# 读取多个对应tf.parse_example()
# 读取单个对应tf.parse_single_example()

features = tf.parse_single_example(
    serialized_example, features={
        'label': tf.FixedLenFeature([], tf.int64),
        'img_raw': tf.FixedLenFeature([], tf.string),
    }
)


image = tf.decode_raw(features['img_raw'], tf.uint8)
# image_shape = tf.stack([img_h, img_w, 3])
image = tf.reshape(image, [img_h, img_w, 3])
label = tf.cast(features['label'], tf.int64)


# tf.train.shuffle_batch()
to_train_batch, to_label_batch = tf.train.shuffle_batch(
    [image, label], batch_size=batch_size, capacity=capacity,
    allow_smaller_final_batch=True, num_threads=1, min_after_dequeue=1
)


with tf.Session() as sess:
    sess.run(
        tf.group(tf.local_variables_initializer(), tf.global_variables_initializer())
    )

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    for i in range(train_rounds):
        train_batch, label_batch = sess.run([to_train_batch, to_label_batch])
        plt.subplot(121)
        plt.imshow(train_batch[0])
        plt.subplot(122)
        plt.imshow(train_batch[1])
        plt.show()
    coord.request_stop()
    coord.join(threads)

print('finish all')
# 下图是read_data.py 读取 tfrecords 的结果:

读取数据后结果还原

Tensorflow中使用TFRecords高效读取数据--结合NLP数据实践

之前一篇博客在进行论文仿真的时候用到了TFRecords进行数据的读取操作,但是因为当时比较忙,所以没有进行深入学习。这两天看了一下,决定写篇博客专门结合该代码记录一下TFRecords的相关操作。 ...
  • liuchonge
  • liuchonge
  • 2017-06-23 20:15:20
  • 10338

由浅入深之Tensorflow(3)----数据读取之TFRecords

由浅入深之Tensorflow(3)----数据读取之TFRecords
  • jacke121
  • jacke121
  • 2017-08-27 12:12:18
  • 283

tensorflow数据读取之tfrecords

掌握一个深度学习框架的用法,从训练一个模型的流程来看,需要掌握以下几个步骤: 1. 数据的处理,包括训练数据转成网络的输入,模型参数的存储与读取 2. 网络结构的定义,包括网络主体的搭建以及los...
  • yaoqi_isee
  • yaoqi_isee
  • 2017-08-24 10:32:59
  • 1496

tensorflow读取SVHN数据集转为TFrecords格式

这里默认将python脚本文件和svhn数据集放在同一目录下,FLAGS.directory参数可以指定数据集的目录,由于svhn没有validation数据集,因此将train分割一部分出来作为va...
  • qikaihuting
  • qikaihuting
  • 2017-05-10 20:02:50
  • 1066

Tensorflow建立与读取TFrecorder文件

Tensorflow建立与读取TFrecorder文件除了直接读取数据文件,比如csv和bin文件,tensorflow还可以建立一种自有格式的数据文件,称之为tfrecorder,这种文件储存类似于...
  • freedom098
  • freedom098
  • 2017-02-20 13:20:13
  • 6837

TensorFlow读取tfrecords数据

因为要用到TensorFlow,自然少不了数据的读取,这里我自己写了一个tfrecords的数据的读取函数 """ Created on Wed Jun 28 13:56:35 2017 @auth...
  • m0_37041325
  • m0_37041325
  • 2017-07-09 16:58:34
  • 421

Tensorflow 中Tfrecords的使用心得

这篇博客主要讲了如何用Tensorflow中的标准数据读取方式的简单的实现对自己数据的读取操作。...
  • u014802590
  • u014802590
  • 2017-03-30 20:10:53
  • 2735

Tensorflow分批量读取tfrecords

Tensorflow分批量读取tfrecords
  • jacke121
  • jacke121
  • 2017-12-17 11:12:20
  • 191

Tensorflow之构建自己的图片数据集TFrecords

学习谷歌的深度学习终于有点眉目了,给大家分享我的Tensorflow学习历程。    tensorflow的官方中文文档比较生涩,数据集一直采用的MNIST二进制数据集。并没有过多讲述怎么构建自己的图...
  • csuzhaoqinghui
  • csuzhaoqinghui
  • 2016-05-11 20:23:20
  • 17677

用Tensorflow处理自己的数据:制作自己的TFRecords数据集

转载请注明作者和出处: http://blog.csdn.net/c406495762 运行平台: Windows Python版本: Python3.x IDE: Spyder前言   ...
  • WIinter_FDd
  • WIinter_FDd
  • 2017-06-01 22:28:56
  • 11762
收藏助手
不良信息举报
您举报文章:tensorflow 标准数据读取 tfrecords
举报原因:
原因补充:

(最多只允许输入30个字)