5.数据构建
tf.data简介
面对一堆格式不一的原始数据文件 ?
读入程序的过程往往十分繁琐 ?
运行的效率上不尽如人意 ?
T e n so r F l ow 提供了 tf.data 这一模块,包括了一套灵活的数据集构建 API,能够帮助我们快速 、高效地构建数据输入的流水线 ,尤其适用于数据量巨大的场景。
tf.data包含三个类:
• tf.data.Dataset类
• tf.data.TFRecordDataset类
• tf.data.TextLineDataset类
5.1.Dataset类
tf.data 的核心是 tf.data.Dataset 类,提供了对数据集的高层封装。
tf.data.Dataset 由一系列的可迭代访问的元素(element)组成,每个元素包含一个或多个张量。 Dataset可以看作是相同类型“元素”的有序列表。
注: Dataset可以看作是相同类型“元素”的有序列表。在实际使用时,单个“元素”
可以是向量,也可以是字符串、图片,甚至是tuple或者dict。
比如说,对于一个由图像组成的数据集,每个元素可以是一个形状为 长×宽×通道数的图片张量,也可以是由图片张量和图片标签张量组成的元组(Tuple)。
常用创建tf.data.Dataset数据集的方法有:
tf.data.Dataset.from_tensors() :创建Dataset对象,返回具有单个元素的数据集。
tf.data.Dataset.from_tensor_slices() :创建一个Dataset对象,但是会将第0维切分
tf.data.Dataset. from_generator() :迭代生成所需的数据集,一般数据量较大时使用。
from_tensors() 函数会把传入的tensor当做一个元素,但是from_tensor_slices() 会把传入的tensor除开第0维之后的大小当做一个元素
最基础的建立 tf.data.Dataset 的方法是使用 tf.data.Dataset.from_tensor_slices() ,适用于数据量较小(能够整个装进内存)的情况
具体而言,如果我们的数据集中的所有元素通过张量的第 0 维,拼接成一个大的张量(例如,前节的 MNIST 数据集的训练集即为一个 [60000, 28, 28, 1] 的张量,表示了 60000 张 28*28 的单通道灰度图像),那么我们提供一个这样的张量或者第 0 维大小相同的多个张量作为输入,即可按张量的第 0 维展开来构建数据集,数据集的元素数量为张量第 0 位的大小。
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
mnist = np.load("mnist.npz")
x_train, y_train = mnist['x_train'],mnist['y_train']
x_train.shape,y_train.shape
((60000, 28, 28), (60000,))
x_train = np.expand_dims(x_train, axis=-1)
mnist_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
for image, label in mnist_dataset:
plt.title(label.numpy())
plt.imshow(image.numpy()[:, :,0])
plt.show()
break
Pandas数据读取
import pandas as pd
df = pd.read_csv('heart.csv')
df.head()
|
age |
sex |
cp |
trestbps |
chol |
fbs |
restecg |
thalach |
exang |
oldpeak |
slope |
ca |
thal |
target |
0 |
63 |
1 |
1 |
145 |
233 |
1 |
2 |
150 |
0 |
2.3 |
3 |
0 |
fixed |
0 |
1 |
67 |
1 |
4 |
160 |
286 |
0 |
2 |
108 |
1 |
1.5 |
2 |
3 |
normal |
1 |
2 |
67 |
1 |
4 |
120 |
229 |
0 |
2 |
129 |
1 |
2.6 |
2 |
2 |
reversible |
0 |
3 |
37 |
1 |
3 |
130 |
250 |
0 |
0 |
187 |
0 |
3.5 |
3 |
0 |
normal |
0 |
4 |
41 |
0 |
2 |
130 |
204 |
0 |
2 |
172 |
0 |
1.4 |
1 |
0 |
normal |
0 |
df.dtypes
age int64
sex int64
cp int64
trestbps int64
chol int64
fbs int64
restecg int64
thalach int64
exang int64
oldpeak float64
slope int64
ca int64
thal object
target int64
dtype: object
df['thal'] = pd.Categorical(df['thal']).codes
df.head()
|
age |
sex |
cp |
trestbps |
chol |
fbs |
restecg |
thalach |
exang |
oldpeak |
slope |
ca |
thal |
target |
0 |
63 |
1 |
1 |
145 |
233 |
1 |
2 |
150 |
0 |
2.3 |
3 |
0 |
2 |
0 |
1 </ |