tf.data模块
import tensorflow as tf
tf.__version__
‘2.0.0’
tf.data模块用于构建输入,且可以处理大量数据、不同的数据格式,以及数据转换。
tf.data.Dateset表示一组数据。tf.data.Dateset中一个元素包含一个或者多个Tensor对象,例如一个元素代表单个训练样本或者代表一对训练数据和标签
创建Dataset
tf.data.Dataset.from_tensor_slices(tensors)
tensors可以为列表,字典,元组,numpy的ndarray,tensor
Dataset对象中每个元素的结构必须相同,一个元素可以包含一个或多个tensor对象,这些tensor对象被称为组件。
Dataset对象可以对其直接进行迭代使用
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
print(dataset)
<TensorSliceDataset shapes: (), types: tf.int32>
TensorSliceDataset shapes 表示Dataset中元素的shape,这里元素为数字,所以shapes为()
for element in dataset:
print(element)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
可以用.numpy()的方法,对Tensor进行转换
for element in dataset:
print(element.numpy())
1
2
3
用一个二维列表创建
dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [3, 4]])
print(dataset)
print()
for element in dataset:
print(element)
print()
for element in dataset:
print(element.numpy())
<TensorSliceDataset shapes: (2,), types: tf.int32>
tf.Tensor([1 2], shape=(2,), dtype=int32)
tf.Tensor([3 4], shape=(2,), dtype=int32)
[1 2]
[3 4]
用字典类型创建
Dataset的每个元素就是一个字典
dict = {
'a': [[1], [2]], 'b': [[3], [4]]}
dataset = tf.data