2.4 TFRecord数据集制作实战
TensorFlow提供了TFRecords格式来统一存储数据,从理论上讲,TFRecords可以存储任何形式的数据。TFRecord是一种二进制文件,具有以下优点:
- 统一各种输入文件的操作
- 更好的利用内存,方便复制和移动
- 将二进制数据和标签(label)存储在同一个文件中。
在本节的内容中,将详细讲解制作并操作TFRecord数据集的知识。
2.4.1 将图片制作为TFRecord数据集
在“img”目录中有两个子目录“0”和“1”,在两个子目录中分别保存了图片。然后编写实例文件data05.py,功能是将上述两个子目录“0”和“1”中的图片制作成TFRecord数据集。文件data05.py的具体实现代码如下所示。
import os
import tensorflow as tf
from PIL import Image
cwd = 'img\\'
classes = {'0', '1'} # 人为 设定 2 类
writer = tf.compat.v1.python_io.TFRecordWriter("dog_train.tfrecords") # 要生成的文件
for index, name in enumerate(classes):
class_path = cwd + name + '\\'
for img_name in os.listdir(class_path):
img_path = class_path + img_name # 每一个图片的地址
img = Image.open(img_path)
img = img.resize((128, 128))
img_raw = img.tobytes() # 将图片转化为二进制格式
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
})) # example对象对label和image数据进行封装
writer.write(example.SerializeToString()) # 序列化为字符串
writer.close()
执行后会创建TFRecord数据集文件dog_train.tfrecords。
2.4.2 将CSV文件保存为TFRecord文件
请看下面的实例文件data06.py,功能是将著名地鸢尾花数据集文件iris.csv制作成TFRecord数据集。文件data06.py的具体实现代码如下所示。
import pandas as pd
import tensorflow as tf
print(tf.__version__)
input_csv_file = "iris.csv"
iris_frame = pd.read_csv(input_csv_file, header=0)
print(iris_frame)
# label,sepal_length,sepal_width,petal_length,petal_width
print("values shape: ", iris_frame.shape)
row_count = iris_frame.shape[0]
col_count = iris_frame.shape[1]
output_tfrecord_file = "iris.tfrecords"
with tf.io.TFRecordWriter(output_tfrecord_file) as writer:
for i in range(row_count):
example = tf.train.Example(
features=tf.train.Features(
feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[iris_frame.iloc[i, 0]])),
"sepal_length": tf.train.Feature(float_list=tf.train.FloatList(value=[iris_frame.iloc[i, 1]])),
"sepal_width": tf.train.Feature(float_list=tf.train.FloatList(value=[iris_frame.iloc[i, 2]])),
"petal_length": tf.train.Feature(float_list=tf.train.FloatList(value=[iris_frame.iloc[i, 3]])),
"petal_width": tf.train.Feature(float_list=tf.train.FloatList(value=[iris_frame.iloc[i, 4]]))
}
)
)
writer.write(record=example.SerializeToString())
writer.close()
执行后会提取数据集中的信息,打印输出如下信息,并创建TFRecord数据集文件iris.tfrecords。
2.6.0
Unnamed: 0 Sepal.Length ... Petal.Width Species
0 1 5.1 ... 0.2 setosa
1 2 4.9 ... 0.2 setosa
2 3 4.7 ... 0.2 setosa
3 4 4.6 ... 0.2 setosa
4 5 5.0 ... 0.2 setosa
5 6 5.4 ... 0.4 setosa
6 7 4.6 ... 0.3 setosa
7 8 5.0 ... 0.2 setosa
8 9 4.4 ... 0.2 setosa
9 10 4.9 ... 0.1 setosa
10 11 5.4 ... 0.2 setosa
11 12 4.8 ... 0.2 setosa
12 13 4.8 ... 0.1 setosa
13 14 4.3 ... 0.1 setosa
14 15 5.8 ... 0.2 setosa
15 16 5.7 ... 0.4 setosa
16 17 5.4 ... 0.4 setosa
17 18 5.1 ... 0.3 setosa
18 19 5.7 ... 0.3 setosa
19 20 5.1 ... 0.3 setosa
20 21 5.4 ... 0.2 setosa
21 22 5.1 ... 0.4 setosa
22 23 4.6 ... 0.2 setosa
23 24 5.1 ... 0.5 setosa
24 25 4.8 ... 0.2 setosa
25 26 5.0 ... 0.2 setosa
26 27 5.0 ... 0.4 setosa
27 28 5.2 ... 0.2 setosa
28 29 5.2 ... 0.2 setosa
29 30 4.7 ... 0.2 setosa
.. ... ... ... ... ...
120 121 6.9 ... 2.3 virginica
121 122 5.6 ... 2.0 virginica
122 123 7.7 ... 2.0 virginica
123 124 6.3 ... 1.8 virginica
124 125 6.7 ... 2.1 virginica
125 126 7.2 ... 1.8 virginica
126 127 6.2 ... 1.8 virginica
127 128 6.1 ... 1.8 virginica
128 129 6.4 ... 2.1 virginica
129 130 7.2 ... 1.6 virginica
130 131 7.4 ... 1.9 virginica
131 132 7.9 ... 2.0 virginica
132 133 6.4 ... 2.2 virginica
133 134 6.3 ... 1.5 virginica
134 135 6.1 ... 1.4 virginica
135 136 7.7 ... 2.3 virginica
136 137 6.3 ... 2.4 virginica
137 138 6.4 ... 1.8 virginica
138 139 6.0 ... 1.8 virginica
139 140 6.9 ... 2.1 virginica
140 141 6.7 ... 2.4 virginica
141 142 6.9 ... 2.3 virginica
142 143 5.8 ... 1.9 virginica
143 144 6.8 ... 2.3 virginica
144 145 6.7 ... 2.5 virginica
145 146 6.7 ... 2.3 virginica
146 147 6.3 ... 1.9 virginica
147 148 6.5 ... 2.0 virginica
148 149 6.2 ... 2.3 virginica
149 150 5.9 ... 1.8 virginica
[150 rows x 6 columns]
values shape: (150, 6)
2.4.3 读取TFRecord文件的内容
请看下面的实例文件data07.py,功能是将图像保存写入到TFRecord文件,然后读取TFRecord文件里的内容。将使用图像作为输入数据,将数据写入 TFRecord 文件,然后将文件读取回来并显示图像。如果想在同一个输入数据集上使用多个模型,这种做法会很有用。我们可以不以原始格式存储图像,而是将图像预处理为 TFRecord 格式,然后将其用于所有后续的处理和建模中。文件data07.py的具体实现流程如下所示。
(1)为了将标准TensorFlow类型转换为兼容tf.Example的 tf.train.Feature,编写如下所示的函数将值转换为与tf.Example兼容的类型,每个函数会接受标量输入值并返回包含上述三种 list 类型之一的 tf.train.Feature。
# 将值转换为与tf.Example兼容的类型
def _bytes_feature(value):
""" 从字符串/字节返回bytes_list"""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList不会从张量中解包字符串.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""从float/double返回一个float_list"""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""从bool/enum/int/uint返回int64_list"""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
(2)下载两个网络照片,代码如下:
cat_in_snow = tf.keras.utils.get_file('320px-Felis_catus-cat_on_snow.jpg', 'https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg')
williamsburg_bridge = tf.keras.utils.get_file('194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg')
display.display(display.Image(filename=cat_in_snow))
display.display(display.HTML('Image cc-by: <a "href=https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg">Von.grzanka</a>'))
display.display(display.Image(filename=williamsburg_bridge))
display.display(display.HTML('<a "href=https://commons.wikimedia.org/wiki/File:New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg">From Wikimedia</a>'))
如图2-5所示。
图2-5 两幅网络图片
(3)写入 TFRecord 文件
将特征编码为与 tf.Example 兼容的类型,这将存储原始图像字符串特征,以及高度、宽度、深度和任意 label 特征。后者会在您写入文件以区分猫和桥的图像时使用。将 0 用于猫的图像,将 1 用于桥的图像。代码如下:
image_labels = {
cat_in_snow : 0,
williamsburg_bridge : 1,
}
#这是一个示例,仅使用cat图像。
image_string = open(cat_in_snow, 'rb').read()
label = image_labels[cat_in_snow]
#创建具有相关功能的词典
def image_example(image_string, label):
image_shape = tf.image.decode_jpeg(image_string).shape
feature = {
'height': _int64_feature(image_shape[0]),
'width': _int64_feature(image_shape[1]),
'depth': _int64_feature(image_shape[2]),
'label': _int64_feature(label),
'image_raw': _bytes_feature(image_string),
}
return tf.train.Example(features=tf.train.Features(feature=feature))
for line in str(image_example(image_string, label)).split('\n')[:15]:
print(line)
print('...')
执行后会打印输出TFRecord 文件的结构:
key: "depth"
value {
int64_list {
value: 3
}
}
}
feature {
key: "height"
value {
int64_list {
value: 213
}
...
此时所有的特征都被存储在 tf.Example 消息中,接下来,函数化处理上面的代码,并将示消息写入名为 images.tfrecords 的文件中。代码如下:
# 将原始图像文件写入“images.tfrecords”。
# 首先,将这两个图像处理为`tf.Example`消息。
# 然后,写入一个“.tfrecords”文件.
record_file = 'images.tfrecords'
with tf.io.TFRecordWriter(record_file) as writer:
for filename, label in image_labels.items():
image_string = open(filename, 'rb').read()
tf_example = image_example(image_string, label)
writer.write(tf_example.SerializeToString())
(4)读取 TFRecord 文件
现在已经创建了了文件 images.tfrecords,并可以迭代其中的记录以将您写入的内容读取回来。因为在此实例中只需重新生成图像,所以只需要原始图像字符串这一个特征。使用上面描述的 getter 方法(即 example.features.feature['image_raw'].bytes_list.value[0])提取该特征。另外还可以使用标签来确定哪个记录是猫,哪个记录是桥。
raw_image_dataset = tf.data.TFRecordDataset('images.tfrecords')
#创建描述功能的词典.
image_feature_description = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'depth': tf.io.FixedLenFeature([], tf.int64),
'label': tf.io.FixedLenFeature([], tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string),
}
def _parse_image_function(example_proto):
#使用上面的字典解析输入tf.Example proto
return tf.io.parse_single_example(example_proto, image_feature_description)
parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
parsed_image_dataset
执行后会输出:
<MapDataset shapes: {depth: (), height: (), image_raw: (), label: (), width: ()}, types: {depth: tf.int64, height: tf.int64, image_raw: tf.string, label: tf.int64, width: tf.int64}>
从 TFRecord 文件中恢复图像,代码如下:
for image_features in parsed_image_dataset:
image_raw = image_features['image_raw'].numpy()
display.display(display.Image(data=image_raw))
从TFRecord文件中恢复出来的图像如图2-6所示。
图2-6 从TFRecord文件中恢复出来的图像