(13)tensorflow数据集操作

经典数据集操作

功能函数代码
加载数据集datasets.Dataset_name.load_data()
构建 Dataset 对象tf.data.Dataset_name.from_tensor_slices((x, y))
随机打散Dataset_name.shuffle(buffer_size)
批训练Dataset_name.batch(size)
数据预处理Dataset_name.map(func_name)
数据集Datatset_name类型
Boston housing波士顿房价趋势
CIFAR10/100图片数据集
MNIST/Fashion_MNIST手写数字
IMDB文本分类

数据集缓存在用户目录下的.keras/datasets 文件夹

加载数据集

数据集缓存在用户目录下的.keras/datasets 文件夹(有则加载,无则自动下载)

import tensorflow as tf
from tensorflow.keras import datasets
(x,y),(x_text,y_text) =  datasets.mnist.load_data()
print(x.shape)
print(y.shape)
print(x_text.shape)
print(y_text.shape)

out:
(60000, 28, 28)
(60000,)
(10000, 28, 28)
(10000,)

数据加载进入内存后,需要转换成 Dataset 对象, 才能利用 TensorFlow 提供的各种操作

import tensorflow as tf
from tensorflow.keras import datasets
(x,y),(x_text,y_text) =  datasets.mnist.load_data()
print(x.shape)
print(y.shape)
print(x_text.shape)
print(y_text.shape)
train_db = tf.data.Dataset.from_tensor_slices((x, y))
print(train_db)

out:
(60000, 28, 28)
(60000,)
(10000, 28, 28)
(10000,)
<TensorSliceDataset shapes: ((28, 28), ()), types: (tf.uint8, tf.uint8)>

随机打散

  • Dataset_name.shuffle(buffer_size)
  • buffer_size为缓冲池大小,设置一个较大常数
import tensorflow as tf
from tensorflow.keras import datasets
(x,y),(x_text,y_text) =  datasets.mnist.load_data()
print(x.shape)
print(y.shape)
print(x_text.shape)
print(y_text.shape)
train_db = tf.data.Dataset.from_tensor_slices((x, y))
td = train_db.shuffle(500)
print(td)

out:
(60000, 28, 28)
(60000,)
(10000, 28, 28)
(10000,)
<ShuffleDataset shapes: ((28, 28), ()), types: (tf.uint8, tf.uint8)>

批训练

  • Dataset_name.batch(size)
  • 同时并行计算多个样本为批训练,size即为并行计算数目,尽量根据显卡性能配置
import tensorflow as tf
from tensorflow.keras import datasets
(x,y),(x_text,y_text) =  datasets.mnist.load_data()
train_db = tf.data.Dataset.from_tensor_slices((x, y))
train_db = train_db.batch(100)
print(train_db)

out:
<BatchDataset shapes: ((None, 28, 28), (None,)), types: (tf.uint8, tf.uint8)>

预处理

  • Dataset_name.map(func_name)
import tensorflow as tf
from tensorflow.keras import datasets
(x,y),(x_text,y_text) =  datasets.mnist.load_data()
train_db = tf.data.Dataset.from_tensor_slices((x, y))
def func_name(x,y):
    x = tf.cast(x, dtype=tf.float32) / 255.
    x = tf.reshape(x, [-1, 28 * 28])
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)
    return x , y
train_db = train_db.map(func_name)
print(train_db)

out:
<MapDataset shapes: ((1, 784), (10,)), types: (tf.float32, tf.float32)>

循环训练

  •   for step, (x,y) in enumerate(train_db):
    
  •   for x,y in train_db:
    
  •   for epoch in range(20):
      	for step, (x,y) in enumerate(train_db):
    
  •   train_db = train_db.repeat(20)
    
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小蜗笔记

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值