(2-4)TensorFlow数据集制作实战:TFRecord数据集制作实战

2.4  TFRecord数据集制作实战

TensorFlow提供了TFRecords格式来统一存储数据,从理论上讲,TFRecords可以存储任何形式的数据。TFRecord是一种二进制文件,具有以下优点:

  1. 统一各种输入文件的操作
  2. 更好的利用内存,方便复制和移动
  3. 将二进制数据和标签(label)存储在同一个文件中。

在本节的内容中,将详细讲解制作并操作TFRecord数据集的知识。

2.4.1  将图片制作为TFRecord数据集

“img”目录中有两个子目录“0”和“1”,在两个子目录中分别保存了图片。然后编写实例文件data05.py,功能是上述两个子目录“0”和“1”中的图片制作成TFRecord数据集。文件data05.py的具体实现代码如下所示。

import os
import tensorflow as tf
from PIL import Image 

cwd = 'img\\'
classes = {'0', '1'}  # 人为 设定 2 类
writer = tf.compat.v1.python_io.TFRecordWriter("dog_train.tfrecords")  # 要生成的文件


for index, name in enumerate(classes):
    class_path = cwd + name + '\\'
    for img_name in os.listdir(class_path):
        img_path = class_path + img_name  # 每一个图片的地址

        img = Image.open(img_path)
        img = img.resize((128, 128))
        img_raw = img.tobytes()  # 将图片转化为二进制格式
        example = tf.train.Example(features=tf.train.Features(feature={
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        }))  # example对象对label和image数据进行封装
        writer.write(example.SerializeToString())  # 序列化为字符串

writer.close()

执行后会创建TFRecord数据集文件dog_train.tfrecords。

2.4.2  将CSV文件保存为TFRecord文件

请看下面的实例文件data06.py,功能是将著名地鸢尾花数据集文件iris.csv制作成TFRecord数据集。文件data06.py的具体实现代码如下所示。

import pandas as pd
import tensorflow as tf

print(tf.__version__)

input_csv_file = "iris.csv"
iris_frame = pd.read_csv(input_csv_file, header=0)
print(iris_frame)
# label,sepal_length,sepal_width,petal_length,petal_width
print("values shape: ", iris_frame.shape)

row_count = iris_frame.shape[0]
col_count = iris_frame.shape[1]

output_tfrecord_file = "iris.tfrecords"
with  tf.io.TFRecordWriter(output_tfrecord_file) as writer:
    for i in range(row_count):
        example = tf.train.Example(
            features=tf.train.Features(
                feature={
                    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[iris_frame.iloc[i, 0]])),
                    "sepal_length": tf.train.Feature(float_list=tf.train.FloatList(value=[iris_frame.iloc[i, 1]])),
                    "sepal_width": tf.train.Feature(float_list=tf.train.FloatList(value=[iris_frame.iloc[i, 2]])),
                    "petal_length": tf.train.Feature(float_list=tf.train.FloatList(value=[iris_frame.iloc[i, 3]])),
                    "petal_width": tf.train.Feature(float_list=tf.train.FloatList(value=[iris_frame.iloc[i, 4]]))

                }
            )
        )
        writer.write(record=example.SerializeToString())
writer.close()

执行后会提取数据集中的信息,打印输出如下信息,并创建TFRecord数据集文件iris.tfrecords。

2.6.0
     Unnamed: 0  Sepal.Length  ...  Petal.Width    Species
0             1           5.1  ...          0.2     setosa
1             2           4.9  ...          0.2     setosa
2             3           4.7  ...          0.2     setosa
3             4           4.6  ...          0.2     setosa
4             5           5.0  ...          0.2     setosa
5             6           5.4  ...          0.4     setosa
6             7           4.6  ...          0.3     setosa
7             8           5.0  ...          0.2     setosa
8             9           4.4  ...          0.2     setosa
9            10           4.9  ...          0.1     setosa
10           11           5.4  ...          0.2     setosa
11           12           4.8  ...          0.2     setosa
12           13           4.8  ...          0.1     setosa
13           14           4.3  ...          0.1     setosa
14           15           5.8  ...          0.2     setosa
15           16           5.7  ...          0.4     setosa
16           17           5.4  ...          0.4     setosa
17           18           5.1  ...          0.3     setosa
18           19           5.7  ...          0.3     setosa
19           20           5.1  ...          0.3     setosa
20           21           5.4  ...          0.2     setosa
21           22           5.1  ...          0.4     setosa
22           23           4.6  ...          0.2     setosa
23           24           5.1  ...          0.5     setosa
24           25           4.8  ...          0.2     setosa
25           26           5.0  ...          0.2     setosa
26           27           5.0  ...          0.4     setosa
27           28           5.2  ...          0.2     setosa
28           29           5.2  ...          0.2     setosa
29           30           4.7  ...          0.2     setosa
..          ...           ...  ...          ...        ...
120         121           6.9  ...          2.3  virginica
121         122           5.6  ...          2.0  virginica
122         123           7.7  ...          2.0  virginica
123         124           6.3  ...          1.8  virginica
124         125           6.7  ...          2.1  virginica
125         126           7.2  ...          1.8  virginica
126         127           6.2  ...          1.8  virginica
127         128           6.1  ...          1.8  virginica
128         129           6.4  ...          2.1  virginica
129         130           7.2  ...          1.6  virginica
130         131           7.4  ...          1.9  virginica
131         132           7.9  ...          2.0  virginica
132         133           6.4  ...          2.2  virginica
133         134           6.3  ...          1.5  virginica
134         135           6.1  ...          1.4  virginica
135         136           7.7  ...          2.3  virginica
136         137           6.3  ...          2.4  virginica
137         138           6.4  ...          1.8  virginica
138         139           6.0  ...          1.8  virginica
139         140           6.9  ...          2.1  virginica
140         141           6.7  ...          2.4  virginica
141         142           6.9  ...          2.3  virginica
142         143           5.8  ...          1.9  virginica
143         144           6.8  ...          2.3  virginica
144         145           6.7  ...          2.5  virginica
145         146           6.7  ...          2.3  virginica
146         147           6.3  ...          1.9  virginica
147         148           6.5  ...          2.0  virginica
148         149           6.2  ...          2.3  virginica
149         150           5.9  ...          1.8  virginica

[150 rows x 6 columns]
values shape:  (150, 6)

2.4.3  读取TFRecord文件的内容

请看下面的实例文件data07.py,功能是将图像保存写入到TFRecord文件,然后读取TFRecord文件里的内容。将使用图像作为输入数据,将数据写入 TFRecord 文件,然后将文件读取回来并显示图像。如果想在同一个输入数据集上使用多个模型,这种做法会很有用。我们可以不以原始格式存储图像,而是将图像预处理为 TFRecord 格式,然后将其用于所有后续的处理和建模中。文件data07.py的具体实现流程如下所示。

(1)为了将标准TensorFlow类型转换为兼容tf.Example的 tf.train.Feature,编写如下所示的函数将值转换为与tf.Example兼容的类型,每个函数会接受标量输入值并返回包含上述三种 list 类型之一的 tf.train.Feature。

# 将值转换为与tf.Example兼容的类型
def _bytes_feature(value):
  """  从字符串/字节返回bytes_list"""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList不会从张量中解包字符串.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """从float/double返回一个float_list"""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """从bool/enum/int/uint返回int64_list"""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

2下载两个网络照片,代码如下:

cat_in_snow  = tf.keras.utils.get_file('320px-Felis_catus-cat_on_snow.jpg', 'https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg')
williamsburg_bridge = tf.keras.utils.get_file('194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg')

display.display(display.Image(filename=cat_in_snow))
display.display(display.HTML('Image cc-by: <a "href=https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg">Von.grzanka</a>'))

display.display(display.Image(filename=williamsburg_bridge))
display.display(display.HTML('<a "href=https://commons.wikimedia.org/wiki/File:New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg">From Wikimedia</a>'))

如图2-5所示。

2-5  两幅网络图片

(3)写入 TFRecord 文件

将特征编码为与 tf.Example 兼容的类型,这将存储原始图像字符串特征,以及高度、宽度、深度和任意 label 特征。后者会在您写入文件以区分猫和桥的图像时使用。将 0 用于猫的图像,将 1 用于桥的图像。代码如下:

image_labels = {
    cat_in_snow : 0,
    williamsburg_bridge : 1,
}

#这是一个示例,仅使用cat图像。
image_string = open(cat_in_snow, 'rb').read()

label = image_labels[cat_in_snow]

#创建具有相关功能的词典
def image_example(image_string, label):
  image_shape = tf.image.decode_jpeg(image_string).shape

  feature = {
      'height': _int64_feature(image_shape[0]),
      'width': _int64_feature(image_shape[1]),
      'depth': _int64_feature(image_shape[2]),
      'label': _int64_feature(label),
      'image_raw': _bytes_feature(image_string),
  }

  return tf.train.Example(features=tf.train.Features(feature=feature))

for line in str(image_example(image_string, label)).split('\n')[:15]:
  print(line)
print('...')

执行后会打印输出TFRecord 文件的结构

    key: "depth"
    value {
      int64_list {
        value: 3
      }
    }
  }
  feature {
    key: "height"
    value {
      int64_list {
        value: 213
      }
...

此时所有的特征都被存储在 tf.Example 消息中,接下来,函数化处理上面的代码,并将示消息写入名为 images.tfrecords 的文件中。代码如下:

# 将原始图像文件写入“images.tfrecords”。
# 首先,将这两个图像处理为`tf.Example`消息。
# 然后,写入一个“.tfrecords”文件.
record_file = 'images.tfrecords'
with tf.io.TFRecordWriter(record_file) as writer:
  for filename, label in image_labels.items():
    image_string = open(filename, 'rb').read()
    tf_example = image_example(image_string, label)
    writer.write(tf_example.SerializeToString())

4读取 TFRecord 文件

现在已经创建了了文件 images.tfrecords,并可以迭代其中的记录以将您写入的内容读取回来。因为在此实例中只需重新生成图像,所以只需要原始图像字符串这一个特征。使用上面描述的 getter 方法(即 example.features.feature['image_raw'].bytes_list.value[0])提取该特征。另外还可以使用标签来确定哪个记录是猫,哪个记录是桥。

raw_image_dataset = tf.data.TFRecordDataset('images.tfrecords')

#创建描述功能的词典.
image_feature_description = {
    'height': tf.io.FixedLenFeature([], tf.int64),
    'width': tf.io.FixedLenFeature([], tf.int64),
    'depth': tf.io.FixedLenFeature([], tf.int64),
    'label': tf.io.FixedLenFeature([], tf.int64),
    'image_raw': tf.io.FixedLenFeature([], tf.string),
}

def _parse_image_function(example_proto):
  #使用上面的字典解析输入tf.Example proto
  return tf.io.parse_single_example(example_proto, image_feature_description)

parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
parsed_image_dataset

执行后会输出:

<MapDataset shapes: {depth: (), height: (), image_raw: (), label: (), width: ()}, types: {depth: tf.int64, height: tf.int64, image_raw: tf.string, label: tf.int64, width: tf.int64}>

从 TFRecord 文件中恢复图像,代码如下:

for image_features in parsed_image_dataset:
  image_raw = image_features['image_raw'].numpy()
  display.display(display.Image(data=image_raw))

从TFRecord文件中恢复出来的图像如图2-6所示。

2-6  从TFRecord文件中恢复出来的图像

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

码农三叔

感谢鼓励

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值