目录
头文件
import numpy as np
import tensorflow as tf
import datetime
import os
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPool2D
from tensorflow.keras import Model
1、加载数据集并处理
def load_dataset():
mnist = np.load("mnist.npz")
return mnist['x_train']/255.0, mnist['y_train'], mnist['x_test']/255.0, mnist['y_test']
x_train, y_train, x_test, y_test = load_dataset()
print (np.shape(x_train))
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
# y_train = tf.one_hot(y_train, depth=10)
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices