tf.data模块

目标:

tf.data模块构建输入


内容:

tf.data模块用于构建输入,且可以处理大量数据、不同的数据格式,以及复杂的数据转换。tf.data.Dateset表示一组数据。tf.data.Dateset中一个元素包含一个或者多个Tensor对象,例如一个元素代表单个训练样本或者代表一对训练数据和标签。


过程:

直接从 Tensor 创建 Dataset
tf.data.Dataset.from_tensor_slices(tensors)
tensors可以为列表,字典,元组,numpy的ndarray,tensor
Dataset对象中每个元素的结构必须相同,每个元素可以包含一个或多个tensor对象,这些tensor对象被称为组件。
在tensorflow 2.0 环境下可以直接对 Datasett对象直接进行迭代处理
1、一维数据创建

import tensorflow as tf
tf.__version__

dataset = tf.data.Dataset.from_tensor_slice([1, 2, 3, 4, 5, 6, 7])
dataset
<TensorSliceDataset shapes: (), types: tf.int32>

TensorSliceDataset shapes是指dataset中每个元素的shape,这里元素为数字,所以shapes为(),可以迭代

for ele in dataset:
	print(ele)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
tf.Tensor(5, shape=(), dtype=int32)
tf.Tensor(6, shape=(), dtype=int32)
tf.Tensor(7, shape=(), dtype=int32)

当使用tf.data.Dataset.from_tensor_slice会把每个元素[1, 2, 3, 4, 5, 6, 7]变成一个组件,但是会转化成tf.tensor数据类型,所以结果是tf.Tensor(1, shape=(), dtype=int32)
可以用.numpy()的方法,对Tensor进行转换成numpy数据类型

for ele in dataset:
    print(ele.numpy())
1
2
3
4
5
6
7

2、用一个二维列表创建,二维形状要相同

dataset = tf.data.Dataset.from_tensor_slice([[1, 2], [3, 4]])
dataset
<TensorSliceDataset shapes: (2,), types: tf.int32>
for ele in dataset:
	print(ele)
tf.Tensor([1 2], shape=(2,), dtype=int32)
tf.Tensor([3 4], shape=(2,), dtype=int32)
for ele in dataset:
    print(ele.numpy())
[1 2]
[3 4]

3、用字典类型创建,Dataset的每个元素就是一个字典

dataset_dict =  tf.data.Dataset.from_tensor_slices({'a': [1, 2, 3],  'b': [4,5,6], 'c':[7, 8, 9]})
<TensorSliceDataset shapes: {a: (), b: (), c: ()}, types: {a: tf.int32, b: tf.int32, c: tf.int32}>
for ele in dataset_dict:
    print(ele)
{'a': <tf.Tensor: shape=(), dtype=int32, numpy=1>, 'b': <tf.Tensor: shape=(), dtype=int32, numpy=4>, 'c': <tf.Tensor: shape=(), dtype=int32, numpy=7>}
{'a': <tf.Tensor: shape=(), dtype=int32, numpy=2>, 'b': <tf.Tensor: shape=(), dtype=int32, numpy=5>, 'c': <tf.Tensor: shape=(), dtype=int32, numpy=8>}
{'a': <tf.Tensor: shape=(), dtype=int32, numpy=3>, 'b': <tf.Tensor: shape=(), dtype=int32, numpy=6>, 'c': <tf.Tensor: shape=(), dtype=int32, numpy=9>}
for ele in dataset_dict:
   print(ele['a'].numpy(),ele['b'].numpy(),ele['c'].numpy())
1 4 7
2 5 8
3 6 9

使用Dataset

import numpy as np
dataset = tf.data.Dataset.from_tensor_slices(np.array([1, 2, 3, 4]))
for ele in dataset.take(4):
    print(ele.numpy())
1
2
3
4

数据乱序
Dataset.shuffle(buffer_size) 将当前Dataset中buffer_size个元素填充缓冲区,在缓冲区进行随机采样。

dataset = tf.data.Dataset.from_tensor_slices(np.array([1, 2, 3, 4, 5, 6, 7]))
for ele in dataset.shuffle(7)print(ele.numpy())
7
6
1
3
4
2
5

数据重复,将Dataset中数据重复count次,count默认为None,表示一直重复

for ele in dataset.repeat(2):
    print(ele.numpy())
1
2
3
4
5
6
7
1
2
3
4
5
6
7

Dataset.batch(batch_size) 将Dataset中batch个元素合并为一个元素

for ele in dataset.batch(2):
    print(ele.numpy())
[1 2]
[3 4]
[5 6]
[7]

Dataset.map(map_func) 将Dataset的每个元素都用map_func处理

for ele in dataset.map(tf.square):
    print(ele.numpy())
1
4
9
16
25
36

Dataset.zip(datasets) 将datasets中各个dataset的对应元素合并为元组作为新的Dataset的元素

a = tf.data.Dataset.from_tensor_slices([[1], [2], [3]])
b = tf.data.Dataset.from_tensor_slices([4, 5, 6])
print(tf.data.Dataset.zip((a, b)))
for ele in tf.data.Dataset.zip((a, b)):
    print(ele[0], ele[1])
<ZipDataset shapes: ((1,), ()), types: (tf.int32, tf.int32)>
tf.Tensor([1], shape=(1,), dtype=int32) tf.Tensor(4, shape=(), dtype=int32)
tf.Tensor([2], shape=(1,), dtype=int32) tf.Tensor(5, shape=(), dtype=int32)
tf.Tensor([3], shape=(1,), dtype=int32) tf.Tensor(6, shape=(), dtype=int32)

示例 mnist数据集

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images.shape
(60000, 28, 28)
train_labels.shape
(60000,)
train_images[0]
array([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   3,
         18,  18,  18, 126, 136, 175,  26, 166, 255, 247, 127,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,  30,  36,  94, 154, 170,
        253, 253, 253, 253, 253, 225, 172, 253, 242, 195,  64,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,  49, 238, 253, 253, 253, 253,
        253, 253, 253, 253, 251,  93,  82,  82,  56,  39,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,  18, 219, 253, 253, 253, 253,
        253, 198, 182, 247, 241,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,  80, 156, 107, 253, 253,
        205,  11,   0,  43, 154,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,  14,   1, 154, 253,
         90,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 139, 253,
        190,   2,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  11, 190,
        253,  70,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  35,
        241, 225, 160, 108,   1,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         81, 240, 253, 253, 119,  25,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,  45, 186, 253, 253, 150,  27,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,  16,  93, 252, 253, 187,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0, 249, 253, 249,  64,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,  46, 130, 183, 253, 253, 207,   2,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  39,
        148, 229, 253, 253, 253, 250, 182,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  24, 114, 221,
        253, 253, 253, 253, 201,  78,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,  23,  66, 213, 253, 253,
        253, 253, 198,  81,   2,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,  18, 171, 219, 253, 253, 253, 253,
        195,  80,   9,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,  55, 172, 226, 253, 253, 253, 253, 244, 133,
         11,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0, 136, 253, 253, 253, 212, 135, 132,  16,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0]], dtype=uint8)
plt.imshow(train_images[0])
![在这里插入图片描述](https://img-blog.csdnimg.cn/20210428095946836.png)

图片数据归一化

train_images = train_images/255
test_images = test_images/255

创建image的dataset

dataset_images = tf.data.Dataset.from_tensor_slices(train_images)
dataset_images
<TensorSliceDataset shapes: (28, 28), types: tf.float64>

创建label的dataset

dataset_labels = tf.data.Dataset.from_tensor_slices(train_labels)
dataset_labels
<TensorSliceDataset shapes: (), types: tf.uint8>

两个dataset压缩在一起, 用元组的方式

dataset = tf.data.Dataset.zip((dataset_images, dataset_labels))
dataset
<ZipDataset shapes: ((28, 28), ()), types: (tf.float64, tf.uint8)>
batch_size = 256
dataset = dataset.shuffle(train_images.shape[0]).repeat().batch(batch_size)

建立模型

model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

模型编译

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

模型训练

steps_per_epoch = train_images.shape[0]/batch_size
model.fit(dataset, epochs=5, steps_per_epoch=steps_per_epoch)
Epoch 1/5
235/234 [==============================] - 1s 6ms/step - loss: 0.4398 - accuracy: 0.8820
Epoch 2/5
235/234 [==============================] - 1s 5ms/step - loss: 0.1984 - accuracy: 0.9438
Epoch 3/5
235/234 [==============================] - 1s 4ms/step - loss: 0.1468 - accuracy: 0.9592
Epoch 4/5
235/234 [==============================] - 1s 4ms/step - loss: 0.1167 - accuracy: 0.9670
Epoch 5/5
235/234 [==============================] - 1s 4ms/step - loss: 0.0976 - accuracy: 0.9720
<tensorflow.python.keras.callbacks.History at 0x1af323831d0>

参考文献:

https://study.163.com/course/introduction/1004573006.htm

  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值