数据集创建tf.data(tensorflow2)

tf.data API允许使用简单,可复用的代码创建一个数据输入流。比如它可以从图像分布式文件系统中创建数据输入,在此过程中可以为每张图像添加随机噪声,随机抽取图像当作本次batch进行训练。

tf.data API引入了tf.data.Dataset对象,它包含了一系列的元素(element),每一个元素有多个或一个components

有两种方法创建数据集dataset:

1.从内存或文件中的数据源构建Datadet
2.Dataset数据库中转化得到

基础创建Dataset

dataset = tf.data.Dataset.from_tensor_slices([8, 3, 0, 8, 2, 1])
dataset
#<TensorSliceDataset shapes: (), types: tf.int32>

for elem in dataset:
  print(elem.numpy())
#8
#3
#0
#8
#2
#1

dataset其实是一个iter:

it = iter(dataset)
print(next(it).numpy())
#8

我们可以看到dataset是一个Dataset对象,下面看下reduce方法的使用:

print(dataset.reduce(0, lambda state, value: state + value).numpy())
#22

数据集结构

数据中每个元素都是相同的类型。
类型如下:
Tensor, SparseTensor, RaggedTensor, TensorArray, or Dataset。它们都被包含在 tf.TypeSpec

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random.uniform([4, 10]))

dataset1.element_spec

#TensorSpec(shape=(10,), dtype=tf.float32, name=None)

dataset2 = tf.data.Dataset.from_tensor_slices(
   (tf.random.uniform([4]),
    tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))

dataset2.element_spec

#(TensorSpec(shape=(), dtype=tf.float32, name=None),
# TensorSpec(shape=(100,), dtype=tf.int32, name=None))

dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

dataset3.element_spec

#(TensorSpec(shape=(10,), dtype=tf.float32, name=None),
# (TensorSpec(shape=(), dtype=tf.float32, name=None),
#  TensorSpec(shape=(100,), dtype=tf.int32, name=None)))

# Dataset containing a sparse tensor.
dataset4 = tf.data.Dataset.from_tensors(tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))

dataset4.element_spec

#SparseTensorSpec(TensorShape([3, 4]), tf.int32)

# Use value_type to see the type of value represented by the element spec
dataset4.element_spec.value_type

#tensorflow.python.framework.sparse_tensor.SparseTensor

创建生成器

def count(stop):
  i = 0
  while i<stop:
    yield i
    i += 1

for n in count(5):
  print(n)
#0
#1
#2
#3
#4
ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )

for count_batch in ds_counter.repeat().batch(10).take(10):
  print(count_batch.numpy())

#[0 1 2 3 4 5 6 7 8 9]
#[10 11 12 13 14 15 16 17 18 19]
#[20 21 22 23 24  0  1  2  3  4]
#[ 5  6  7  8  9 10 11 12 13 14]
#[15 16 17 18 19 20 21 22 23 24]
#[0 1 2 3 4 5 6 7 8 9]
#[10 11 12 13 14 15 16 17 18 19]
#[20 21 22 23 24  0  1  2  3  4]
#[ 5  6  7  8  9 10 11 12 13 14]
#[15 16 17 18 19 20 21 22 23 24]

再看下,其它例子:

def gen_series():
  i = 0
  while True:
    size = np.random.randint(0, 10)
    yield i, np.random.normal(size=(size,))
    i += 1

for i, series in gen_series():
  print(i, ":", str(series))
  if i > 5:
    break
"""
0 : [ 0.9475 -0.6361  0.9765]
1 : [-0.555   1.3723  0.1027 -1.0957  0.141 ]
2 : [-0.5906 -1.2747  0.5064  1.104   0.1396 -0.1937 -0.3695 -0.5508]
3 : [ 0.2029  0.7422  1.3038  1.0698  1.7587 -0.7051]
4 : [ 0.4777  0.568  -0.7713 -0.0322 -1.0875]
5 : [ 0.2634 -0.3093  0.6087]
6 : [-0.1843  0.6568  0.2268  2.1317 -0.2758 -0.4531]
"""
ds_series = tf.data.Dataset.from_generator(
    gen_series, 
    output_types=(tf.int32, tf.float32), 
    output_shapes=((), (None,)))

ds_series

#<FlatMapDataset shapes: ((), (None,)), types: (tf.int32, tf.float32)>

需要注意一点,就是上述数据长度是不相等。可以使用padded_batch。

ds_series_batch = ds_series.shuffle(20).padded_batch(10)

ids, sequence_batch = next(iter(ds_series_batch))
print(ids.numpy())
print()
print(sequence_batch.numpy())
"""
[10  2 13  4 19 12 25 14  8  0]

[[ 0.6505  0.9339 -0.6796  0.      0.      0.      0.      0.    ]
 [ 1.4865 -0.5334  0.      0.      0.      0.      0.      0.    ]
 [ 0.9648  1.0677 -1.2092 -0.4564  0.9524 -0.5516 -0.8149  1.1307]
 [-1.2593  0.8061  0.7738 -0.6441 -1.3384  1.2362  0.      0.    ]
 [ 0.6877  1.9626 -1.0171 -0.7908  0.      0.      0.      0.    ]
 [ 0.2111  1.687  -0.0555 -0.0242 -1.2556  1.1843 -0.509   1.5797]
 [-0.6607  0.      0.      0.      0.      0.      0.      0.    ]
 [-0.3114 -0.6608  0.      0.      0.      0.      0.      0.    ]
 [ 0.8224 -1.2478 -0.9483  0.6411 -0.9707  1.659  -0.642   0.    ]
 [ 1.7515  1.3955 -0.9958 -0.1844  0.5085 -0.1619  1.0888  1.7601]]
 """

再看下对图像数据的读取

flowers = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)

img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20)

images, labels = next(img_gen.flow_from_directory(flowers))

#Found 3670 images belonging to 5 classes.

print(images.dtype, images.shape)
print(labels.dtype, labels.shape)

#float32 (32, 256, 256, 3)
#float32 (32, 5)

ds = tf.data.Dataset.from_generator(
    img_gen.flow_from_directory, args=[flowers], 
    output_types=(tf.float32, tf.float32), 
    output_shapes=([32,256,256,3], [32,5])
)

ds

#<FlatMapDataset shapes: ((32, 256, 256, 3), (32, 5)), types: (tf.float32, #tf.float32)>

读取并使用 TFRecord data

读取TFRecord data当作输入流:

# Creates a dataset that reads all of the examples from two files.
fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")
'''
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001
7905280/7904079 [==============================] - 1s 0us/step
'''
dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])
dataset

#<TFRecordDatasetV2 shapes: (), types: tf.string>
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值