Tf.data中API使用
1.学习预览
- Dataset基础API使用
- Dataset读取csv文件
- Dataset读取和存储tfrecord文件,tensorflow中自带的文件存储格式,更快
2.使用API列表
- Dataset基础使用
tf.data.Dataset.from_tensor_slices ==> 用来构建Dataset
Data构建后的具体使用方法有 ==> repeat, batch, interleave, map, shuffle, list_files - csv
tf.data.TextLineDataset ==> 用于读取文本文件;tf.io.decode_csv ==>用来解析csv文件 - Tfrecord
tf.train.Floatlist, tf.train.Int64List, tf.train.BytesList
tf.train.Feature, tf.train.Features, tf.train.Example
example.SerializeToString
tf.io.ParseSingleExample
tf.io.VarLenFeature, tf.io.FixedLenFeature
tf.data.TFRecordDataset, tf.io.TFRecordOptions
2.具体使用案列
2.1 Dataset的基础使用
import tensorflow as tf
import numpy as np
dataset = tf.data.Dataset.from_tensor_slices(np.arange(10))
# 从内存中构建数据集
# 传入的参数可以是ndarry,可以是list,甚至是dict
for item in dataset:
print(item)
# 常用操作1. repeat epoch
# 常用操作2. get batch
data = data.repeat(3).batch(7) # 注意是又返回一个新的dataset
for item in data:
print(item)
# interleace: 对现有dataset中每一个item进行处理,产生新的结果,然后interleave再合并起来,形成新的数据集。。有点难懂
data2 = data.inerleave(map_func=lambda x: tf.data.Dataset.from_tensor_slices(x),
cycle_length=5, #并行程度
block_length=5, )
# 但是在分割数据每个batch时,传入的往往是样本和标签
x = np.array([[1., 2.], [3., 4.], [5., 6.]])
y = np.array(['cat', 'dog', 'fox'])
dataset = tf.data.Dataset.from_tensor_slices((x, y)) # 传入的形式唯一,不能是list
dataset2 = tf.data.Dataset.from_tensor_slices({'feature': x, 'lable': y}) # 此时的每一个item都是一个小字典
for item in dataset2:
print(item['feature'].numpy(), item['lable'].numpy())
2.2 csv在tensorflow中的使用
这部分学起来比较吃力,API较多且“复杂“,以至于后面看视频越看越糊涂,以后心情好了再学吧。