【tensorflow2.0】tf.data输入模块实例

准备工作:

import tensorflow as tf
#加载手写数字数据集
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
#将数据集归一化
train_images = train_images / 255
test_images = test_images / 255

创建Dataset:

ds_train_img = tf.data.Dataset.from_tensor_slices(train_images)
ds_train_lab = tf.data.Dataset.from_tensor_slices(train_labels)

当显示shapes:()时,说明该数据为一个数字。

让创建的两个Dataset对应起来:

ds_train = tf.data.Dataset.zip((ds_train_img, ds_train_lab))

以元组的形式合并在一起,所以zip()函数里面还得再加一个括号。
此时,ds_train为一个ZipDataset,它的形状为((28, 28), ()),前面为一个28×28的图片,对应于train_images;后面是一个单独的数字,对应于train_labels。

然后再对数据统一做变换,就不用担心变换之后数据和标签不对应了。

ds_train = ds_train.shuffle(10000).repeat().batch(64)

取10000个组件进行乱序,然后无限重复,希望batch size为64,即每次输出64张图片以及对应的标签。

建立模型:

model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input=(28,28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

编译模型:

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

训练数据:
因为ds_train数据既包括图片,又包括对应的标签。所以直接训练ds_train数据即可。

model.fit(ds_train, epochs=5, steps_per_epoch = steps_per_epoch)

因为本数据是无限循环的,所以得告诉它循环多少次算是一个epoch。因为每次迭代是64张图片,一共有60000张图片,所以在此代码前加一段代码steps_per_epoch = train_images.shape[0] // 64,取整(因为这个数必须是整数)。

添加验证数据:
首先建立test的数据集

ds_test = tf.data.Dataset.from_tensor_slices((test_images, test_labels))

对model进行预测时,乱序变换并没有用处,训练时才有用。预测时也不需要重复变换,因为预测时默认是无限循环的。仅batch变换对预测有用。下面对数据进行batch变换设置

ds_test = ds_test.batch(64)

然后

model.fit(ds_train, epochs=5,steps_per_epoch = steps_per_epochs, 
          validation_data = ds_test, validation_steps =  10000//64)

设置了validation_data = ds_test, validation_steps = 10000//64,其中10000为ds_test的大小,64为之前设置的batch大小。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值