import tensorflow as tf
from tensorflow import keras
def preprocess(x,y):
x=tf.cast(x,dtype=tf.float32)/255.
y=tf.cast(y,dtype=tf.int32)
return x,y
(x_train,y_train),(x_test,y_test)=keras.datasets.mnist.load_data()
train_db=tf.data.Dataset.from_tensor_slices((x_train,y_train))
train_db=train_db.map(preprocess).shuffle(10000).batch(128)
test_db=tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db=test_db.map(preprocess).shuffle(10000).batch(128)
# 查看训练数据集
train_iter=iter(train_db)
print(next(train_iter)[0].shape)
tensorflow2.0之常见数据集的使用
最新推荐文章于 2024-07-12 15:12:34 发布