Tensorflow 1.x和Tensorflow 2.x读取tfrecord方法略有不同,下面分别记录两段代码:
Tensorflow 1.x:
for example in tf.python_io.tf_record_iterator(path):
# print(tf.train.Example.FromString(example))
jsonMessage = MessageToJson(tf.train.Example.FromString(example))
Tensorflow 2.x:
import tensorflow as tf
import json
from google.protobuf.json_format import MessageToJson
dataset = tf.data.TFRecordDataset("mydata.tfrecord")
for d in dataset:
ex = tf.train.Example()
ex.ParseFromString(d.numpy())
m = json.loads(MessageToJson(ex))
print(m['features']['feature'].keys())