看到tf好多预处理数据的时候都会使用dataset类及其一些方法, 现在解释说明如下:
这里先举一个连续处理的小案例, 然后一步步说明:
创建dataset方法很多例如:
这里选择一个简单易懂的方法
ds_train = tf.data.Dataset.from_tensor_slices([0, 1, 2, 3, 4, 5, 6]) \
.map(add) \
.shuffle(buffer_size=9)\
.batch(100) \
.prefetch(-1)\
.repeat(1)
list(ds_train.unbatch())
结果:
下面一步步讲解:先定义处理函数, 这里简单地定义为加1了
def add(x):
return x + 1
第一步:创建数据集, 根据自己的数据采用合适的方法
d1 = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9])
for element in d1:
print(element)
第二步: 进行预处理操作
d2 = d1.map(add)
for num in d2:
print(num)
第三步:打乱顺序, 这里的参数适当的大一些, 当取值为1 时 并没有打乱顺序
d3 = d2.shuffle(buffer_size=1)
for nums in d3:
print(nums)
增大buffer_size的值
第四步: 确定batch的大小:
d4 = d3.batch(2)
for nums in d4:
print(nums)
第五步: prefetch()这步是在内部进行处理需要的数据, 是随着使用实时生成的, 这样提高效率参数可以参考cpu的数量, 当我们两次取值是不一样的组合
d5 = d4.prefetch(-1)
for nums in d5:
print(nums)
print("=============")
for nums in d5:
print(nums)
第六步:repeat()重复
d6 = d5.repeat(3)
for nums in d6:
print(nums)
结果: