tf.data.Dataset.from_tensor_slices()函数的参数是tensor。
该函数的作用是接收tensor,对tensor的第一维度进行切分,并返回一个表示该tensor的切片数据集
以minist训练集为例:
x的shape为(60000,28,28),将x作为参数传递给tf.data.Dataset.from_tensor_slices(),
将返回一个含有60000个切片的数据集,每个切片为 28*28 的图像(但数据集不知道里面有多少个切片)。
代码如下:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets
(x, y), _ = datasets.mnist.load_data()
x = tf.convert_to_tensor(x, dtype=tf.float32)/255.
print(x.shape)
train_db = tf.data.Dataset.from_tensor_slices(x)
print(train_db)
输出的结果为: