Dataset 使用基础
这一部分主要是对Tensorflow中Dataset的使用基础进行描述
Dataset语法
假设m,… = data.shape
tf.data.Dataset.from_tensor_slices(data)这里from_tensor_slices主要是将data进行切片,将data分为m个矩阵,注意这里分割为m个矩阵只和data的第一维数据相关。
tf.data.Dataset.from_tensor_slices(data).batch(num)这里的batch(num)是指通过切片之后的数据分为num个一组
测试样例如下:
import tensorflow as tf
import numpy as np
data = np.random.randint(0,10, (6, 6, 3))
dataset = tf.data.Dataset.from_tensor_slices(data)
print("Original Data: ")
print(dataset)
print("---------------------------")
print("After slices")
for _ in range(2):
for record, data in enumerate(dataset):
print(record, "------", data)
print("---------------------------")
print("After batch")
dataset = tf.data.Dataset.from_tensor_slices(data).batch(2)
for record, data in enumerate(dataset):
print(record, "-----", data)
print("-----")
data可以是矩阵,也可以是元组(array1, array2),也可以是字典{“a”: array1, “b”: array2}。这里array1, array2代表的是矩阵,当然也可以是其他形式的数据,只不过一般所有数据都可以采用矩阵的形式进行处理。
Dataset中还有一些API用于数据转换、打乱、组成batch、生成epoch,主要的有以下4种:
map、batch、 shuffle、 repeat
map主要用于给Dataset中的每个数据执行函数, 即 如果想给Dataset中的数据执行某些改变操作,就可以先写一个函数,函数的参数是Data,通过map函数调用在Dataset中执行函数.
batch主要用于给Dataset中的数据进行打包,通过batch(num)将数据分为num个一组。
shuffle主要用于打乱Dataset中数据的顺序,shuffle(buffer_size), buffer_size代表打乱时使用的buffer(缓冲区)大小,shuffle不会改变原来的batch中的数据的顺序。需要说的是这里的打乱顺序的理论是将按原始dataset中顺序放buffer_size个batch放入buffer中,随机选取一个batch出缓冲区(确定了当前batch的在新dataset中的顺序,由原始dataset中的batch按顺序填入buffer中,一直进行下去,就能得到shuffle后的新Dataset.
repeat主要用于重复Dataset中数据及的数目,repeat(num)
测试样例如下:
‘’’
import tensorflow as tf
import numpy as np
data = np.random.randint(0, 10, (4, 3))
dataset = tf.data.Dataset.from_tensor_slices(data)
print("--------------------------------------------------")
print("Original Data:")
for record, data in enumerate(dataset):
print(record, "-----", data)
def test_map(data):
data = data * 10
return data
dataset = dataset.map(test_map)
print("--------------------------------------------------")
print("MAP test Data:")
for record, data in enumerate(dataset):
print(record, "-----", data)
print("--------------------------------------------------")
print("Batch test Data:")
dataset = dataset.batch(2)
for record, data in enumerate(dataset):
print(record, "-----", data)
print("--------------------------------------------------")
print("repeat test Data:")
dataset = dataset.repeat(2)
for record, data in enumerate(dataset):
print(record, "-----", data)
print("--------------------------------------------------")
print("shuffle test Data:")
dataset = dataset.shuffle(2)
for record, data in enumerate(dataset):
print(record, "-----", data)