tf.data.Dataset关于batch,repeat,shuffle的讲解

本篇博客主要讲解tf.data.Dataset的batch,repeat,shuffle函数,最后着重讲它们之间的顺序使用问题。
tf.data.Dataset官方文档

batch函数讲解

以下先创建一个Dataset类型

import tensorflow as tf
t = tf.range(10.)[:, None]
t = tf.data.Dataset.from_tensor_slices(t)
#<TensorSliceDataset shapes: (1,), types: tf.float32>
for i in t:
    print(i.numpy())
#[0.]
#[1.]
#[2.]
#[3.]
#[4.]
#[5.]
#[6.]
#[7.]
#[8.]
#[9.]

然后我们看下batch函数

batch_t = t.batch(2)
#<BatchDataset shapes: (None, 1), types: tf.float32>
for i in batch_t:
    print(i.numpy())
"""
[[0.]
 [1.]]
[[2.]
 [3.]]
[[4.]
 [5.]]
[[6.]
 [7.]]
[[8.]
 [9.]]
"""

以上可以清楚的看到,batch将2个数据分为一组了

repeat函数讲解

repeat的参数:count=None,表示将数据重复count次。

repeat_t = t.repeat(2)
for i in repeat_t:
    print(i.numpy())
#输出太长,可以看下结果

shuffle函数讲解

该函数的作用就是打乱数据顺序,其中有个参数需要着重说明下。
buffer_size:该函数的作用就是先构建buffer,大小为buffer_size,然后从Dataset中提取数据将它填满。batch操作,从buffer中提取。如果buffer_size小于Dataset的大小,每次提取buffer中的数据,会再次从Dataset中抽取数据将它填满(当然是之前没有抽过的)。所以一般最好的方式是buffer_size=Dataset_size。

shuffle_t = t.shuffle(10)
for i in shuffle_t:
    print(i.numpy())
#[8.]
#[1.]
#[6.]
#[5.]
#[9.]
#[4.]
#[3.]
#[7.]
#[0.]
#[2.]

顺序对比

t1 = t.shuffle(10).batch(2)
#这个是先打乱t的顺序,然后batch
t2 = t.batch(2).shuffle(10)
#这个是打乱batch的顺序
t3 = t.batch(2).repeat(2)
#重复batch,而不是数据
t4 = t.repeat(2).batch(2)
#重复数据,再batch

可以自己print出来看一下,顺序不同每一个数据的意义是不相同的

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值