Tensorflow 数据对象Dataset.shuffle()、repeat()、batch() 等用法

1.Dataset数据对象

Dataset可以用来表示输入管道元素集合(张量的嵌套结构)和“逻辑计划“对这些元素的转换操作。在Dataset中元素可以是向量,元组或字典等形式。
另外,Dataset需要配合另外一个类Iterator进行使用,Iterator对象是一个迭代器,可以对Dataset中的元素进行迭代提取。

2.Dataset方法

2.1 产生数据集
2.1.1. from_tensor_slices

from_tensor_slices 用于创建dataset,其元素是给定张量的切片的元素。

函数形式:from_tensor_slices(tensors)

参数tensors:张量的嵌套结构,每个都在第0维中具有相同的大小。

import tensorflow as tf
#创建一个Dataset对象
dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8,9,10,11,12])
'''合成批次'''
dataset=dataset.batch(5)
#创建一个迭代器
iterator = dataset.make_one_shot_iterator()
#get_next()函数可以帮助我们从迭代器中获取元素
element = iterator.get_next()
 
#遍历迭代器,获取所有元素
with tf.Session() as sess:   
    for i in range(9):
       print(sess.run(element))  

输出

[1 2 3 4 5]
[ 6  7  8  9 10]
[11 12]

2.1.2 .from_tensors

创建一个Dataset包含给定张量的单个元素。

函数形式:from_tensors(tensors)

参数tensors:张量的嵌套结构。

dataset = tf.data.Dataset.from_tensors([1,2,3,4,5,6,7,8,9])
 
iterator = concat_dataset.make_one_shot_iterator()
 
element = iterator.get_next()
 
with tf.Session() as sess:   
    for i in range(1):
       print(sess.run(element))

区别:

  • from_tensors是将tensors作为一个整体进行操纵,而from_tensor_slices可以操纵tensors里面的元素。

2.1.3 from_generator(具体实践不太了解)

创建Dataset由其生成元素的元素generator。

函数形式:from_generator(generator,output_types,output_shapes=None,args=None)

参数generator:一个可调用对象,它返回支持该iter()协议的对象 。如果args未指定,generator则不得参数; 否则它必须采取与有值一样多的参数args。
参数output_types:tf.DType对应于由元素生成的元素的每个组件的对象的嵌套结构generator。
参数output_shapes:tf.TensorShape 对应于由元素生成的元素的每个组件的对象 的嵌套结构generator
参数args:tf.Tensor将被计算并将generator作为NumPy数组参数传递的对象元组。

 

2.2 数据转换Transformation
2.2.1 batch

# 创建0-10的数据集,每个batch取个数6。
dataset = tf.data.Dataset.range(10).batch(6)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(2):
        value = sess.run(next_element)
        print(value)

但是如果我们把循环次数设置成3(即for i in range(2)),那么就会报错。

或者将for循环改为:while True:。就不用设置循环次数了。

 

2.2.2 shuffle


上面所有输出结果都是有序的,在机器学习中训练模型需要将数据打乱,这样可以保证每批次训练的时候所用到的数据集是不一样的,可以提高模型训练效果。

注意:shuffle的顺序很重要,应该先shuffle再batch,如果先batch后shuffle的话,那么此时就只是对batch进行shuffle,而batch里面的数据顺序依旧是有序的,那么随机程度会减弱(实际并未shuffle)。

随机混洗数据集的元素。

函数形式:shuffle(buffer_size,seed=None,reshuffle_each_iteration=None)

参数buffer_size:表示新数据集将从中采样的数据集中的元素数。

buffer_size=1:不打乱顺序,既保持原序
buffer_size越大,打乱程度越大
参数seed:(可选)表示将用于创建分布的随机种子。
参数reshuffle_each_iteration:(可选)一个布尔值,如果为true,则表示每次迭代时都应对数据集进行伪随机重组。(默认为True。)

在这里buffer_size:该函数的作用就是先构建buffer,大小为buffer_size,然后从dataset中提取数据将它填满。batch操作,从buffer中提取。

如果buffer_size小于Dataset的大小,每次提取buffer中的数据,会再次从Dataset中抽取数据将它填满(当然是之前没有抽过的)。所以一般最好的方式是buffer_size= Dataset_size。

 

2.2.3 map


map可以将map_func函数映射到数据集.
map接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset,

函数形式:flat_map(map_func,num_parallel_calls=None)

参数map_func:映射函数
参数num_parallel_calls:表示要并行处理的数字元素。如果未指定,将按顺序处理元素。如果使用值tf.data.experimental.AUTOTUNE,则根据可用的CPU动态设置并行调用的数量。

对dataset中每个元素的值加10

dataset = tf.data.Dataset.range(10).batch(6).shuffle(10)
dataset = dataset.map(lambda x: x + 10)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(2):
        value = sess.run(next_element)
        print(value)

[16 17 18 19]
[10 11 12 13 14 15]


2.2.4 repeat

重复此数据集次数,主要用来处理机器学习中的epoch,假设原先的数据训练一个epoch,使用repeat(2)就可以将之变成2个epoch,默认空是无限次。

函数形式:repeat(count=None)

参数count:(可选)表示数据集应重复的次数。默认行为(如果count是None或-1)是无限期重复的数据集。

 


————————————————
版权声明:本文为CSDN博主「rrr2」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_35608277/article/details/116333888

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值