开始
万恶的Tensorflow老是喜欢把训练数据处理成tfrecord的格式,之前踩了tensor2tensor的坑,大坑小坑不断这里主要说TFRecord的事情,就不细说这些。
网上有的方法
搜了下网上的方法,其实是有对tensorflow的TFRecord操作的使用大全,不过感觉都比较冗长(真的是又长又凶残),有的感觉是文档的翻译,所以我这里就记录下简短的TFRecord的读取方式
准备
- 要有准备被读取的TFRecord文件
- 装好的tensorflow框架(废话)
- 我本机上的tensorflow版本:1.15.0
读取代码
import tensorflow as tf
# 如果想直接在shell里执行,记得加上eager
tfe = tf.contrib.eager
tfe.enable_eager_execution()
# 这里是读取TFRecord文件
filenames = ['record.tfrecord']
raw_dataset = tf.data.TFRecordDataset(filenames)
# 这里是真正的读取代码
# 因为这个raw_dataset是iterator,可以用take的方法取出对应的N个
for raw_record in raw_dataset.take(10):
example = tf.train.Example()
example.ParseFromString(raw_record.numpy())
# 在这里你就可以打印或者取出对应的记录
print(example)
后记
- 其实之前主要是想核对tensor2tensor框架生成的文本记录是否带的,所以就特来开箱检查下对应的TFRecord记录。以前每次每次检查TFRecord都要去翻之前的代码,故而我就直接保留在这里了,方便自己也方便大家。
- 而且我发现大家确实有不少这个问题,主要是tensorflow对数据的封装造成了没有csv,tsv那么直观,估计是为了性能而放弃了部分可视化能力吧。(叹气)