tensorflow快速入门文档地址:
https://tensorflow.google.cn/tutorials/quickstart/advanced
下面是阅读文档过程中的记录:
1、数据预处理
x_train = x_train[..., tf.newaxis]
tf.newaxis 的 作用与 np.newaxis 的作用相同,都是添加张量的维度。
参考:https://blog.csdn.net/u013841196/article/details/84260631
2、将数据集切分为 batch 以及混淆数据集:
参考: https://tensorflow.google.cn/api_docs/python/tf/data/Dataset
train_ds = tf.data.Dataset.from_tensor_slices( (x_train, y_train)).shuffle(10000).batch(32)
2.1基本用法:
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]) for element in dataset: print(element)
代码运行结果:
2.2设置batch:
dataset = tf.data.Dataset.range(8) dataset = dataset.batch(3) list(dataset.as_numpy_iterator())
代码运行结果:
2.3乱序:
dataset = tf.data.Dataset.range(3) dataset = dataset.shuffle(3, reshuffle_each_iteration=True) dataset = dataset.repeat(2) # doctest: +SKIP
代码运行结果: