tf.data:TensorFlow Input Pipeline
- Extract:
- read data from memory/storage
- parse file format
- Transform
- text vectorization
- image transformations
- video temporal sampling
- shuffling, batching, …
- Load
- transfer data to the accelerator
简单实例:
# one method
import tensorflow as tf
def preprocess(record)"
pass
dataset = tf.data.TFRecordDataset(../*.tfrecord") # reads data from storage
dataset = dataset.map(preprocess, num_parallel_calls=Y) # applies user-defined preprocessing
dataset = dataset.batch(batch_size=32) #
dataset = dataset.prefetch(buffer_size=X)
model = ...
model.fit(dataset, epochs=10)
# another method
import tensorflow as tf
def preprocess(record)"
pass
dataset = tf.data.Dataset.list_files(../*.tfrecord") # reads data from storage
dataset = dataset.interleave(TFRecordDataset, num_parallel_calls=Z)
dataset = dataset.map(preprocess, num_parallel_calls=X) # applies user-defined preprocessing
dataset = dataset.batch(batch_size=32) #
dataset = dataset.prefetch(buffer_size=X)
model = ...
model.fit(dataset, epochs=10)
tf.data Options
- tf.data.Options
- statistics aggregation
- optimizations (autotuning, fushion, vectorization, parallelization, determinism,…)
- threading (private thread pool, intra op parallelism)
dataset = ...
options = tf.data.Options()
options.experimental_optimization.map_parallelization = True
dataset = dataset.with_options(options)
TFDS:TensorFlow Datasets
- https://www.tensorflow.org/datasets/datasets
- canned datasets ready to be used with rest of TensorFlow
import tensorflow as tf
import tensorflow_datasets as tfds
# see available datasets
print(tfds.list_builders())
# construct a tf.data.Dataset
dataset = tfds.load(name="mnist", split=tfds.Split.TRAIN)
# customize your input pipeline
dataset = dataset.shuffle(1024).batch(32)
for features in dataset.take(1):
image, label = features["image"], features["labels"]