一文讲解TensorFlow数据接口 tf.data.Dataset

导入数据
X = pd.read_csv('./datasets/housing/housing.csv')
X = X.sample(n=10)
X.drop(columns = X.columns.difference(['longitude']), inplace=True)

为了避免报错,先进行格式转换:

X = np.asarray(X).astype(np.float32)
dataset = tf.data.Dataset.from_tensor_slices(X)
for _ in dataset:
    print(_)
tf.Tensor([-118.75], shape=(1,), dtype=float32)
tf.Tensor([-119.25], shape=(1,), dtype=float32)
tf.Tensor([-118.18], shape=(1,), dtype=float32)
tf.Tensor([-118.13], shape=(1,), dtype=float32)
tf.Tensor([-118.2], shape=(1,), dtype=float32)
tf.Tensor([-117.25], shape=(1,), dtype=float32)
tf.Tensor([-117.93], shape=(1,), dtype=float32)
tf.Tensor([-122.96], shape=(1,), dtype=float32)
tf.Tensor([-121.77], shape=(1,), dtype=float32)
tf.Tensor([-121.24], shape=(1,), dtype=float32)
dataset = dataset.repeat(3).batch(10)
for _ in dataset:
    print(_)

图解:

repeat(3)将数据集重复3次,batch(10)每次输出一个包括10个元素的batch。

tf.Tensor(
[[-118.75]
 [-119.25]
 [-118.18]
 [-118.13]
 [-118.2 ]
 [-117.25]
 [-117.93]
 [-122.96]
 [-121.77]
 [-121.24]], shape=(10, 1), dtype=float32)
tf.Tensor(
[[-118.75]
 [-119.25]
 [-118.18]
 [-118.13]
 [-118.2 ]
 [-117.25]
 [-117.93]
 [-122.96]
 [-121.77]
 [-121.24]], shape=(10, 1), dtype=float32)
tf.Tensor(
[[-118.75]
 [-119.25]
 [-118.18]
 [-118.13]
 [-118.2 ]
 [-117.25]
 [-117.93]
 [-122.96]
 [-121.77]
 [-121.24]], shape=(10, 1), dtype=float32)

如果不能刚好等分,例如

dataset = dataset.repeat(3).batch(9)
for _ in dataset:
    print(_)

最后一个batch将包含剩下的元素

tf.Tensor(
[[-122.08]
 [-121.37]
 [-118.32]
 [-122.38]
 [-122.09]
 [-122.1 ]
 [-122.27]
 [-121.49]
 [-120.68]], shape=(9, 1), dtype=float64)
tf.Tensor(
[[-118.2 ]
 [-122.08]
 [-121.37]
 [-118.32]
 [-122.38]
 [-122.09]
 [-122.1 ]
 [-122.27]
 [-121.49]], shape=(9, 1), dtype=float64)
tf.Tensor(
[[-120.68]
 [-118.2 ]
 [-122.08]
 [-121.37]
 [-118.32]
 [-122.38]
 [-122.09]
 [-122.1 ]
 [-122.27]], shape=(9, 1), dtype=float64)
tf.Tensor(
[[-121.49]
 [-120.68]
 [-118.2 ]], shape=(3, 1), dtype=float64)
map函数
dataset = dataset.map(lambda x: abs(x))
for _ in dataset:
    print(_)
tf.Tensor(
[[118.75]
 [119.25]
 [118.18]
 [118.13]
 [118.2 ]
 [117.25]
 [117.93]
 [122.96]
 [121.77]
 [121.24]], shape=(10, 1), dtype=float32)
tf.Tensor(
[[118.75]
 [119.25]
 [118.18]
 [118.13]
 [118.2 ]
 [117.25]
 [117.93]
 [122.96]
 [121.77]
 [121.24]], shape=(10, 1), dtype=float32)
tf.Tensor(
[[118.75]
 [119.25]
 [118.18]
 [118.13]
 [118.2 ]
 [117.25]
 [117.93]
 [122.96]
 [121.77]
 [121.24]], shape=(10, 1), dtype=float32)
filter函数

使用filter函数前需要先unbatch

dataset = dataset.unbatch()
dataset = dataset.filter(lambda x: x < 120)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值