目录
1、TFRecord介绍
TFRecord 是 TensorFlow 中的数据集中存储格式,TFRecord是一种二进制文件。
将数据集整理成 TFRecord 格式后,TensorFlow 就可以高效地读取和处理这些数据集,从而更高效地进行大规模的模型训练。
TFRecord 内部使用了二进制数据编码方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可。简单,快速,尤其对大型训练数据很友好。而且当我们的训练数据量比较大的时候,可以将数据分成多个 TFRecord 文件,来提高处理效率。
2、TFRecord格式数据文件处理过程
将形式各样的数据集整理为 TFRecord 格式,可以对数据集中的每个元素进行以下步骤:
(1)读取该数据元素到内存;
(2)将该元素转换为 tf.train.Example 对象(每一个 tf.train.Example 由若干个 tf.train.Feature 的字典组成,因此需要先建立 Feature 的字典);
(3)将该 tf.train.Example 对象序列化为字符串,并通过一个预先定义的 tf.io.TFRecordWriter 写入 TFRecord 文件。
读取 TFRecord 数据可按照以下步骤:
(1)通过 tf.data.TFRecordDataset 读入原始的 TFRecord 文件(此时文件中的 tf.train.Example 对象尚未被反序列化),获得一个 tf.data.Dataset数据集对象;
(2)通过 Dataset.map 方法,对该数据集对象中的每一个序列化的 tf.train.Example 字符串执行 tf.io.parse_single_example 函数,从而实现反序列化。
3、TFRecord格式
TFRecord内部包含了多个tf.train.Example
, 而Example
是protocol buffer
(protobuf) 数据标准的实现,在一个Example
消息体中包含了一系列的tf.train.feature
属性,而每一个feature
是一个key-value
的键值对,其中,key
是string类型,而value
的取值有三种:
bytes_list:
可以存储string
和byte
两种数据类型。float_list:
可以存储float(float32)
与double(float64)
两种数据类型 。int64_list:
可以存储bool, enum, int32, uint32, int64, uint64
。
tf.train.Feature 支持三种数据格式:
- tf.train.BytesList :字符串或原始 Byte 文件(如图片),通过 bytes_list 参数传入一个由字符串数组初始化的 tf.train.BytesList 对象;
- tf.train.FloatList :浮点数,通过 float_list 参数传入一个由浮点数数组初始化的 tf.train.FloatList 对象;
- tf.train.Int64List :整数,通过 int64_list 参数传入一个由整数数组初始化的 tf.train.Int64List 对象。
如果只希望保存一个元素而非数组,传入一个只有一个元素的数组即可
4、生成TFRecord格式数据
import os
import tensorflow as tf
# 读取数据集中图片文件名和标签
def read_image_filenames (data_dir) :
cat_dir = data_dir + "cat/"
dog_dir = data_dir + "dog/"
cat_filenames = [cat_dir + fn for fn in os.listdir(cat_dir)]
dog_filenames = [dog_dir + fn for fn in os.listdir(dog_dir)]
filenames = cat_filenames + dog_filenames
# 将cat类的标签设为0, dog类的标签设为1
labels = [0]* len(cat_filenames) + [1] *len(dog_filenames)
return filenames,labels
# 定义生成TFRecord格式数据文件函数
def write_TFRecord_file(filenames,labels,tfrecord_file):
with tf.io.TFRecordWriter(tfrecord_file) as writer:
for filename,label in zip(filenames,labels) :
# 读取数据集图片到内存,image 为一个 Byte类型的字符串
image = open(filename,"rb").read()
# 建立tf.train.Feature字典
feature = {
# 图片是一个Bytes对象
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
# 标签是一个Int对象
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}
# 通过feature字典建立Example
example = tf.train.Example(features=tf.train.Features(feature=feature))
# 将Example序列化并写入TFRecord 文件
writer.write(example.SerializeToString())
train_data_dir = './data_small/train/' # 数据集路径
tfrecord_file = train_data_dir + 'train.tfrecords' # 生成的tfrecord路径
if not os.path.isfile(tfrecord_file): # 判断train.tfrecord是否存在
train_filenames,train_labels = read_image_filenames(train_data_dir)
write_TFRecord_file(train_filenames,train_labels,tfrecord_file)
print('write TFRecord file:',tfrecord_file)
else:
print(tfrecord_file,'already exists.')
5、TFRecord数据文件解码
1、定义TFRecord数据文件解码函数
# 定义Feature结构,告诉解码器每个Feature的类型是什么,要与生成的TFrecord的类型一致
feature_description = {
"image":tf.io.FixedLenFeature([],tf.string),
"label":tf.io.FixedLenFeature([],tf.int64)
}
# 将TFRecord 文件中的每一个序列化的 tf.train.Example 解码
def parse_example(example_string):
feature_dict = tf.io.parse_single_example(example_string,feature_description)
feature_dict['image'] = tf.io.decode_jpeg(feature_dict['image']) # 解码JPEG图片
feature_dict['image'] = tf.image.resize(feature_dict['image'],[224,224])/ 255.0 # 改变图片尺寸并进行归一化
return feature_dict['image'],feature_dict['label']
2、定义读取TFRecord文件,解码并生成Dataset数据集的函数
def read_TFRecond_file(tfrecord_file):
# 读取TFRecord 文件
raw_dataset = tf.data.TFRecordDataset(tfrecord_file)
# 解码
dataset = raw_dataset.map(parse_example)
return dataset
3、tf.data.TFRecordDataset
tfrecord文件创建一个TFRecordDataset类的实例对象
参数:tf.data.TFRecordDataset(filenames,compression_type=None,
buffer_size=None,num_parallel_reads=None)
一般只传第一个参数filenames即可 ,生成的tfrecord文件
6、解码并生成Dataset数据集
# Dataset的数据缓冲器大小,和数据集大小及规律有关
buffer_size = 20000
# Dataset的数据批次大小,每批次多少个样本数
batch_size = 8
dataset_train = read_TFRecond_file(tfrecord_file) # 解码
dataset_train = dataset_train.shuffle(buffer_size) # 打乱数据
dataset_train = dataset_train.batch(batch_size) # 分批次进行读取
7、查看第一批元素
import matplotlib.pyplot as plt
sub_dataset = dataset_train.take(1) # 读取第一个批次
for images,labels in sub_dataset:
fig,axs = plt.subplots(1, batch_size)
for i in range(batch_size):
axs[i].set_title(labels.numpy()[i])
axs[i].imshow(images.numpy()[i])
axs[i].set_xticks([])
axs[i].set_yticks([])
plt.show()
案例实例地址:Tfrecord介绍以及实例· GitHub
链接:猫狗大战数据集
提取码:kqgt