Tensorflow.Dataset中map,shuffle,repeat,batch的总结

Dataset API是TensorFlow 1.3版本中引入的一个新的模块,主要服务于数据读取,构建输入数据的pipeline。

Google官方给出的Dataset API中的类图:

在这里插入图片描述
我们本文只关注Dataset的一类特殊的操作:Transformation,即map,shuffle,repeat,batch等。

在正式介绍之前,我们再回忆一下深度学习中的一些基本概念。

  • batch size指的就是更新梯度中使用的样本数。如果把batch_size设置为数据集的长度,就成了批量梯度下降算法,batch_size设置为1就是随机梯度下降算法
  • 一次epoch=所有训练数据forward+backward后更新参数的过程。
  • 一次iteration=[batch
    size]个训练数据forward+backward后更新参数过程。即每跑完一个batch都要更新参数,这个过程叫一个iteration

一、batch

最简单的情况如下:

# 创建0-10的数据集,每个batch取个数6。
dataset = tf.data.Dataset.range(10).batch(6)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(2):
        value = sess.run(next_element)
        print(value)

结果为:

[0 1 2 3 4 5]
[6 7 8 9]

但是如果我们把循环次数设置成3(即for i in range(2)),那么就会报错。

二、repeat

repeat方法可以解决上述问题,repeat的功能就是将整个数据重复多次,主要用来处理机器学习中的epoch,假设原先的数据是一个epoch,使用repeat(2)就可以将之变成2个epoch:

dataset = tf.data.Dataset.range(10).batch(6)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(4):
        value = sess.run(next_element)
        print(value)

结果如下:

[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]

当然,如果觉得每次都需要设置repeat的次数麻烦,我们也可以不设置repeat,代码如下:

dataset = tf.data.Dataset.range(10).batch(6)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(6):
        value = sess.run(next_element)
        print(value)

结果:

[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]

三、shuffle

仔细看可以知道上面所有输出结果都是有序的,在机器学习中训练模型需要将数据打乱,这样可以保证每批次训练的时候所用到的数据集是不一样的,可以提高模型训练效果。

shuffle的功能为打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的buffer的大小,不设置会报错,

  • buffer_size=1:不打乱顺序,既保持原序
  • buffer_size越大,打乱程度越大,演示效果见如下代码:
dataset = tf.data.Dataset.range(10).shuffle(2).batch(6)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(4):
        value = sess.run(next_element)
        print(value)

结果如下:

[1 0 2 4 3 5]
[7 8 9 6]
[1 2 3 4 0 6]
[7 8 9 5]

注意:shuffle的顺序很重要,应该先shuffle再batch,如果先batch后shuffle的话,那么此时就只是对batch进行shuffle,而batch里面的数据顺序依旧是有序的,那么随机程度会减弱。

ataset = tf.data.Dataset.range(10).batch(6).shuffle(10)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(4):
        value = sess.run(next_element)
        print(value)

结果:

[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]

可以看到实际并没有shuffle

四、map

map接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset,如我们可以对dataset中每个元素的值加10

dataset = tf.data.Dataset.range(10).batch(6).shuffle(10)
dataset = dataset.map(lambda x: x + 10)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(2):
        value = sess.run(next_element)
        print(value)

结果


[16 17 18 19]
[10 11 12 13 14 15]

参考文献

【1】Tensorflow datasets.shuffle repeat batch方法
【2】TensorFlow全新的数据读取方式:Dataset API入门教程
【3】Module: tf.data

  • 30
    点赞
  • 71
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值