tensorflow(2)——读取数据TFrecord

  • 学习这个是因为搞tensorflow肯定跳不过这个坑,所以还不如静下心来好好梳理一下。
  • 本文学完理论会优化自己以前的一个分类代码,从原来最古老的placeholder版本做一下优化——启发是来自transformer的源码,它的做法让我觉得我有必要体会一下。

TFrecord

  • 注意,这里他只是一种文件存储格式的改变,前文那些队列的思想是没变的!!!

简单介绍

  • TFRecords其实是一种二进制文件,虽然它不如其他格式好理解,但是它能更好的利用内存,更方便复制和移动,并且不需要单独的标签文件。总而言之,这样的文件格式好处多多。

  • TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。我们可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件。

  • 从TFRecords文件中读取数据, 可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个操作可以将Example协议内存块(protocol buffer)解析为张量。

  • 其实我是这样理解的,我们可以把存入文件和读取文件看作一种”通信协议“,我首先指定一下我们交互信息的协议,然后我存的时候这么存进去,读的时候也这么读出来,仅此而已!

开篇

def to_tfrecord(file_name,train_data,train_label):

    # 这里准备一个样本一个样本的写入TFRecord file中
    # 先把每个样本中所有feature的信息和值存到字典中,key为feature名,value为feature值。
    # feature值需要转变成tensorflow指定的feature类型中的一个。
    # tensorflow feature类型只接受list数据

    writer = tf.python_io.TFRecordWriter('%s.tfrecord' %file_name)

    for i in range(len(train_data)):

        # 写入字典
        features = {}

        # 写入向量,类型float,本身就是list,所以"value=vectors[i]"没有中括号
        features['data'] = tf.train.Feature(float_list=tf.train.FloatList(value=train_data[i]))
        features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=train_label[i]))

        # 转化为tf_features
        tf_features = tf.train.Features(feature=features)

        # 再将其变成一个样本example
        tf_example = tf.train.Example(features=tf_features)

        # 序列化该样本
        tf_serialized = tf_example.SerializeToString()

        # 写入一个序列化的样本
        writer.write(tf_serialized)

    writer.close()

读取(我感觉我碰到了最玄学的问题)

正常

# 使用TF_record导入数据

# 使用TF_record导入数据

filenames = "test.tfrecord"
filename_queue = tf.train.string_input_producer([filenames], num_epochs=None,
                                                shuffle=True)
# **2.创建一个读取器
reader = tf.TFRecordReader()

_, serialized_example = reader.read(filename_queue)

# **3.根据你写入的格式对应说明读取的格式
features = tf.parse_single_example(serialized_example,
                                   features={
                                       'data': tf.FixedLenFeature(shape=[100], dtype=tf.float32),
                                        'label': tf.FixedLenFeature(shape=[2], dtype=tf.float32)}     # 而标量就不用说明
                                   )
X_out = features['data']
y_out = features['label']

X_batch, y_batch = tf.train.shuffle_batch([X_out, y_out], batch_size=2,
                                          capacity=200, min_after_dequeue=100, num_threads=2)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

# **5.启动队列进行数据读取
# 下面的 coord 是个线程协调器,把启动队列的时候加上线程协调器。
# 这样,在数据读取完毕以后,调用协调器把线程全部都关了。
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
y_outputs = list()
for i in range(5):
    _X_batch, _y_batch = sess.run([X_batch, y_batch])
    print('** batch %d' % i)
    print('_X_batch:', _X_batch)
    print('_y_batch:', _y_batch)
    y_outputs.extend(_y_batch.tolist())
print(y_outputs)

# **6.最后记得把队列关掉
coord.request_stop()
coord.join(threads)

在这里插入图片描述

报错代码:

def parse_function(example_proto):
    # 只接受一个输入:example_proto,也就是序列化后的样本tf_serialized

    # 解析规则
    # 也可以把形状信息存入example_proto里,然后在下面用
    dics = {
        'data': tf.FixedLenFeature(shape=[100], dtype=tf.float32, default_value=0.0),
        'label': tf.FixedLenFeature(shape=[2], dtype=tf.float32)
    }

    # 解析样本
    parsed_example = tf.parse_single_example(example_proto,dics)

    # parsed_example['data'] = tf.reshape(parsed_example['data'], (1,100))
    #
    # # 转变tensor形状
    # parsed_example['label'] = tf.reshape(parsed_example['label'], (1,2))

    # 转变特征
    return parsed_example




# 使用TF_record导入数据

filenames = "test.tfrecord"
dataset = tf.data.TFRecordDataset(filenames)

'''由于从tfrecord文件中导入的样本是刚才写入的tf_serialized序列化样本,
所以我们需要对每一个样本进行解析。这里就用dataset.map(parse_function)来对dataset里的每个样本进行相同的解析操作。'''

new_dataset = dataset.map(parse_function)

# 创建迭代器
iterator = new_dataset.make_one_shot_iterator()

# 获取样本
next_element = iterator.get_next()

sess = tf.Session()

sess.run(next_element['data'])

在这里插入图片描述

END

  • 这个报错挖个坑,下篇填。
  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值