tensorflow2中的tf.data.Dataset.from_tensor_slices()

tf.data.Dataset.from_tensor_slices()函数的参数是tensor。
该函数的作用是接收tensor,对tensor的第一维度进行切分,并返回一个表示该tensor的切片数据集
以minist训练集为例:
x的shape为(60000,28,28),将x作为参数传递给tf.data.Dataset.from_tensor_slices(),
将返回一个含有60000个切片的数据集,每个切片为 28*28 的图像(但数据集不知道里面有多少个切片)。
代码如下:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets

(x, y),  _ = datasets.mnist.load_data()
x = tf.convert_to_tensor(x, dtype=tf.float32)/255.
print(x.shape)

train_db = tf.data.Dataset.from_tensor_slices(x)
print(train_db)

输出的结果为:
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值