TFRecords文件的生成和读取

转载自 https://blog.csdn.net/qq_36330643/article/details/77366083

关于TensorFlow读取数据,官网提供了3种方法:
1)Feeding : 在TensorFlow程序运行的每一个epoch,用python代码在线提供数据
2)Reader : 在一个计算图(tf.graph)的开始前,将文件读入到流(queue)中
3)在声明 tf.variable 变量或 numpy 数组时保存数据。受限于内存大小,适用于数据较小的情况

准备图片数据

分成两个文件夹,分别存放两类图片
现在用这两类图片制作TFRecords文件

制作TFRecords文件

1)TFRecords是一种将图像数据和标签放在一起的二进制文件,能更好的利用内存,在TensorFlow中快速的复制、移动、读取、存储等。
注意:存储数据时,TFRecords会根据你选择输入文件的类,自动给每一类打上同样的标签。在本例中,只有0,1两类。
2)制作TFRecords文件的代码

import tensorflow as tf
from PIL import Image
cwd = 'D:\tfrecords测试\dog\\'
classes = {'husky', 'chihuahua'}

def creat_tfrecords(tfrecords_name):
    writer = tf.python_io.TFRecordWriter(tfrecords_name)

    for index, name in enumerate(classes):

        for img_name in os.listdir(class_path):
            img_path = class_path + img_name

            img = Image.open(img_path)
            img = img.resize((128, 128))
            img_raw = img.tobytes()
            example = tf.train.Example(features=tf.train.Features(feature={
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
                "img_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
            }))
            writer.write(example.SerializeToString())

    writer.close()

调用以上函数后,会生成一个tfrecords文件,名字为函数的输入参数tfrecords_name
3)tf.train.Example 协议内存块包含了Features字段,通过feature将图片的二进制数据和label进行统一封装, 然后将example协议内存块转化为字符串, tf.python_io.TFRecordWriter 写入到TFRecords文件中。

读取TFRecords文件

1) 在制作完tfrecords文件后,将该文件读入到数据流中,代码如下:

def read_and_decode(filename):
    filename_queue = tf.train.string_input_producer([filename])

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,
                                       features={'label': tf.FixedLenFeature([], tf.int64),
                                                 'img_raw': tf.FixedLenFeature([], tf.string),
                                       })

    image = tf.decode_raw(features['img_raw'], tf.uint8)
    image = tf.reshape(image, [128, 128, 3])
    label = tf.cast(features['label'], tf.int32)

    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for i in range(20):
            example, l = sess.run([image, label])

            # 下面两行代码功能是 将读取到的img存储为.jpg图片 
            img = Image.fromarray(example, 'RGB')
            img.save(cwd + str(i) + '_Label_' + str(l) + '.jpg')

            print(example, l)

        coord.request_stop()
        coord.join(threads)

注意: example, l = sess.run([image, label]) 的功能是将tensor格式的image和label转换为array

2)调用以上函数后,会在文件夹中得到20张以一定规则命名的.jpg文件

读取一个batch的数据

1)我们训练神经网络的时候往往不是将数据一个一个的输入到网络里,而是输入一个batch的数据。将多个输入样例组织成一个batch可以提高模型的训练效率。
TensorFlow提供了 tf.train.batchtf.shuffle_batch 函数来将单个的样例组织成batch的形式输出。它们唯一的区别在于是否会将数据的顺序打乱。
tf.train.batch 有三个输入参数,第一个是数据,第二个是batch_size,第三个是 capacity。
capacity表示组合样例的队列中最多可以存储的样例个数,这个队列如果太大,那么需要占用很多内存资源;如果太小,那么出队操作可能会因为没有数据而被阻碍(block),从而导致训练效率降低。一般来说这个队列的大小会和每一个batch的大小有关,一般可设置为 capacity = 1000 + 3 * batch_size
tf.train.shuffle_batchtf.train.batch 多一个参数 min_after_dequeue.
min_after_dequeue 限制了出队时队列中元素的最少个数。当队列中元素太少时,随机打乱样例顺序的作用就不大了。当出队函数被调用但是队列中元素不够时,出队操作将等待更多的元素入队才会完成,如果 min_after_dequeue 参数被设定, capacity 也应该相应调整来满足性能要求。

下面代码说明读取一个batch的方法:

# 读取tfrecords文件的函数
def read_and_decode(filename):
    filename_queue = tf.train.string_input_producer([filename])

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example, features={
        'label':tf.FixedLenFeature([], tf.int64),
        'img_raw':tf.FixedLenFeature([], tf.string),
    })
    image = tf.decode_raw(features['img_raw'], tf.uint8)
    image = tf.reshape(image, [200, 200, 3])
    label = tf.cast(features['label'], tf.int32)

    return image, label

----------

example, label = read_and_decode('train.tfrecords')
# 读取一个batch的数据和标签
example_batch, label_batch = tf.train.shuffle_batch([example, label], batch_size=3, capacity=1000, min_after_dequeue=500)
# 开启一个会话
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    # 取3次batch_size=3的数据,并打印出标签 看取数据的情况。
    for i in range(3):
        cur_example_batch, cur_label_batch = sess.run([example_batch, label_batch])
        print(cur_label_batch)

    coord.request_stop()
    coord.join(threads)

将label转换为one hot形式

参考https://blog.csdn.net/a_yangfh/article/details/77911126
详细讲解参考上述博客


以手写数字识别为例,我们需要将0-9共十个数字标签转化成onehot标签。例如:数字标签“6”转化为onehot标签就是[0,0,0,0,0,0,1,0,0,0]
而我们通过tfrecords读取到的label并不是one hot 形式的,而是类似于[1, 2, 3, 4, 5]形式的,其中数字n代表第n类。
通过以下函数可以将我们读取到的label 转换成 one hot 形式:

import tensorflow as tf
# labels为需要转换的标签
# NUM_CLASSES为类别数
# batch_size为batch大小
def turn_to_onehot(labels, NUM_CLASSES, batch_size):
    labels = tf.expand_dims(labels, 1)  # 增加一个维度
    indices = tf.expand_dims(tf.range(0, batch_size, 1), 1)  # 生成索引
    concated = tf.concat([indices, labels], 1)  # 作为拼接
    onehot_labels = tf.sparse_to_dense(concated, tf.stack([batch_size, NUM_CLASSES]), 1.0, 0.0)  # 生成one-hot编码的标签
    return onehot_labels
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值