Tensorflow - Dataset 之 repeat(), shuffle(), batch()作用

本文介绍了 TensorFlow 中 Dataset 的关键操作:repeat() 用于使数据集可无限重复,shuffle() 打乱数据顺序,batch() 设置批量处理的样本数。通过示例展示了这些函数如何改变数据流的处理方式,对于理解和构建深度学习模型的输入流水线非常有帮助。
摘要由CSDN通过智能技术生成

Tensorflow - Dataset 之 repeat(), shuffle(), batch()作用

repeat(): 该函数让数据集重复的次数, 如没有参数,则数据集可以任意获取

shuffle(): 打乱数据集的顺序

batch(): 设置一次操作允许获取的数据个数

import tensorflow as tf
import numpy as np

feature = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], np.float32)

label = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1])

train_data = tf.data.Dataset.from_tensor_slices((feature, label))  # 定义9个数据的数据集


def print_train_data(data, cnt):
    it = data.__iter__()

    for i in range(cnt):
        x, y = it.next()

        print(f"{i}: {x} - {y}")


print_train_data(train_data, 9)
# print_train_data(train_data, 10)  # 出错

print("=== after repeat ====")

train_data = train_data.repeat()  # 调用该函数后后面可以无限使用该数据集

print_train_data(train_data, 9)  # print_train_data(train_data, 10)  # 可以无限循环读取: repeat 留空为无限

print("=== after shuffle ====")

train_data = train_data.shuffle(buffer_size=2)  # 打乱数据集的顺序( 如果为 1 的话, 打乱顺序则无用)

print_train_data(train_data, 9)

print("=== after batch ====")

dataset_batch = train_data.batch(batch_size=3)  # 设置每次回去数据集的数据条数 

it = dataset_batch.__iter__()

print_train_data(it, 9)

打印的数据如下:

0: 1.0 - 0
1: 2.0 - 0
2: 3.0 - 0
3: 4.0 - 0
4: 5.0 - 1
5: 6.0 - 1
6: 7.0 - 1
7: 8.0 - 1
8: 9.0 - 1
=== after repeat ====
0: 1.0 - 0
1: 2.0 - 0
2: 3.0 - 0
3: 4.0 - 0
4: 5.0 - 1
5: 6.0 - 1
6: 7.0 - 1
7: 8.0 - 1
8: 9.0 - 1
=== after shuffle ====
0: 2.0 - 0
1: 3.0 - 0
2: 1.0 - 0
3: 5.0 - 1
4: 6.0 - 1
5: 7.0 - 1
6: 4.0 - 0
7: 9.0 - 1
8: 8.0 - 1
=== after batch ====
0: [2. 3. 4.] - [0 0 0]
1: [5. 1. 6.] - [1 0 1]
2: [8. 7. 9.] - [1 1 1]
3: [1. 3. 4.] - [0 0 0]
4: [2. 5. 6.] - [0 1 1]
5: [8. 9. 7.] - [1 1 1]
6: [2. 1. 4.] - [0 0 0]
7: [5. 3. 6.] - [1 0 1]
8: [7. 8. 9.] - [1 1 1]

reference

@online{BibEntry2022May,
title = {{Tensorflow - Dataset 之 repeat(), shuffle(), batch()作用_aaronychen的博客-CSDN博客_train_dataset.shuffle}},
year = {2022},
month = may,
date = {2022-05-13},
urldate = {2022-05-13},
language = {chinese},
hyphenation = {chinese},
note = {[Online; accessed 13. May 2022]},
url = {https://blog.csdn.net/aaronychen/article/details/122879141},
keywords = {train_dataset.shuffle},
abstract = {{该文章简要描述了tensorflow 下 DataSet 一些函数的基本操作}}
}

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

或许,这就是梦想吧!

如果对你有用,欢迎打赏。

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值