序
- 学习这个是因为搞tensorflow肯定跳不过这个坑,所以还不如静下心来好好梳理一下。
- 本文学完理论会优化自己以前的一个分类代码,从原来最古老的placeholder版本做一下优化——启发是来自transformer的源码,它的做法让我觉得我有必要体会一下。
TFrecord
- 注意,这里他只是一种文件存储格式的改变,前文那些队列的思想是没变的!!!
简单介绍
-
TFRecords其实是一种二进制文件,虽然它不如其他格式好理解,但是它能更好的利用内存,更方便复制和移动,并且不需要单独的标签文件。总而言之,这样的文件格式好处多多。
-
TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。我们可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件。
-
从TFRecords文件中读取数据, 可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个操作可以将Example协议内存块(protocol buffer)解析为张量。
-
其实我是这样理解的,我们可以把存入文件和读取文件看作一种”通信协议“,我首先指定一下我们交互信息的协议,然后我存的时候这么存进去,读的时候也这么读出来,仅此而已!
开篇
-
基本代码为,目的:把你的数据转化成tf_record文件
def to_tfrecord(file_name,train_data,train_label):
# 这里准备一个样本一个样本的写入TFRecord file中
# 先把每个样本中所有feature的信息和值存到字典中,key为feature名,value为feature值。
# feature值需要转变成tensorflow指定的feature类型中的一个。
# tensorflow feature类型只接受list数据
writer = tf.python_io.TFRecordWriter('%s.tfrecord' %file_name)
for i in range(len(train_data)):
# 写入字典
features = {}
# 写入向量,类型float,本身就是list,所以"value=vectors[i]"没有中括号
features['data'] = tf.train.Feature(float_list=tf.train.FloatList(value=train_data[i]))
features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=train_label[i]))
# 转化为tf_features
tf_features = tf.train.Features(feature=features)
# 再将其变成一个样本example
tf_example = tf.train.Example(features=tf_features)
# 序列化该样本
tf_serialized = tf_example.SerializeToString()
# 写入一个序列化的样本
writer.write(tf_serialized)
writer.close()
读取(我感觉我碰到了最玄学的问题)
正常
# 使用TF_record导入数据
# 使用TF_record导入数据
filenames = "test.tfrecord"
filename_queue = tf.train.string_input_producer([filenames], num_epochs=None,
shuffle=True)
# **2.创建一个读取器
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# **3.根据你写入的格式对应说明读取的格式
features = tf.parse_single_example(serialized_example,
features={
'data': tf.FixedLenFeature(shape=[100], dtype=tf.float32),
'label': tf.FixedLenFeature(shape=[2], dtype=tf.float32)} # 而标量就不用说明
)
X_out = features['data']
y_out = features['label']
X_batch, y_batch = tf.train.shuffle_batch([X_out, y_out], batch_size=2,
capacity=200, min_after_dequeue=100, num_threads=2)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
# **5.启动队列进行数据读取
# 下面的 coord 是个线程协调器,把启动队列的时候加上线程协调器。
# 这样,在数据读取完毕以后,调用协调器把线程全部都关了。
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
y_outputs = list()
for i in range(5):
_X_batch, _y_batch = sess.run([X_batch, y_batch])
print('** batch %d' % i)
print('_X_batch:', _X_batch)
print('_y_batch:', _y_batch)
y_outputs.extend(_y_batch.tolist())
print(y_outputs)
# **6.最后记得把队列关掉
coord.request_stop()
coord.join(threads)
报错代码:
- 使用datasets
参考最最上面那个正常的代码
def parse_function(example_proto):
# 只接受一个输入:example_proto,也就是序列化后的样本tf_serialized
# 解析规则
# 也可以把形状信息存入example_proto里,然后在下面用
dics = {
'data': tf.FixedLenFeature(shape=[100], dtype=tf.float32, default_value=0.0),
'label': tf.FixedLenFeature(shape=[2], dtype=tf.float32)
}
# 解析样本
parsed_example = tf.parse_single_example(example_proto,dics)
# parsed_example['data'] = tf.reshape(parsed_example['data'], (1,100))
#
# # 转变tensor形状
# parsed_example['label'] = tf.reshape(parsed_example['label'], (1,2))
# 转变特征
return parsed_example
# 使用TF_record导入数据
filenames = "test.tfrecord"
dataset = tf.data.TFRecordDataset(filenames)
'''由于从tfrecord文件中导入的样本是刚才写入的tf_serialized序列化样本,
所以我们需要对每一个样本进行解析。这里就用dataset.map(parse_function)来对dataset里的每个样本进行相同的解析操作。'''
new_dataset = dataset.map(parse_function)
# 创建迭代器
iterator = new_dataset.make_one_shot_iterator()
# 获取样本
next_element = iterator.get_next()
sess = tf.Session()
sess.run(next_element['data'])
END
- 这个报错挖个坑,下篇填。