全文参考简单粗暴TensFlow2.0
import tensorflow as tf
import numpy as np
使用 SavedModel 完整导出模型
在部署模型时,我们的第一步往往是将训练好的整个模型完整导出为一系列标准格式的文件,然后即可在不同的平台上部署模型文件。这时,TensorFlow 为我们提供了 SavedModel 这一格式。与前面介绍的 Checkpoint 不同,SavedModel 包含了一个 TensorFlow 程序的完整信息: 不仅包含参数的权值,还包含计算的流程(即计算图) 。
当模型导出为 SavedModel 文件时,无需建立模型的源代码即可再次运行模型,这使得 SavedModel 尤其适用于模型的分享和部署。TensorFlow Serving(服务器端部署模型)、TensorFlow Lite(移动端部署模型)以及 TensorFlow.js 都会用到这一格式。
Keras 模型均可方便地导出为 SavedModel 格式。不过需要注意的是,因为 SavedModel 基于计算图,所以对于使用继承 tf.keras.Model 类建立的 Keras 模型,其需要导出到 SavedModel 格式的方法(比如 call )都需要使用 @tf.function 修饰。然后,假设我们有一个名为 model 的 Keras 模型,使用下面的代码即可将模型导出为 SavedModel:
tf.saved_model.save(model, "保存的目标文件夹名称")
在需要载入 SavedModel 文件时,使用
model = tf.saved_model.load("保存的目标文件夹名称")
任务:MNIST 手写体识别的模型 进行导出和导入
# 加载数据类
class MNISTLoader():
def __init__(self):
mnist = tf.keras.datasets.mnist
(self.train_data, self.train_label), (self.test_data, self.test_label) = mnist.load_data()
# MNIST中的图像默认为uint8(0-255的数字)。以下代码将其归一化到0-1之间的浮点数,并在最后增加一维作为颜色通道
self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis=-1)# [60000, 28, 28, 1]
self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis=-1) # [10000, 28, 28, 1]
self.train_label = self.train_label.astype(np.int32) # [60000]
self.test_label = self.test_label.astype(np.int32) # [10000]
self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]
def get_batch(self, batch_size):
# 从数据集中随机取出batch_size个元素并返回
index = np.random.randint(0, np.shape(self.train_data)[0], batch_size)
return self.train_data[index, :], self.train_label[index]
1.方式一:使用tf.keras.Model自定义模型
class MLP(tf.keras.Model):
def __init__(self):
super().__init__()
self.flatten = tf.keras.layers.Flatten() # Flatten层将除第一维(batch_size)以外的维度展平
self.dense1 = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(units=10)
@tf.function
def call(self, inputs): # [batch_size, 28, 28, 1]
x = self.flatten(inputs) # [batch_size, 784]
x = self.dense1(x) # [batch_size, 100]
x = self.dense2(x) # [batch_size, 10]
output = tf.nn.softmax(x)
return output
@tf.function
def train_one_step(x,y,model,optimizer):
with tf.GradientTape() as tape:
y_pred = model.call(x)
loss = tf.keras.losses.sparse_categorical_crossentropy(y,y_pred)
loss = tf.reduce_mean(loss)
tf.print("loss:",loss)
grads = tape.gradient(loss,model.variables)
optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))
# 定义超参数
num_epochs = 5
batch_size = 50
# 数据
dataset = MNISTLoader()
# 模型
model = MLP()
# 优化器
optimizer = tf.keras.optimizers.Adam(0.001)
num_batch = int(dataset.num_train_data//batch_size*num_epochs)
for index in range(num_batch):
x,y = dataset.get_batch(batch_size)
train_one_step(x,y,model,optimizer)
tf.saved_model.save(model,"./save/")
loss: 2.34741402
loss: 2.22175741
loss: 2.19209266
loss: 2.16491508
loss: 2.04420757
loss: 1.97402906
loss: 1.97099698
loss: 1.87866795
loss: 1.84972
loss: 1.78676546
loss: 1.76340961
loss: 1.62747312
loss: 1.61507249
loss: 1.58165455
loss: 1.39345312
loss: 1.46365499
loss: 1.48324382
loss: 1.2654649
loss: 1.42977905
loss: 1.34926128
loss: 1.30133331
loss: 1.21467018
loss: 1.19885385
...
loss: 0.0588415
loss: 0.0428937487
loss: 0.0467342623
loss: 0.0205140114
loss: 0.0116279796
loss: 0.101192042
loss: 0.0416857041
loss: 0.0855338871
loss: 0.0464787595
loss: 0.0230195746
loss: 0.0637835711
loss: 0.0763473585
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:tensorflow:Assets written to: ./save/assets
加载模型
model = tf.saved_model.load(export_dir='./save/')
accurancy = tf.keras.metrics.SparseCategoricalAccuracy()
dataset = MNISTLoader()
num_batch = int(dataset.num_test_data//batch_size)
for index in range(num_batch):
start_index,end_index = index * batch_size,(index+1) * batch_size
y_pred = model.call(dataset.test_data[start_index:end_index])
accurancy.update_state(dataset.test_label[start_index:end_index],y_pred)
print("accurancy is %f:"%accurancy.result())
accurancy is 0.972800:
2.方式二:使用tf.keras高级API定义模型
sequential_model = tf.keras.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units=100,activation=tf.nn.relu),
tf.keras.layers.Dense(units=10),
tf.keras.layers.Softmax()
])
dataset = MNISTLoader()
sequential_model.compile(optimizer=tf.keras.optimizers.Adam(),
loss = tf.keras.losses.sparse_categorical_crossentropy,
metrics=[tf.keras.metrics.sparse_categorical_accuracy])
sequential_model.fit(dataset.train_data,dataset.train_label,epochs=5)
Train on 60000 samples
Epoch 1/5
60000/60000 [==============================] - 7s 117us/sample - loss: 0.2704 - sparse_categorical_accuracy: 0.9226
Epoch 2/5
60000/60000 [==============================] - 6s 93us/sample - loss: 0.1225 - sparse_categorical_accuracy: 0.9640
Epoch 3/5
60000/60000 [==============================] - 5s 91us/sample - loss: 0.0847 - sparse_categorical_accuracy: 0.9748
Epoch 4/5
60000/60000 [==============================] - 7s 117us/sample - loss: 0.0636 - sparse_categorical_accuracy: 0.9803
Epoch 5/5
60000/60000 [==============================] - 7s 121us/sample - loss: 0.0487 - sparse_categorical_accuracy: 0.9851
<tensorflow.python.keras.callbacks.History at 0x1d5013ec308>
# 保存模型
tf.saved_model.save(sequential_model,export_dir='./save/')
WARNING:tensorflow:From D:\Anaconda3\envs\tf2\lib\site-packages\tensorflow_core\python\ops\resource_variable_ops.py:1781: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:tensorflow:Assets written to: ./save/assets
# 加载模型 测试
model = tf.saved_model.load(export_dir='./save/')
batch_size = 50
accurancy = tf.keras.metrics.SparseCategoricalAccuracy()
num_batch = int(dataset.num_test_data//batch_size)
for index in range(num_batch):
start_index,end_index = index * batch_size,(index+1) * batch_size
y_pred = model(dataset.test_data[start_index:end_index])
accurancy.update_state(dataset.test_label[start_index:end_index],y_pred)
print("accurancy is %f:"%accurancy.result())
accurancy is 0.975800: