python 部署模型_三分钟快速上手TensorFlow 2.0 (下)——模型的部署 、大规模训练、加速...

本文详细介绍了如何使用TensorFlow进行模型部署,包括使用SavedModel导出模型,Keras Sequential模型的保存与加载,以及在服务器端使用TensorFlow Serving部署模型。此外,还涉及了TensorFlow Lite在移动端的部署和性能比较,以及分布式训练的策略,如MirroredStrategy和MultiWorkerMirroredStrategy。
摘要由CSDN通过智能技术生成

TensorFlow 模型导出

使用 SavedModel 完整导出模型

不仅包含参数的权值,还包含计算的流程(即计算图)

tf.saved_model.save(model, "保存的目标文件夹名称")

将模型导出为 SavedModel

model = tf.saved_model.load("保存的目标文件夹名称")

载入 SavedModel 文件

因为 SavedModel 基于计算图,所以对于使用继承 tf.keras.Model 类建立的 Keras 模型,其需要导出到 SavedModel 格式的方法(比如 call )都需要使用 @tf.function 修饰

使用继承 tf.keras.Model 类建立的 Keras 模型 model ,使用 SavedModel 载入后将无法使用 model() 直接进行推断,而需要使用 model.call()

importtensorflow as tffrom zh.model.utils importMNISTLoader

num_epochs= 1batch_size= 50learning_rate= 0.001model=tf.keras.models.Sequential([

tf.keras.layers.Flatten(),

tf.keras.layers.Dense(100, activation=tf.nn.relu),

tf.keras.layers.Dense(10),

tf.keras.layers.Softmax()

])

data_loader=MNISTLoader()

model.compile(

optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),

loss=tf.keras.losses.sparse_categorical_crossentropy,

metrics=[tf.keras.metrics.sparse_categorical_accuracy]

)

model.fit(data_loader.train_data, data_loader.train_label, epochs=num_epochs, batch_size=batch_size)

tf.saved_model.save(model,"saved/1")

MNIST 手写体识别的模型 进行导出

importtensorflow as tffrom zh.model.utils importMNISTLoader

batch_size= 50model= tf.saved_model.load("saved/1")

data_loader=MNISTLoader()

sparse_categorical_accuracy=tf.keras.metrics.SparseCategoricalAccuracy()

num_batches= int(data_loader.num_test_data //batch_size)for batch_index inrange(num_batches):

start_index, end_index= batch_index * batch_size, (batch_index + 1) *batch_size

y_pred=model(data_loader.test_data[start_index: end_index])

sparse_categorical_accuracy.update_state(y_true=data_loader.test_label[start_index: end_index], y_pred=y_pred)print("test accuracy: %f" % sparse_categorical_accuracy.result())

MNIST 手写体识别的模型 进行导入并测试

classMLP(tf.keras.Model):def __init__(self):

super().__init__()

self.flatten=tf.keras.layers.Flatten()

self.dense1= tf.keras.layers.Dense(units=100, activation=tf.nn.relu)

self.dense2= tf.keras.layers.Dense(units=10)

@tf.functiondef call(self, inputs): #[batch_size, 28, 28, 1]

x = self.flatten(inputs) #[batch_size, 784]

x = self.dense1(x) #[batch_size, 100]

x = self.dense2(x) #[batch_size, 10]

output =tf.nn.softmax(x)returnoutput

model= MLP()

使用继承 tf.keras.Model 类建立的 Keras 模型同样可以以相同方法导出,唯须注意 call 方法需要以 @tf.function 修饰,以转化为 SavedModel 支持的计算图

y_pred = model.call(data_loader.test_data[start_index: end_index])

模型导入并测试性能的过程也相同,唯须注意模型推断时需要显式调用 call 方法

Keras Sequential save 方法

是基于 keras 的 Sequential 构建了多层的卷积神经网络,并进行训练

curl -LO https://raw.githubusercontent.com/keras-team/keras/master/examples/mnist_cnn.py

使用如下命令拷贝到本地:

model.save('mnist_cnn.h5')

对 keras 训练完毕的模型进行保存

python mnist_cnn.py

在终端中执行 mnist_cnn.py 文件

执行过程会比较久,执行结束后,会在当前目录产生 mnist_cnn.h5 文件(HDF5 格式),就是 keras 训练后的模型,其中已经包含了训练后的模型结构和权重等信息。

在服务器端,可以直接通过 keras.models.load_model("mnist_cnn.h5") 加载,然后进行推理;在移动设备需要将 HDF5 模型文件转换为 TensorFlow Lite 的格式,然后通过相应平台的 Interpreter 加载,然后进行推理。

TensorFlow Serving(服务器端部署模型)

安装

#添加Google的TensorFlow Serving源

echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list#添加gpg key

curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -

设置安装源

sudo apt-get update

sudo apt-get install tensorflow-model-server

使用 apt-get 安装 TensorFlow Serving

curl 设置代理的方式为 -x 选项或设置 http_proxy 环境变量,即

export http_proxy=http://代理服务器IP:端口

curl-x http://代理服务器IP:端口 URL

apt-get 设置代理的方式为 -o 选项,即

sudo apt-get -o Acquire::http::proxy="http://代理服务器IP:端口" ...

可能需要设置代理

模型部署

tensorflow_model_server \--rest_api_port=端口号(如8501) \--model_name=模型名 \--model_base_path="SavedModel格式模型的文件夹绝对地址(不含版本号)"

直接读取 SavedModel 格式的模型进行部署

支持热更新模型,其典型的模型文件夹结构如下:

/saved_model_files

/1 # 版本号为1的模型文件

/assets

/variables

saved_model.pb

...

/N # 版本号为N的模型文件

/assets

/variables

saved_model.pb

上面 1~N 的子文件夹代表不同版本号的模型。当指定 --model_base_path 时,只需要指定根目录的 绝对地址 (不是相对地址)即可。例如,如果上述文件夹结构存放在 home/snowkylin 文件夹内,则 --model_base_path 应当设置为 home/snowkylin/saved_model_files (不附带模型版本号)。TensorFlow Serving 会自动选择版本号最大的模型进行载入。

tensorflow_model_server \--rest_api_port=8501\--model_name=MLP \--model_base_path="/home/.../.../saved" #文件夹绝对地址根据自身情况填写,无需加入版本号

Keras Sequential 模式模型的部署

Sequential 模式的输入和输出都很固定,因此这种类型的模型很容易部署,无需其他额外操作。例如,要将 前文使用 SavedModel 导出的 MNIST 手写体识别模型 (使用 Keras Sequential 模式建立)以 MLP 的模型名在 8501 端口进行部署,可以直接使用以上命令

classMLP(tf.keras.Model):

...

@tf.function(input_signature=[tf.TensorSpec([None, 28, 28, 1], tf.float32)])defcall(self, inputs):

...

自定义 Keras 模型的部署-导出到 SavedModel 格式

不仅需要使用 @tf.function 修饰,还要在修饰时指定 input_signature 参数,以显式说明输入的形状。该参数传入一个由 tf.TensorSpec 组成的列表,指定每个输入张量的形状和类型

例如,对于 MNIST 手写体数字识别,我们的输入是一个 [None, 28, 28, 1] 的四维张量( None表示第一维即 Batch Size 的大小不固定),此时我们可以将模型的 call 方法做出上面的修饰

model =MLP()

...

tf.saved_model.save(model,"saved_with_signature/1", signatures={"call": model.call})

自定义 Keras 模型的部署-使用 tf.saved_model.save 导出

将模型使用 tf.saved_model.save 导出时,需要通过 signature 参数提供待导出的函数的签名(Signature)

需要告诉 TensorFlow Serving 每个方法在被客户端调用时分别叫做什么名字。例如,如果我们希望客户端在调用模型时使用 call 这一签名来调用 model.call方法时,我们可以在导出时传入 signature 参数,以 dict 的键值对形式告知导出的方法对应的签名

tensorflow_model_server \--rest_api_port=8501\--model_name=MLP \--model_base_path="/home/.../.../saved_with_signature" #修改为自己模型的绝对地址

两步均完成后,即可使用以下命令部署

在客户端调用以 TensorFlow Serving 部署的模型

支持以 gRPC 和 RESTful API 调用以 TensorFlow Serving 部署的模型。这里主要介绍较为通用的 RESTful API 方法。

RESTful API 以标准的 HTTP POST 方法进行交互,请求和回复均为 JSON 对象。为了调用服务器端的模型,我们在客户端向服务器发送以下格式的请求:

服务器 URI: http://服务器地址:端口号/v1/models/模型名:predict

请求内容:

{

"signature_name": "需要调用的函数签名(Sequential模式不需要)",

"instances": 输入数据

}

回复为:

{

"predictions": 返回值

}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值