# tensorflow2.0 初学者
import tensorflow as tf
import numpy as np
# 加载mnist数据
mnist = tf.keras.datasets.mnist
(xtrain, ytrain), (xtest, ytest) = mnist.load_data()
print(xtrain.shape, ytrain.shape, xtest.shape,ytest.shape)
# 数据归一化, 将值缩放到[0,1]
xtrain, xtest = xtrain/255.0, xtest/255.0
# 通过堆叠图层构建模型
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28,28)),
tf.keras.layers.Dense(128,activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10,activation='softmax')
])
#
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# tf.losses.loss
# 模型保存
check_path = 'my_mnist_checkpoint/ckpt/cp-{epoch:04d}.ckpt'
save_mode_cp = tf.keras.callbacks.ModelCheckpoint(check_path, verbose=1, save_weights_only=True, period=2)
# 训练
model.fit(xtest,ytest,epochs=5,callbacks=[save_mode_cp])
SDUWH2019-2020寒假python实训--my_tf_1
最新推荐文章于 2024-05-20 22:54:06 发布