Tensorflow datasets.shuffle repeat batch方法的妙用,以及batch是否会重复取值

最普通的情况

# 创建0-30的数据集,每个batch取个8数。
import tensorflow as tf
dataset = tf.data.Dataset.range(30).batch(8)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(4):
        value = sess.run(next_element)
        print(value)

输出结果为:

[0 1 2 3 4 5 6 7]
[ 8  9 10 11 12 13 14 15]
[16 17 18 19 20 21 22 23]
[24 25 26 27 28 29]

可以看到,最后一行的输出只有6个数,说明batch方法还是很智能的,并且在这种情况下batch方法取出来的数据并不会存在重复的情况。

如果将迭代次数增加为5次呢?

# 创建0-30的数据集,每个batch取个8数。
import tensorflow as tf
dataset = tf.data.Dataset.range(30).batch(8)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(5):
        value = sess.run(next_element)
        print(value)

# 输出结果:
# outOfRangeError (see above for traceback): End of sequence

此时可以看到会出现outOfRangeError错误,但是我们有时候又不想去计算每次到底循环多少次才能恰好的用batch方法取完数据。这个时候repeate方法就有用了。

# 创建0-30的数据集,每个batch取个数。
import tensorflow as tf
dataset = tf.data.Dataset.range(30).batch(8)
dataset=dataset.repeat(count=None)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(5):
        value = sess.run(next_element)
        print(value)

输出结果为:

[0 1 2 3 4 5 6 7]
[ 8  9 10 11 12 13 14 15]
[16 17 18 19 20 21 22 23]
[24 25 26 27 28 29]
[0 1 2 3 4 5 6 7]

可以看到batch方法已经正确的重复取出数据了,repeate方法当中的参数count代表的是最多将数据重复几次,如果设置了count参数,但是循环的次数依然过大,那么还是会报错,所以为了保险起见,最好是不传入参数,或者count=None。

dataset.shuffle方法的使用

shuffle方法用于打乱数据之后取出,它的可设置参数buffer_size代表的是缓冲区的大小。buffer_size的大小越大代表取出的数据混乱程度越高,其中buffer_size=1,代表有序状态,此时和batch方法一样。

# 创建0-30的数据集,每个batch取个数。
import tensorflow as tf
dataset = tf.data.Dataset.range(30).shuffle(buffer_size=100).batch(8)
dataset=dataset.repeat(count=None)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(5):
        value = sess.run(next_element)
        print(value)

输出结果:

[ 6  4 11 24 15 26 18 27]
[12 28  5 25  0 19 14 20]
[17 10  7  8  3 21 13  2]
[22 16 29  9 23  1]
[11 12 14  5  3  0 26 22]

注意:shuffle的顺序很重要,一般建议是先执行shuffle方法,接着采用batch方法,这样是为了保证在整体数据打乱之后再取出batch_size大小的数据。如果先采取batch方法再采用shuffle方法,那么此时就只是对batch进行shuffle,而batch里面的数据顺序依旧是有序的,那么随机程度会减弱。示例如下:

将上一段代码的batch方法放到shuffle方法的前面

# 创建0-30的数据集,每个batch取个8数。
import tensorflow as tf
dataset = tf.data.Dataset.range(30).batch(8).shuffle(buffer_size=100)
dataset=dataset.repeat(count=None)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(5):
        value = sess.run(next_element)
        print(value)

输出:

[16 17 18 19 20 21 22 23]
[ 8  9 10 11 12 13 14 15]
[24 25 26 27 28 29]
[0 1 2 3 4 5 6 7]
[24 25 26 27 28 29]

可以看出,数据依然是有序的,起不到打乱的作用。

也可参考博客:参考博客

### 回答1: tensorflow.keras.datasets.mnist是一个内置的数据集,用于识别手写数字的机器学习任务。该数据集包含了60000张28x28像素的训练图像和10000张测试图像,每张图像都代表一个手写数字(0-9之间)。这个数据集常用于深度学习的图像分类任务。 使用tensorflow.keras.datasets.mnist,可以很方便地加载和使用这个数据集。通过调用load_data()函数,可以将训练和测试数据分别加载到变量。这些数据已经划分好了训练集和测试集,可以直接用于模型的训练和评估。 加载数据后,可以对图像进行预处理和准备,并构建机器学习模型来识别手写数字。通常,经典的深度学习模型,如卷积神经网络(CNN),在这个任务上表现良好。 在训练模型时,可以使用训练集来调整模型的参数,使其可以准确地预测手写数字。训练集的标签提供了每个图像对应的真实数字,可以用于监督学习。 在模型训练完成后,可以使用测试集来评估模型的性能和准确度。测试集的标签提供了每个测试图像的真实数字,可以与模型的预测结果进行比较,从而得到模型的准确率。 总的来说,tensorflow.keras.datasets.mnist提供了一个方便的方式来获取和使用手写数字数据集,可以用于构建和训练机器学习模型,实现手写数字识别任务。 ### 回答2: tensorflow.keras.datasets.mnist是一个常用的数据集,用于机器学习数字识别的训练和测试。该数据集包含了60,000个用于训练的手写数字图像和10,000个用于测试的手写数字图像。 这个数据集可以通过tensorflow.keras.datasets模块的mnist.load_data()函数来加载。这个函数返回两个元组,分别是训练集和测试集。每个元组都包括了两个numpy数组,一个是图像数组,另一个是对应的标签数组。 训练集包括了60,000个28x28像素的灰度图像,用于训练模型。每个图像数组都是一个形状为(28, 28)的二维numpy数组,表示一个手写数字图像。对应的标签数组是一个形状为(60000,)的一维numpy数组,包含了0到9之间的整数,表示了对应图像的真实数字。 测试集包括了10,000个用于测试模型的手写数字图像,和训练集相似,每个图像数组是一个形状为(28, 28)的二维numpy数组。对应的标签数组是一个形状为(10000,)的一维numpy数组,包含了0到9之间的整数,表示了对应图像的真实数字。 使用这个数据集可以帮助我们训练和评估模型的性能,比如使用卷积神经网络对手写数字进行分类。加载mnist数据集并将其拆分为训练集和测试集后,我们可以使用这些数据来训练模型,并使用测试集来评估模型在未见过的数据上的表现。 总之,tensorflow.keras.datasets.mnist提供了一个方便且广泛使用的手写数字识别数据集,供机器学习研究和实践使用。 ### 回答3: tensorflow.keras.datasets.mnist是一个常用的数据集,用于机器学习领域的手写数字识别任务。该数据集包含了60000张28x28像素的训练图像和10000张测试图像。 这个数据集可以通过以下代码导入: ``` (train_images, train_labels), (test_images, test_labels) = tensorflow.keras.datasets.mnist.load_data() ``` 其train_images和train_labels是训练图像和对应的标签,test_images和test_labels是测试图像和对应的标签。 train_images和test_images都是三维数组,表示图像的像素值。每张图像都由28行28列的像素组成,像素值范围为0-255。 train_labels和test_labels是一维数组,表示图像对应的真实数字标签。标签范围为0-9,分别表示数字0到9。 加载完数据集后,我们可以进行数据预处理,例如将像素值缩放到0-1之间: ``` train_images = train_images / 255.0 test_images = test_images / 255.0 ``` 然后可以使用这些数据来训练机器学习模型,例如使用卷积神经网络进行手写数字识别的训练: ``` model = tensorflow.keras.models.Sequential([ tensorflow.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(28, 28, 1)), tensorflow.keras.layers.MaxPooling2D((2, 2)), tensorflow.keras.layers.Flatten(), tensorflow.keras.layers.Dense(64, activation='relu'), tensorflow.keras.layers.Dense(10, activation='softmax') ]) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) model.fit(train_images, train_labels, epochs=10) ``` 通过这个数据集和训练示例,我们可以建立一个手写数字识别模型,并用测试集进行评估和预测。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值