将数据集转化为TFRecords格式

参考文章:How to write into and read from a TFRecords file in TensorFlow
数据集:Dogs vs. Cats

简介

TensorFlow提供了一种统一的格式来存储数据,这个格式是TFRecord。TFRecord文件中的数据都是通过tf.train.Example Protocol Buffer 的格式存储的。以下代码给出了tf.train.Example 的定义。

message Example {
    Features features = 1;
};

message Features {
    map<string,Feature>feature = 1;
}

message Feature {
    oneof kind {
    BytesList bytes_list = 1;
    FloatList float_list = 2;
    Int64List int64_list = 3;
    }
};

接着便以猫狗大战数据集为例展示TFRecord的生成和读取。

将图像及其标签列表化

首先,我们需要将图片和标签列表化。我们让猫的label=0、狗的label=1。以下代码列表化所有的图片,赋予合适的标签,并对数据进行shuffle。同时也将数据集划分成训练集(60%)和验证集(20%)以及测试集(20%)。

from random import shuffle
import glob
shuffle_data = True  # shuffle the addresses before saving
cat_dog_train_path = 'Cat vs Dog/train/*.jpg'
# read addresses and labels from the 'train' folder
addrs = glob.glob(cat_dog_train_path)
labels = [0 if 'cat' in addr else 1 for addr in addrs]  # 0 = Cat, 1 = Dog
# to shuffle data
if shuffle_data:
    c = list(zip(addrs, labels))
    shuffle(c)
    addrs, labels = zip(*c)

# Divide the hata into 60% train, 20% validation, and 20% test
train_addrs = addrs[0:int(0.6*len(addrs))]
train_labels = labels[0:int(0.6*len(labels))]
val_addrs = addrs[int(0.6*len(addrs)):int(0.8*len(addrs))]
val_labels = labels[int(0.6*len(addrs)):int(0.8*len(addrs))]
test_addrs = addrs[int(0.8*len(addrs)):]
test_labels = labels[int(0.8*len(labels)):]

生成TFRecords文件

首先我们要读取图片并将其转化为我们想保存在TFRecords文件中的数据的格式(本例中为float32)。以下函数完成了图片的读取和resize,并返回一个合适的数据格式。

def load_image(addr):
    # read an image and resize to (224, 224)
    # cv2 load images as BGR, convert it to RGB
    img = cv2.imread(addr)
    img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.astype(np.float32)
    return img

在将数据保存到TFRecords文件之前,我们需要将它放到一个名叫Example的protocol buffer中。接着我们将序列化protocol buffer为string并将它写入TFR文件。Example protocol buffer包含了Features。Feature是一个用于描述数据的protocol,它有三种类型:bytes、float、int64。总而言之,保存你的数据通过以下这些步骤:
1. 使用tf.python_io.TFRecordWriter 打开一个TFRecords文件
2. 使用tf.train.Int64Listtf.train.BytesListtf.train.FloatList 将数据转化为合适类型的feature
3. 使用tf.train.Feature 创建一个feature并将数据传给它
4. 使用tf.train.Example 创建一个Example protocol buffer并将feature传给它
5. 使用example.SerializeToString() 序列化Example为string
6. 将序列化后的example写入:writer.write

本例中我们将使用以下两个函数来创建features:

def _int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

现在讲数据保存到TFRecords文件:

train_filename = 'train.tfrecords'  # address to save the TFRecords file
# open the TFRecords file
writer = tf.python_io.TFRecordWriter(train_filename)
for i in range(len(train_addrs)):
    # print how many images are saved every 1000 images
    if not i % 1000:
        print 'Train data: {}/{}'.format(i, len(train_addrs))
        sys.stdout.flush()
    # Load the image
    img = load_image(train_addrs[i])
    label = train_labels[i]
    # Create a feature
    feature = {'train/label': _int64_feature(label),
               'train/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))}
    # Create an example protocol buffer
    example = tf.train.Example(features=tf.train.Features(feature=feature))

    # Serialize to string and write on the file
    writer.write(example.SerializeToString())

writer.close()
sys.stdout.flush()

类似的,生成验证和测试的TFR文件:

# open the TFRecords file
val_filename = 'val.tfrecords'  # address to save the TFRecords file
writer = tf.python_io.TFRecordWriter(val_filename)
for i in range(len(val_addrs)):
    # print how many images are saved every 1000 images
    if not i % 1000:
        print 'Val data: {}/{}'.format(i, len(val_addrs))
        sys.stdout.flush()
    # Load the image
    img = load_image(val_addrs[i])
    label = val_labels[i]
    # Create a feature
    feature = {'val/label': _int64_feature(label),
               'val/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))}
    # Create an example protocol buffer
    example = tf.train.Example(features=tf.train.Features(feature=feature))
    # Serialize to string and write on the file
    writer.write(example.SerializeToString())
writer.close()
sys.stdout.flush()
# open the TFRecords file
test_filename = 'test.tfrecords'  # address to save the TFRecords file
writer = tf.python_io.TFRecordWriter(test_filename)
for i in range(len(test_addrs)):
    # print how many images are saved every 1000 images
    if not i % 1000:
        print 'Test data: {}/{}'.format(i, len(test_addrs))
        sys.stdout.flush()
    # Load the image
    img = load_image(test_addrs[i])
    label = test_labels[i]
    # Create a feature
    feature = {'test/label': _int64_feature(label),
               'test/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))}
    # Create an example protocol buffer
    example = tf.train.Example(features=tf.train.Features(feature=feature))
    # Serialize to string and write on the file
    writer.write(example.SerializeToString())
writer.close()
sys.stdout.flush()

读取TFRecords文件

TensorFlow文件读取机制参考文章:十图详解tensorflow数据读取机制
为了读取TFR文件,有以下步骤:
1. 创建一个文件名列表:本例中我们只有一个文件data_path='train.tfrecords 因此我们的list应该是[data_path]
2. 创建文件名队列:使用tf.train.string_input_producer 创建一个FIFO队列。它需要传入文件名列表,系统会自动将它转化为一个文件名队列。它还有两个重要的参数,一个是num_epochs 来指定epoch,另一个是shuffle 来指定是否打乱顺序。
3. 定义reader:对于TFR文件我们需要定义一个TFRecordReader–reader = tf.TFRecordReader() 。然后reader返回下一个record–reader.read(filename_queue)
4. 定义decoder:reader读出来的record需要经过decoder的解析。TFR文件的decoder应该是tf.parse_single_example 。它需要传入一个序列化的Example和一个dict(key为feature,value为FixedLenFeature或者VarLenFeature),并返回一个dict(key为feature,value为Tensor)–features = tf.parse_single_example(serialized_example,features=feature)
5. 将数据从string转换回数字:tf.decode_raw(bytes,out_type) 传入一个string类型的Tensor,并将它转换为out_type 类型。当然,对于那些没有转化为string 的label,我们只需要使用tf.cast(x,dtype)
6. 将数据reshape到它原本的shape:image = tf.reshape(image,[224,224,3])
7. 预处理:如果想对数据做预处理请在现在完成
8. Batching:另外的队列用来从examples中创建batches。可以使用tf.train.shuffle_batch([image,label], batch_size=10, capacity=30, num_threads=1, min_after_dequeue=10)capacity 是队列的最大size,·min_after_dequeue 是出列后队列的最小size,num_threads 是入队example 的线程数目。使用多线程可提高读取速度。
读取TFR文件的代码:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
data_path = 'train.tfrecords'  # address to save the hdf5 file
with tf.Session() as sess:
    feature = {'train/image': tf.FixedLenFeature([], tf.string),
               'train/label': tf.FixedLenFeature([], tf.int64)}
    # Create a list of filenames and pass it to a queue
    filename_queue = tf.train.string_input_producer([data_path], num_epochs=1)
    # Define a reader and read the next record
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    # Decode the record read by the reader
    features = tf.parse_single_example(serialized_example, features=feature)
    # Convert the image data from string back to the numbers
    image = tf.decode_raw(features['train/image'], tf.float32)

    # Cast label data into int32
    label = tf.cast(features['train/label'], tf.int32)
    # Reshape image data into the original shape
    image = tf.reshape(image, [224, 224, 3])

    # Any preprocessing here ...

    # Creates batches by randomly shuffling tensors
    images, labels = tf.train.shuffle_batch([image, label], batch_size=10, capacity=30, num_threads=1, min_after_dequeue=10)

初始化全局变量和局部变量
一些函数如tf.traintf.train.shuffle_batch 添加了tf.train.QueueRunner 对象到图中.每个这样的对象都维持了一个列表的入队op。在我们使用tf.train.string_input_producer创建文件名队列后,整个系统其实还是处于“停滞状态”的,也就是说,我们文件名并没有真正被加入到队列中,而使用tf.train.start_queue_runners之后,才会启动填充队列的线程,这时系统就不再“停滞”。此后计算单元就可以拿到数据并进行计算,整个程序也就跑起来了,这就是函数tf.train.start_queue_runners的用处。为了管理线程,需要tf.train.Coordinator 来在合适的时候结束线程。
以下为这部分的代码:

    # Initialize all global and local variables
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    sess.run(init_op)
    # Create a coordinator and run all QueueRunner objects
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    for batch_index in range(5):
        img, lbl = sess.run([images, labels])
        img = img.astype(np.uint8)
        for j in range(6):
            plt.subplot(2, 3, j+1)
            plt.imshow(img[j, ...])
            plt.title('cat' if lbl[j]==0 else 'dog')
        plt.show()
    # Stop the threads
    coord.request_stop()

    # Wait for threads to stop
    coord.join(threads)
    sess.close()
  • 2
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值