经典数据集操作
功能 | 函数代码 |
---|---|
加载数据集 | datasets.Dataset_name.load_data() |
构建 Dataset 对象 | tf.data.Dataset_name.from_tensor_slices((x, y)) |
随机打散 | Dataset_name.shuffle(buffer_size) |
批训练 | Dataset_name.batch(size) |
数据预处理 | Dataset_name.map(func_name) |
数据集Datatset_name | 类型 |
---|---|
Boston housing | 波士顿房价趋势 |
CIFAR10/100 | 图片数据集 |
MNIST/Fashion_MNIST | 手写数字 |
IMDB | 文本分类 |
数据集缓存在用户目录下的.keras/datasets 文件夹
加载数据集
数据集缓存在用户目录下的.keras/datasets 文件夹(有则加载,无则自动下载)
import tensorflow as tf
from tensorflow.keras import datasets
(x,y),(x_text,y_text) = datasets.mnist.load_data()
print(x.shape)
print(y.shape)
print(x_text.shape)
print(y_text.shape)
out:
(60000, 28, 28)
(60000,)
(10000, 28, 28)
(10000,)
数据加载进入内存后,需要转换成 Dataset 对象, 才能利用 TensorFlow 提供的各种操作
import tensorflow as tf
from tensorflow.keras import datasets
(x,y),(x_text,y_text) = datasets.mnist.load_data()
print(x.shape)
print(y.shape)
print(x_text.shape)
print(y_text.shape)
train_db = tf.data.Dataset.from_tensor_slices((x, y))
print(train_db)
out:
(60000, 28, 28)
(60000,)
(10000, 28, 28)
(10000,)
<TensorSliceDataset shapes: ((28, 28), ()), types: (tf.uint8, tf.uint8)>
随机打散
- Dataset_name.shuffle(buffer_size)
- buffer_size为缓冲池大小,设置一个较大常数
import tensorflow as tf
from tensorflow.keras import datasets
(x,y),(x_text,y_text) = datasets.mnist.load_data()
print(x.shape)
print(y.shape)
print(x_text.shape)
print(y_text.shape)
train_db = tf.data.Dataset.from_tensor_slices((x, y))
td = train_db.shuffle(500)
print(td)
out:
(60000, 28, 28)
(60000,)
(10000, 28, 28)
(10000,)
<ShuffleDataset shapes: ((28, 28), ()), types: (tf.uint8, tf.uint8)>
批训练
- Dataset_name.batch(size)
- 同时并行计算多个样本为批训练,size即为并行计算数目,尽量根据显卡性能配置
import tensorflow as tf
from tensorflow.keras import datasets
(x,y),(x_text,y_text) = datasets.mnist.load_data()
train_db = tf.data.Dataset.from_tensor_slices((x, y))
train_db = train_db.batch(100)
print(train_db)
out:
<BatchDataset shapes: ((None, 28, 28), (None,)), types: (tf.uint8, tf.uint8)>
预处理
- Dataset_name.map(func_name)
import tensorflow as tf
from tensorflow.keras import datasets
(x,y),(x_text,y_text) = datasets.mnist.load_data()
train_db = tf.data.Dataset.from_tensor_slices((x, y))
def func_name(x,y):
x = tf.cast(x, dtype=tf.float32) / 255.
x = tf.reshape(x, [-1, 28 * 28])
y = tf.cast(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
return x , y
train_db = train_db.map(func_name)
print(train_db)
out:
<MapDataset shapes: ((1, 784), (10,)), types: (tf.float32, tf.float32)>
循环训练
-
for step, (x,y) in enumerate(train_db):
-
for x,y in train_db:
-
for epoch in range(20): for step, (x,y) in enumerate(train_db):
-
train_db = train_db.repeat(20)