前言
这个是我学习网易云课堂上对于tensorflow2.0数据加载内容的笔记
小型的经典数据通过keras.datasets就可以使用
keras.datasets可以加载的数据集
- boston houseing
- mnist/fashion mnist
- cifar10/100
- imdb
tf.data.Dataset.from_tensor_slices
切分传入的 Tensor 的第一个维度,生成相应的 dataset。可以用迭代器来取出dataset
a = tf.random.normal([28,28,28])
aa = tf.data.Dataset.from_tensor_slices(a)
print(next(iter(aa)).shape)
b = tf.random.normal([25,27,28,28])
bb = tf.data.Dataset.from_tensor_slices(b)
print(next(iter(bb)).shape)
c = tf.random.normal([28,2])
cc = tf.data.Dataset.from_tensor_slices(c)
print(next(iter(cc)).shape)
##########输出#################
(28, 28)
(27, 28, 28)
(2,)
如果是tf.data.Dataset.from_tensor_slices((x,y)),那么返回的会是元组,这样可以做到图片与label相对应(假设x是传入的图像,y是label),可以用next(iter(db))[0].shape来查看形状
db = tf.data.Dataset.from_tensor_slices((x,y))
print(next(iter(db))[0].shape)
print(next(iter(db))[1].shape)
###########输出#########
(28, 28)
()
shuffle
原理可以看这个嘞大佬的博客
打乱数据顺序,可以防止过拟合这样子,在训练数据时用。参数就给一个比较大的就好了。
buffer_size = 1 数据集不会被打乱
buffer_size = 数据集样本数量,随机打乱整个数据集
buffer_size > 数据 集样本数量,随机打乱整个数据集
db = db.shuffle(buffer_size )
map
对数据进行预处理
def preproess(x,y):
#tf.cast()函数的作用是执行 tensorflow 中张量数据类型转换,比如读入的图片如果是int8类型的,一般在要在训练前把图像的数据格式转换为float32。
x = tf.cast(x,dtypt=float32) / 255.
y = tf.cast(y,dtype=int32)
y = tf.one_hot(y,depth=10)
return x,y
db2 = db.map(preproess)
batch
#用于迭代器每次取出多少图
db3 = db2.batch(32)
res = next(iter(db3))
repeat
就是重复操作,防止用whie True时出错
#一直迭代
db4 = db3.repeat()
#迭代全部的数据4次退出
db4 = db3.repeat(4)
######暂时就这些了#######19.7.18