tensorflow dataset.shuffle dataset.batch dataset.repeat

转自:https://blog.csdn.net/qq_16234613/article/details/81703228

batch很好理解,就是batch size。注意在一个epoch中最后一个batch大小可能小于等于batch size
dataset.repeat就是俗称epoch,但在tf中与dataset.shuffle的使用顺序可能会导致个epoch的混合
dataset.shuffle就是说维持一个buffer size 大小的 shuffle buffer,图中所需的每个样本从shuffle buffer中获取,取得一个样本后,就从源数据集中加入一个样本到shuffle buffer中。

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.shuffle(3)
dataset = dataset.batch(4)
dataset = dataset.repeat(2)

# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()

with tf.Session() as sess:
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
#源数据集
[[ 0.5488135   0.71518937]
 [ 0.60276338  0.54488318]
 [ 0.4236548   0.64589411]
 [ 0.43758721  0.891773  ]
 [ 0.96366276  0.38344152]
 [ 0.79172504  0.52889492]
 [ 0.56804456  0.92559664]
 [ 0.07103606  0.0871293 ]
 [ 0.0202184   0.83261985]
 [ 0.77815675  0.87001215]
 [ 0.97861834  0.79915856]]

# 通过shuffle batch后取得的样本
[[ 0.4236548   0.64589411]
 [ 0.60276338  0.54488318]
 [ 0.43758721  0.891773  ]
 [ 0.5488135   0.71518937]]
[[ 0.96366276  0.38344152]
 [ 0.56804456  0.92559664]
 [ 0.0202184   0.83261985]
 [ 0.79172504  0.52889492]]
[[ 0.07103606  0.0871293 ]
 [ 0.97861834  0.79915856]
 [ 0.77815675  0.87001215]]  #最后一个batch样本个数为3
[[ 0.60276338  0.54488318]
 [ 0.5488135   0.71518937]
 [ 0.43758721  0.891773  ]
 [ 0.79172504  0.52889492]]
[[ 0.4236548   0.64589411]
 [ 0.56804456  0.92559664]
 [ 0.0202184   0.83261985]
 [ 0.07103606  0.0871293 ]]
[[ 0.77815675  0.87001215]
 [ 0.96366276  0.38344152]
 [ 0.97861834  0.79915856]] #最后一个batch样本个数为3

1、按照shuffle中设置的buffer size,首先从源数据集取得三个样本:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.4236548 0.64589411]
2、从buffer中取一个样本到batch中得:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
batch:
[ 0.4236548 0.64589411]
3、shuffle buffer不足三个样本,从源数据集提取一个样本:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.43758721 0.891773 ]
4、从buffer中取一个样本到batch中得:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.43758721 0.891773 ]
batch:
[ 0.4236548 0.64589411]
[ 0.60276338 0.54488318]
5、如此反复。这就意味中如果shuffle 的buffer size=1,数据集不打乱。如果shuffle 的buffer size=数据集样本数量,随机打乱整个数据集

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.shuffle(1)
dataset = dataset.batch(4)
dataset = dataset.repeat(2)

# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()

with tf.Session() as sess:
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))

[[ 0.5488135   0.71518937]
 [ 0.60276338  0.54488318]
 [ 0.4236548   0.64589411]
 [ 0.43758721  0.891773  ]
 [ 0.96366276  0.38344152]
 [ 0.79172504  0.52889492]
 [ 0.56804456  0.92559664]
 [ 0.07103606  0.0871293 ]
 [ 0.0202184   0.83261985]
 [ 0.77815675  0.87001215]
 [ 0.97861834  0.79915856]]

[[ 0.5488135   0.71518937]
 [ 0.60276338  0.54488318]
 [ 0.4236548   0.64589411]
 [ 0.43758721  0.891773  ]]
[[ 0.96366276  0.38344152]
 [ 0.79172504  0.52889492]
 [ 0.56804456  0.92559664]
 [ 0.07103606  0.0871293 ]]
[[ 0.0202184   0.83261985]
 [ 0.77815675  0.87001215]
 [ 0.97861834  0.79915856]]
[[ 0.5488135   0.71518937]
 [ 0.60276338  0.54488318]
 [ 0.4236548   0.64589411]
 [ 0.43758721  0.891773  ]]
[[ 0.96366276  0.38344152]
 [ 0.79172504  0.52889492]
 [ 0.56804456  0.92559664]
 [ 0.07103606  0.0871293 ]]
[[ 0.0202184   0.83261985]
 [ 0.77815675  0.87001215]
 [ 0.97861834  0.79915856]]

注意如果repeat在shuffle之前使用:
官方说repeat在shuffle之前使用能提高性能,但模糊了数据样本的epoch关系

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.repeat(2)
dataset = dataset.shuffle(11)
dataset = dataset.batch(4)

# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()

with tf.Session() as sess:
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))

[[ 0.5488135   0.71518937]
 [ 0.60276338  0.54488318]
 [ 0.4236548   0.64589411]
 [ 0.43758721  0.891773  ]
 [ 0.96366276  0.38344152]
 [ 0.79172504  0.52889492]
 [ 0.56804456  0.92559664]
 [ 0.07103606  0.0871293 ]
 [ 0.0202184   0.83261985]
 [ 0.77815675  0.87001215]
 [ 0.97861834  0.79915856]]

[[ 0.56804456  0.92559664]
 [ 0.5488135   0.71518937]
 [ 0.60276338  0.54488318]
 [ 0.07103606  0.0871293 ]]
[[ 0.96366276  0.38344152]
 [ 0.43758721  0.891773  ]
 [ 0.43758721  0.891773  ]
 [ 0.77815675  0.87001215]]
[[ 0.79172504  0.52889492]   #出现相同样本出现在同一个batch中
 [ 0.79172504  0.52889492]
 [ 0.60276338  0.54488318]
 [ 0.4236548   0.64589411]]
[[ 0.07103606  0.0871293 ]
 [ 0.4236548   0.64589411]
 [ 0.96366276  0.38344152]
 [ 0.5488135   0.71518937]]
[[ 0.97861834  0.79915856]
 [ 0.0202184   0.83261985]
 [ 0.77815675  0.87001215]
 [ 0.56804456  0.92559664]]
[[ 0.0202184   0.83261985]
 [ 0.97861834  0.79915856]]          #可以看到最后个batch为2,而前面都是4  
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值