TFRecord and tf.Example

提示:tf.TFRecordReader, tf.TFRecordWriter等方法在tf2中已经出现变化,tf2的写入TFRecord文件为:tf.io.TFRecordWriter(); 读取采用tf.data.TFRecordDataset

Table of Contents

Setup

tf.Example

tf.Example的数据类型

创造一个  tf.Example  的信息流

TFRecords 文件格式化细节

在TFRecord文件中用tf.data

写一个TFRecord文件

读取TFRecord文件

python中的TFRecord文件

写一个TFRecord文件

读取TFRecord文件

练习:读取和写入图片数据

获取图片

写入TFRecord文件

读取TFRecord文件


本文来自:Tensorflow官方参考文档

为了有效地读取数据,将数据序列化并存储在一组文件(每个文件100-200MB)中(每个文件都可以线性读取)会很有帮助。如果数据是通过网络传输的,这一点尤其正确。这对于缓存任何数据预处理也很有用。

TFRecord格式是一种用于存储二进制记录序列的简单格式。
协议缓冲区是一个跨平台、跨语言的库,用于高效序列化结构化数据。
协议消息由.proto文件定义,这些文件通常是理解消息类型的最简单方法。
TF.Example信息(或protobuf)是一个灵活的信息类型,表示{“string”:值}映射。它设计用于TensorFlow,并在更高级别的api(如TFX)中使用。
本笔记本将演示如何创建、解析和使用tf.Example消息,然后在.tfrecord文件之间序列化、写入和读取tf.Example消息。

注意:虽然这些结构很有用,但它们是可选的。不需要将现有代码转换为使用TFREST,除非您正在使用tf.data,而读取数据仍然是训练的瓶颈。有关数据集性能提示,请参见数据输入管道性能。

Setup

from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf

import numpy as np
import IPython.display as display

tf.Example

tf.Example的数据类型

从根本上说,一个tf.Example 就是一个{"string": tf.train.Feature}的字典映射。

tf.train.Feature消息类型可以接受以下三种类型之一(请参阅.proto文件)。大多数其他泛型类型可以强制为以下类型之一:

  1. tf.train.BytesList  (可以强制以下类型):  string    byte
  2. tf.train.FloatList  (可以强制以下类型):  float(float32)  double(float64)
  3. tf.train.Int64List  (可以强制以下类型):  bool  enum  int32  uint32  int64  uint64

为了将标准TensorFlow类型转换为tf.Example-兼容 tf.train.Feature,可以使用下面的快捷函数。注意,每个函数都接受一个标量输入值,并返回一个tf.train.Feature,其中包含上述三种列表类型之一:

# The following functions can be used to convert a value to a type compatible
# with tf.Example.

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

注意:为了保持简单,这个例子只使用标量输入。处理非标量特性的最简单方法是使用tf.serialize_tensor将张量转换为二进制字符串。字符串是tensorflow中的标量。使用tf.parse_tensor将二进制字符串转换回张量。

下面是这些函数如何工作的一些示例。注意不同的输入类型和标准化的输出类型。如果函数的输入类型与上述可强制类型之一不匹配,则函数将引发异常(例如,由于1.0是浮点,因此应与浮点函数一起使用,因此1.0将出错):

print(_bytes_feature(b'test_string'))
print(_bytes_feature(u'test_bytes'.encode('utf-8')))

print(_float_feature(np.exp(1)))

print(_int64_feature(True))
print(_int64_feature(1))

bytes_list {
  value: "test_string"
}

bytes_list {
  value: "test_bytes"
}

float_list {
  value: 2.7182817459106445
}

int64_list {
  value: 1
}

int64_list {
  value: 1
}

所有的协议信息能够使用 .SerializeToString的方法序列化为二进制字符串:

feature = _float_feature(np.exp(1))

feature.SerializeToString()


b'\x12\x06\n\x04T\xf8-@'

创造一个  tf.Example  的信息流

假设您想从现有数据创建TF.Example信息流。实际上,数据集可能来自任何地方,但是从一个单一的监视器中创建tf.Example信息流的方法是相同的:

  1. 在每个监视器中,所有的值都需要用上面的方法转化为包含以上三个兼容格式之一的 tf.train.Feature 对象。
  2. 你能够创建一个特征量名字字符串与用示例一编码方式编码的值想对应的映射(字典)。
  3. 在第二步中构造的映射已经被转为了一个Feature信息

在这一个笔记中你将会用Numpy创建一个数据集。

这个数据集有以下4个特征:

  • 一个布尔特征值,Falase或者是True的概率相同。
  • 从[0,5]中均匀随机选择的整数特征值
  • 从字符串表中产生的字符串特征值,这些特征值由整数特征值索引。
  • 一个按照标准正态分布的浮点型特征值

考虑一个由10000个独立且相同分布的观测值组成的样本,这些观测值来自上述每个分布:

# The number of observations in the dataset.
n_observations = int(1e4)

# Boolean feature, encoded as False or True.
feature0 = np.random.choice([False, True], n_observations)

# Integer feature, random from 0 to 4.
feature1 = np.random.randint(0, 5, n_observations)

# String feature
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
feature2 = strings[feature1]

# Float feature, from a standard normal distribution
feature3 = np.random.randn(n_observations)

这些特性中的每一个都可以使用_bytes_feature_float_feature_int64_feature强制转换为tf.Example兼容类型。然后,您可以从这些编码功能创建tf.Example消息:

def serialize_example(feature0, feature1, feature2, feature3):
  """
  Creates a tf.Example message ready to be written to a file.
  """
  # Create a dictionary mapping the feature name to the tf.Example-compatible
  # data type.
  feature = {
      'feature0': _int64_feature(feature0),
      'feature1': _int64_feature(feature1),
      'feature2': _bytes_feature(feature2),
      'feature3': _float_feature(feature3),
  }

  # Create a Features message using tf.train.Example.

  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
  return example_proto.SerializeToString()

例如,假设您有一个来自数据集的观察结果,[False,4,bytes('goat'),0.9876]。您可以使用create_message()创建并打印此观察的tf.Example消息。每一个单独的观察都将按照上述内容作为一个特征信息写入。请注意,tf.Example消息只是Features消息的包装:

# This is an example observation from the dataset.

example_observation = []

serialized_example = serialize_example(False, 4, b'goat', 0.9876)
serialized_example

b'\nR\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04[\xd3|?'

使用tf.train.Example.FromString()来反编译信息流

example_proto = tf.train.Example.FromString(serialized_example)
example_proto
features {
  feature {
    key: "feature0"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "feature1"
    value {
      int64_list {
        value: 4
      }
    }
  }
  feature {
    key: "feature2"
    value {
      bytes_list {
        value: "goat"
      }
    }
  }
  feature {
    key: "feature3"
    value {
      float_list {
        value: 0.9876000285148621
      }
    }
  }
}

TFRecords 文件格式化细节

TFRecord文件包含一系列记录。只能按顺序读取文件。
每个记录包含一个字节字符串,用于数据有效负载,加上数据长度,以及CRC32C(使用Castagnoli多项式的32位CRC)散列用于完整性检查。每一个都是按下面的形式记录的。

uint64 length
uint32 masked_crc32_of_length
byte   data[length]
uint32 masked_crc32_of_data

这些记录连接在一起生成文件。这里描述CRC,CRC的掩码是:

masked_crc = ((crc >> 15) | (crc << 17)) + 0xa282ead8ul

注意:在TFRecord文件中不需要使用tf.Example。Example只是一种将字典序列化为字节字符串的方法。文本行、编码图像数据或序列化张量(加载时使用tf.io.serialize_tensor和tf.io.parse_tensor)。有关更多选项,请参阅tf.io模块。

在TFRecord文件中用tf.data

在Tensorflow中, tf.data 模块也提供了一个读写数据的方法

写一个TFRecord文件

使用from_tensor_slices 方法可以非常简单的将数据获取到数据集中。其他方法还有:tf.data.Dataset.from_generator ; tf.data.Dataset.from_tensors

对于数组,这个方法会返回一个标量数据集:

tf.data.Dataset.from_tensor_slices(feature1)

<TensorSliceDataset shapes: (), types: tf.int64>

对于一个元组数组将会返回一个元组:

features_dataset = tf.data.Dataset.from_tensor_slices((feature0, feature1, feature2, feature3))
features_dataset
<TensorSliceDataset shapes: ((), (), (), ()), types: (tf.bool, tf.int64, tf.string, tf.float64)>
# Use `take(1)` to only pull one example from the dataset.
for f0,f1,f2,f3 in features_dataset.take(1):
  print(f0)
  print(f1)
  print(f2)
  print(f3)
tf.Tensor(True, shape=(), dtype=bool)
tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(b'chicken', shape=(), dtype=string)
tf.Tensor(0.7951810656934285, shape=(), dtype=float64)

使用tf.data.Dataset.map方法将函数应用于数据集的每个元素。

映射函数必须在TensorFlow图模式下操作,并返回tf.Tensors。非张量函数(如serialize_example)可以用tf.py_function包装,以使其兼容。

使用tf.py_function需要指定可用的形状和类型信息.

def tf_serialize_example(f0,f1,f2,f3):
  tf_string = tf.py_function(
    serialize_example,
    (f0,f1,f2,f3),  # pass these args to the above function.
    tf.string)      # the return type is `tf.string`.
  return tf.reshape(tf_string, ()) # The result is a scalar
tf_serialize_example(f0,f1,f2,f3)
<tf.Tensor: shape=(), dtype=string, numpy=b'\nU\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xfc\x90K?'>

为数据集应用这一函数:

serialized_features_dataset = features_dataset.map(tf_serialize_example)
serialized_features_dataset
<MapDataset shapes: (), types: tf.string>

用generator来创建dataset

def generator():
  for features in features_dataset:
    yield serialize_example(*features)
#*feature是可变参数
#Creates a Dataset whose elements are generated by generator.
serialized_features_dataset = tf.data.Dataset.from_generator(
    generator, output_types=tf.string, output_shapes=())
serialized_features_dataset

<FlatMapDataset shapes: (), types: tf.string>

将其写入TFRecord文件

filename = 'test.tfrecord'
writer = tf.data.experimental.TFRecordWriter(filename)
writer.write(serialized_features_dataset)

读取TFRecord文件

可以使用tf.data.TFRecordDataset类来读取TFRecord文件

有关使用tf.data使用TFRecord文件的详细信息,请参见API。

使用TFRecordDataset对于标准化输入数据和优化性能非常有用。

filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset
<TFRecordDatasetV2 shapes: (), types: tf.string>

此时,数据集包含序列化的tf.train.Example消息。当对其进行迭代时,它将这些作为标量字符串张量返回。

使用 .take()方法可以只显示前面的10个数据据

Note:在tf.data.Dataset上迭代只能在启用了紧急执行的情况下工作。

for raw_record in raw_dataset.take(10):
  print(repr(raw_record))
<tf.Tensor: shape=(), dtype=string, numpy=b'\nU\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xfc\x90K?'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nS\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x03\n\x15\n\x08feature2\x12\t\n\x07\n\x05horse\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04,\x87\xbd?'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04WPC\xbf\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04#\x0fc?\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04e\xc6\x9e\xbf'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x9fAL?\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x17\x07"@'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nU\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04=\x91n>\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nS\n\x15\n\x08feature2\x12\t\n\x07\n\x05horse\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xe7N\x99<\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x03'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x08\xe88\xbf'>

可以使用下面的函数解析这些张量。请注意,此处需要feature_description,因为数据集使用图形执行,并且需要此描述来构建其形状和类型签名:

# Create a description of the features.
feature_description = {
    'feature0': tf.io.FixedLenFeature([], tf.int64, default_value=0),
    'feature1': tf.io.FixedLenFeature([], tf.int64, default_value=0),
    'feature2': tf.io.FixedLenFeature([], tf.string, default_value=''),
    'feature3': tf.io.FixedLenFeature([], tf.float32, default_value=0.0),
}
#这里是做句法分析
def _parse_function(example_proto):
  # Parse the input `tf.Example` proto using the dictionary above.
  return tf.io.parse_single_example(example_proto, feature_description)

或者,使用tf.parse示例一次解析整个喂入的数据。使用tf.data.dataset.map方法将此函数应用于数据集中的每个项:

parsed_dataset = raw_dataset.map(_parse_function)
parsed_dataset

<MapDataset shapes: {feature0: (), feature1: (), feature2: (), feature3: ()}, types: {feature0: tf.int64, feature1: tf.int64, feature2: tf.string, feature3: tf.float32}>

 使用“紧急执行”在数据集中显示观察结果。此数据集中有10000个观测值,但仅显示前10个。数据显示为功能字典。每个项都是一个tf.Tensor,该Tensor的numpy元素显示特征值:

for parsed_record in parsed_dataset.take(10):
  print(repr(parsed_record))
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=2>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'chicken'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.79518104>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=3>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'horse'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=1.4806876>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=4>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'goat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-0.76294464>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'cat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.8869497>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'cat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-1.2404295>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'dog'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.7978763>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'dog'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=2.5316827>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=2>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'chicken'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.23297592>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=3>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'horse'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.018714381>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'dog'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-0.7222905>}

在这里,tf.parse_example函数将tf.example字段解压为标准张量。

python中的TFRecord文件

py.io模块还包含用于读取和写入TFRecord文件的纯Python函数。

写一个TFRecord文件

接下来,将10000个观测值写入test.ftRecord文件。将每一个写入文件的观测值都转化成tf.Example信息流。最后你可以确认test.tfrecord这个文件已经建立了。

# Write the `tf.Example` observations to the file.
with tf.io.TFRecordWriter(filename) as writer:
  for i in range(n_observations):
    example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])
    writer.write(example)
!du -sh {filename}

984K    test.tfrecord

读取TFRecord文件

这些序列化的文件可以用tf.train.Example.ParseFromString 来进行语法、句法检测。

filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset
<TFRecordDatasetV2 shapes: (), types: tf.string>
for raw_record in raw_dataset.take(1):
  example = tf.train.Example()
  example.ParseFromString(raw_record.numpy())
  print(example)
features {
  feature {
    key: "feature0"
    value {
      int64_list {
        value: 1
      }
    }
  }
  feature {
    key: "feature1"
    value {
      int64_list {
        value: 2
      }
    }
  }
  feature {
    key: "feature2"
    value {
      bytes_list {
        value: "chicken"
      }
    }
  }
  feature {
    key: "feature3"
    value {
      float_list {
        value: 0.7951810359954834
      }
    }
  }
}

练习:读取和写入图片数据

这是一个端到端的示例,说明如何使用TFRecords读写图像数据。使用图像作为输入数据,您将数据作为TFRecord文件写入,然后将文件读回并显示图像。


例如,如果要在同一个输入数据集上使用多个模型,则这可能很有用。它不需要存储原始图像数据,而是可以预处理为TFRecords格式,并可用于所有进一步的处理和建模。


首先,让我们下载这张雪地里猫的照片和这张正在建设中的纽约威廉斯堡大桥的照片。

获取图片

cat_in_snow  = tf.keras.utils.get_file('320px-Felis_catus-cat_on_snow.jpg', 'https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg')
williamsburg_bridge = tf.keras.utils.get_file('194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg
24576/17858 [=========================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg
16384/15477 [===============================] - 0s 0us/step
display.display(display.Image(filename=cat_in_snow))
display.display(display.HTML('Image cc-by: <a "href=https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg">Von.grzanka</a>'))

Image cc-by: Von.grzanka

display.display(display.Image(filename=williamsburg_bridge))
display.display(display.HTML('<a "href=https://commons.wikimedia.org/wiki/File:New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg">From Wikimedia</a>'))

From Wikimedia

写入TFRecord文件

和以前一样,将这些特性编码为与tf.Example兼容的类型。这将存储原始图像字符串功能,以及高度、宽度、深度和任意标签功能。后者用于在编写文件时区分cat映像和bridge映像。使用0表示cat映像,使用1表示桥接映像:

image_labels = {
    cat_in_snow : 0,
    williamsburg_bridge : 1,
}

# This is an example, just using the cat image.
image_string = open(cat_in_snow, 'rb').read()

label = image_labels[cat_in_snow]

# Create a dictionary with features that may be relevant.
def image_example(image_string, label):
  image_shape = tf.image.decode_jpeg(image_string).shape

  feature = {
      'height': _int64_feature(image_shape[0]),
      'width': _int64_feature(image_shape[1]),
      'depth': _int64_feature(image_shape[2]),
      'label': _int64_feature(label),
      'image_raw': _bytes_feature(image_string),
  }

  return tf.train.Example(features=tf.train.Features(feature=feature))

for line in str(image_example(image_string, label)).split('\n')[:15]:
  print(line)
print('...')
features {
  feature {
    key: "depth"
    value {
      int64_list {
        value: 3
      }
    }
  }
  feature {
    key: "height"
    value {
      int64_list {
        value: 213
      }
...

注意,所有特性现在都存储在tf.Example消息中。接下来,对上面的代码进行功能化,并将示例消息写入名为images.tfrecords的文件:

# Write the raw image files to `images.tfrecords`.
# First, process the two images into `tf.Example` messages.
# Then, write to a `.tfrecords` file.
record_file = 'images.tfrecords'
with tf.io.TFRecordWriter(record_file) as writer:
  for filename, label in image_labels.items():
    image_string = open(filename, 'rb').read()
    tf_example = image_example(image_string, label)
    writer.write(tf_example.SerializeToString())

!du -sh {record_file}

36K images.tfrecords

读取TFRecord文件

现在您有了images.tfrecords文件,现在可以遍历其中的记录来读回您所写的内容。考虑到在这个例子中,您将只复制图像,您将需要的唯一特性是原始图像字符串。使用上面描述的getter提取它,即example.features.feature['image_raw'].bytes_list.value[0]。您还可以使用标签来确定哪条记录是cat,哪条是网桥:

raw_image_dataset = tf.data.TFRecordDataset('images.tfrecords')

# Create a dictionary describing the features.
image_feature_description = {
    'height': tf.io.FixedLenFeature([], tf.int64),
    'width': tf.io.FixedLenFeature([], tf.int64),
    'depth': tf.io.FixedLenFeature([], tf.int64),
    'label': tf.io.FixedLenFeature([], tf.int64),
    'image_raw': tf.io.FixedLenFeature([], tf.string),
}

def _parse_image_function(example_proto):
  # Parse the input tf.Example proto using the dictionary above.
  return tf.io.parse_single_example(example_proto, image_feature_description)

parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
parsed_image_dataset


<MapDataset shapes: {depth: (), height: (), image_raw: (), label: (), width: ()}, types: {depth: tf.int64, height: tf.int64, image_raw: tf.string, label: tf.int64, width: tf.int64}>

从TFRecord文件中恢复这张图片

for image_features in parsed_image_dataset:
  image_raw = image_features['image_raw'].numpy()
  display.display(display.Image(data=image_raw))

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
`tf.placeholder` 和 `tf.Variable` 都是 TensorFlow 中的重要概念,但在使用方式、作用和特点上有所不同。 `tf.placeholder` 是一个占位符,用于在 TensorFlow 的计算图中定义输入数据的位置。它在定义计算图的时候并不需要给定具体的数值,而是在计算图运行时,通过 `feed_dict` 参数传入具体的数值。它通常用于传入训练数据和标签等变量,例如: ``` x = tf.placeholder(tf.float32, shape=[None, 784]) y = tf.placeholder(tf.float32, shape=[None, 10]) ``` 在这个例子中,我们定义了两个 `tf.placeholder`,`x` 和 `y`,分别用于输入训练数据和标签。其中,`shape=[None, 784]` 表示输入数据的形状是一个二维张量,第一个维度可以是任意大小,第二个维度是 784。 相比之下,`tf.Variable` 则是用于在 TensorFlow 的计算图中定义需要优化的变量。它在定义时需要给定初始值,通常是一个随机数值或者全零的数组。在 TensorFlow 的计算过程中,`tf.Variable` 的值会不断地被优化,以使得算法达到更好的结果。例如: ``` W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) ``` 在这个例子中,我们定义了两个 `tf.Variable`,`W` 和 `b`,分别表示权重和偏置。它们的初始值都是全零的数组。 总的来说,`tf.placeholder` 用于传入数据,`tf.Variable` 用于定义需要优化的变量。它们在 TensorFlow 中都有着重要的作用。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值