本篇博客主要讲解tf.data.Dataset的batch,repeat,shuffle函数,最后着重讲它们之间的顺序使用问题。
tf.data.Dataset官方文档
batch函数讲解
以下先创建一个Dataset类型
import tensorflow as tf
t = tf.range(10.)[:, None]
t = tf.data.Dataset.from_tensor_slices(t)
#<TensorSliceDataset shapes: (1,), types: tf.float32>
for i in t:
print(i.numpy())
#[0.]
#[1.]
#[2.]
#[3.]
#[4.]
#[5.]
#[6.]
#[7.]
#[8.]
#[9.]
然后我们看下batch函数
batch_t = t.batch(2)
#<BatchDataset shapes: (None, 1), types: tf.float32>
for i in batch_t:
print(i.numpy())
"""
[[0.]
[1.]]
[[2.]
[3.]]
[[4.]
[5.]]
[[6.]
[7.]]
[[8.]
[9.]]
"""
以上可以清楚的看到,batch将2个数据分为一组了
repeat函数讲解
repeat的参数:count=None,表示将数据重复count次。
repeat_t = t.repeat(2)
for i in repeat_t:
print(i.numpy())
#输出太长,可以看下结果
shuffle函数讲解
该函数的作用就是打乱数据顺序,其中有个参数需要着重说明下。
buffer_size:该函数的作用就是先构建buffer,大小为buffer_size,然后从Dataset中提取数据将它填满。batch操作,从buffer中提取。如果buffer_size小于Dataset的大小,每次提取buffer中的数据,会再次从Dataset中抽取数据将它填满(当然是之前没有抽过的)。所以一般最好的方式是buffer_size=Dataset_size。
shuffle_t = t.shuffle(10)
for i in shuffle_t:
print(i.numpy())
#[8.]
#[1.]
#[6.]
#[5.]
#[9.]
#[4.]
#[3.]
#[7.]
#[0.]
#[2.]
顺序对比
t1 = t.shuffle(10).batch(2)
#这个是先打乱t的顺序,然后batch
t2 = t.batch(2).shuffle(10)
#这个是打乱batch的顺序
t3 = t.batch(2).repeat(2)
#重复batch,而不是数据
t4 = t.repeat(2).batch(2)
#重复数据,再batch
可以自己print出来看一下,顺序不同每一个数据的意义是不相同的