TensorFlow2.0 DataSet数据集的使用
import tensorflow as tf
import numpy as np
def GenerateData(data_size=100):
train_x = np.linspace(-1, 1, data_size)
print(*train_x.shape)
train_y = 2 * train_x + np.random.randn(*train_x.shape) * 0.2
return train_x,train_y
train_data = GenerateData()
def get_one(dataset):
for elment in dataset:
return elment
def show_elment(elment):
x, y = elment
print("x shape:", x.shape)
print("x:", x.numpy())
print("y shape:", y.shape)
print("y:", y.numpy())
def show_head(dataset, size=5):
for step, elment in dataset.enumerate():
show_elment(elment)
if step >= size-1:
break
batch_size = 10
dataset_tuple = tf.data.Dataset.from_tensor_slices(train_data)
db_tuple = dataset_tuple.shuffle(100).batch(batch_size)
elment_tuple = get_one(db_tuple)
show_head(db_tuple, 1)
dataset_dict = tf.data.Dataset.from_tensor_slices({
"x":train_data[0],
"y":train_data[1]
})
db_dict = dataset_dict.map(lambda data: (data["x"], data["y"]))
db_dict = db_dict.shuffle(100).repeat().batch(batch_size)
elment_dict = get_one(db_dict)
show_elment(elment_dict)
show_head(db_dict, 1)