1.制作tfrecord文件
#tf_filename:需提前新建一个tfrecord空白文件。然后使用tf.python_io.TFRecordWriter() 建立一个writer
with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:
#dataset为字典格式
for i,image_example in enumerate(dataset):
#小技巧,在制作dataset的时候,只是把图片的路径放到里面,这样可以节省内存。
filename = image_example['filename']
image = cv2.imread(filename)
image_data = image.tostring()
class_label = image_example['balel']
#创建example
example = tf.train.Example(features=tf.train.Features(feature={
'image':tf.train.Feature(bytes_list=tf.train.BytesList(value=image_data)),
'label':tf.train.Feature(int64_list=tf.train.Int64List(value=class_label))}))
#把example写入writer
tfrecord_writer.write(example.SerializeToString())
2.读取tfrecord文件
##############读取tfrecord
def read_single_tfrecord(filename):
filename_queue = tf.train.string_input_producer(filename,shuffle=True)
#创建reader
reader = tf.TFRecordReader()
_,serizlized_example = reader.read(filename_queue)
#解析
images_features = tf.parse_single_example(
serizlized_example,
features={
'image':tf.FixedLenFeature([],tf.string),
'label':tf.FixedLenFeature([],tf.int64)
})
image = tf.decode_raw(images_features['image'],tf.unit8)
制作和读取tfrecord
最新推荐文章于 2020-09-24 13:29:43 发布