本文主要使用jupyter notebook进行模型的搭建与训练。
1.导入包
import tensorflow as tf
from tensorflow.keras.layers import Dense,Flatten
from tensorflow.keras import Model
tf.keras.backend.set_floatx('float64')
2.下载手写数字数据集
mnist=tf.keras.datasets.mnist
(x_train,y_train),(x_test,y_test)=mnist.load_data()
x_train,x_test=x_train/255.0,x_test/255.0
x_train.shape
x_test.shape
3.生成Dataset
train_ds=tf.data.Dataset.from_tensor_slices(
(x_train,y_train)).shuffle(10000).batch(32)
test_ds=tf.data.Dataset.from_tensor_slices(
(x_test,y_test)).batch(32)
4.使用子类SU币Classing API搭建模型
class MyModel(Model):
def __init__(self):
super(MyModel,self).__init__()
self.flatten=Flatten()
self.d1=Dense(128,activation='relu')
self.d2=Dense(110,activation='softmax')
def call(self,x):
x=self.flatten(x)
x=self.d1(x)
return self.d2(x)
model=MyModel()
5.选择优化器与损失函数
#损失函数
loss_object=tf.keras.losses.SparseCategoricalCrossentropy()
#优化器,用来优化参数
optimizer=tf.keras.optimizers.Adam()
6.使用tf.GradientTpye来训练模型
#训练的目标
train_loss=tf.keras.metrics.Mean(name='train_loss')
train_accuracy=tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
#tf对被tf.function装饰的函数进行AutoGraph优化,提升运算速度
@tf.function
def train_step(images,labels):
with tf.GradientTape() as tape:
predictions=model(images)
loss=loss_object(labels,predictions)
gradients=tape.gradient(loss,model.trainable_variables)
optimizer.apply_gradients(zip(gradients,model.trainable_variables))
train_loss.update_state(loss)
train_accuracy.update_state(labels,predictions)
EPOCHS=5
for epoch in range(EPOCHS):
train_loss.reset_states()
train_accuracy.reset_state()
for images,labels in train_ds:
train_step(images,labels)
print(f'Epoch {epoch+1},Loss:{train_loss.result()},Accuracy:{train_accuracy.result()*100}')
7.测试评估
test_loss=tf.keras.metrics.Mean(name='test_loss')
test_accuracy=tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
for test_images,test_labels in test_ds:
predictions=model.predict(test_images)
t_loss=loss_object(test_labels,predictions)
test_loss.update_state(t_loss)
test_accuracy.update_state(test_labels,predictions)
print(f'Loss:{test_loss.result()},Accuracy:{test_accuracy.result()*100}')