TFrecord 是 tensorflow 中的数据集存储格式,它的写入和读取相当于序列化和反序列化的过程。
下面内容都是基于 tf2 版本来说明,两个版本中 tfrecord 的核心并没有改变,即序列化和反序列化。但是个人认为 tf2 里面对于读取 tfrecords 文件建立 Dataset 直接用于训练的支持更方便好用了,比如 batch、shuffle、repeat 等方面,所以在这点上基本摒弃了 tf1。
写入/序列化
(1)将数据读到内存,并转换为 tf.train.Example
对象,每个对象由若干个 tf.train.Feature 的字典组成。
(2)将 tf.train.Example
对象序列化为字符串,写入 TFRecord 文件。
读取/反序列化
(1)通过 tf.data.TFRecordDataset
读入原始的 TFRecord 文件,获得一个 tf.data.Dataset
的数据集对象。(tf1 中是要创建一个 reader 来读取 tfrecords 文件中的样例)
(2)通过 Dataset.map
对数据集对象中每个序列化的 tf.train.Example
字符串执行 tf.io.parse_single_example
实现反序列化。
*:map 过程中,无法在 parse 内部进行某些处理,只能 parse 之后在 dataset 中迭代器“拿”出数据之后进行一些转换。
tf.train.Feature
上面多次提到的 tf.train.Feature
支持 3 种数据格式,因此对于各种各样的数据,必须处理成对应这三种格式,才能顺利写入/读取。
3 种格式如下:
tf.train.BytesList
: 字符串或原始 Byte 文件,通过 bytes_list
传入。以图片或者数组等类型数据为例,需要转为字符串类型的再传入 BytesList,后面会有例子。
tf.train.FloatList
:浮点数,通过 float_list
传入。
tf.train.Int64List
:整数,通过 int64_list
传入。
具体实现
#-*-coding:utf-8-*-
import tensorflow as tf
from tensorflow.python.platform import gfile
import cv2
# write
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
img_path = 'path/to/read/img'
tfrd_path = 'path/to/save/tfrecords'
image = cv2.imread(img_path)
h, w, c = image.shape
image_raw_data = gfile.FastGFile(img_path, 'rb').read()
label = 1
tfrd_writer = tf.io.TFRecordWriter(tfrd_path)
feature = {'image': _bytes_feature(image_raw_data),
'label': _int64_feature(label),
'imgH': _int64_feature(h),
'imgW': _int64_feature(w),
'imgC': _int64_feature(c)}
example = tf.train.Example(features=tf.train.Features(feature=feature))
tfrd_writer.write(example.SerializeToString())
tfrd_writer.close()
# read
raw_dataset = tf.data.TFRecordDataset(tfrd_path)
feature_description = {'image': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.int64),
'imgH': tf.io.FixedLenFeature([], tf.int64),
'imgW': tf.io.FixedLenFeature([], tf.int64),
'imgC': tf.io.FixedLenFeature([], tf.int64)
}
def parse(record):
features = tf.io.parse_single_example(record, feature_description)
images = tf.io.decode_jpeg(features['image'])
labels = features['label']
imgH, imgW, imgC = features['imgH'], features['imgW'], features['imgC']
shape = [imgH, imgW, imgC]
return images, labels, shape
dataset = raw_dataset.map(parse)
for image, label, shape in dataset:
print(label)
cv2.imshow('img', image.numpy())
cv2.waitKey()
如果遇到需要将 numpy 写入 tfrecord,可以先将 numpy 转为字符串,然后写入;读取的时候再转为 numpy 即可,注意 dtype 的对应。
import numpy as np
gt_row_np = np.array([0, 0, 1, 0], dtype=np.uint8)
gt_row_str = gt_row_np.tostring()
gt_row = np.frombuffer(gt_row_str, dtype=np.uint8)
print('gt_row_np type: {}, dtype: {}, value: {}'.format(type(gt_row_np), gt_row_np.dtype, gt_row_np))
print('gt_row_str type: {}, value: {}'.format(type(gt_row_str), gt_row_str))
print('gt_row type: {}, dtype: {}, value: {}'.format(type(gt_row), gt_row.dtype, gt_row))
# output
gt_row_np type: <class 'numpy.ndarray'>, dtype: uint8, value: [0 0 1 0]
gt_row_str type: <class 'bytes'>, value: b'\x00\x00\x01\x00'
gt_row type: <class 'numpy.ndarray'>, dtype: uint8, value: [0 0 1 0]