基于Linux的TensorFlow模型部署与多平台访问
1. 在Linux上直接安装TensorFlow模型服务器
在Linux系统上安装TensorFlow模型服务器时,无论使用
tensorflow-model-server
还是
tensorflow-model-server-universal
,包名都是相同的。为确保安装正确的版本,建议在开始安装前先移除旧版本的
tensorflow-model-server
。可以使用以下命令进行移除:
apt-get remove tensorflow-model-server
接着,需要将TensorFlow包源添加到系统中。可以使用以下命令:
echo "deb http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | tee /etc/apt/sources.list.d/tensorflow-serving.list && \
curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | apt-key add -
如果在本地系统中需要使用
sudo
权限,可以这样操作:
sudo echo "deb http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list && \
curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -
之后,更新
apt-get
:
apt-get update
完成上述步骤后,就可以使用
apt
安装模型服务器了:
apt-get install tensorflow-model-server
为确保使用的是最新版本,可以使用以下命令进行升级:
apt-get upgrade tensorflow-model-server
安装步骤总结
步骤 | 操作 | 命令 |
---|---|---|
1 | 移除旧版本 |
apt-get remove tensorflow-model-server
|
2 | 添加包源 |
echo "deb http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | tee /etc/apt/sources.list.d/tensorflow-serving.list && \ curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | apt-key add -
|
3 |
更新
apt-get
|
apt-get update
|
4 | 安装模型服务器 |
apt-get install tensorflow-model-server
|
5 | 升级模型服务器 |
apt-get upgrade tensorflow-model-server
|
2. 构建和部署模型
2.1 创建并训练模型
使用简单的“Hello World”模型进行示例。以下是创建和训练模型的代码:
import numpy as np
import tensorflow as tf
xs = np.array([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=float)
ys = np.array([-3.0, -1.0, 1.0, 3.0, 5.0, 7.0], dtype=float)
model = tf.keras.Sequential([tf.keras.layers.Dense(units=1, input_shape=[1])])
model.compile(optimizer='sgd', loss='mean_squared_error')
history = model.fit(xs, ys, epochs=500, verbose=0)
print("Finished training the model")
print(model.predict([10.0]))
该模型训练速度很快,当输入
x
为10.0时,预测结果约为18.98。
2.2 保存模型
训练完成后,需要将模型保存到临时文件夹中:
export_path = "/tmp/serving_model/1/"
model.save(export_path, save_format="tf")
print('\nexport_path = {}'.format(export_path))
需要注意的是,TensorFlow Serving会根据数字查找模型版本,默认查找版本1。因此,虽然保存模型时路径为
/tmp/serving_model/1/
,但在部署时使用
/tmp/serving_model/
。如果保存模型的目录中有其他内容,建议在保存前删除。
2.3 检查模型
可以使用
(saved_model_cli)
工具检查模型的元数据:
saved_model_cli show --dir {export_path} --all
该命令的输出会很长,包含以下关键信息:
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['dense_input'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: serving_default_dense_input:0
The given SavedModel SignatureDef contains the following output(s):
outputs['dense'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: StatefulPartitionedCall:0
需要注意
signature_def
的内容,这里是
serving_default
,后续会用到。同时,输入和输出都有定义的形状和类型,这里都是浮点型,形状为
(-1, 1)
,可以忽略
-1
,记住输入和输出都是浮点型即可。
2.4 运行TensorFlow模型服务器
使用以下命令在命令行中运行TensorFlow模型服务器:
tensorflow_model_server --rest_api_port=8501 --model_name="helloworld" --model_base_path="/tmp/serving_model/" > server.log 2>&1
命令解释:
-
rest_api_port
:指定服务器运行的端口号,这里设置为8501。
-
model_name
:为模型指定名称,这里为
helloworld
。
-
model_base_path
:指定模型保存的路径。
打开
server.log
文件,应该能看到服务器成功启动的信息,显示正在
localhost:8501
导出
HTTP/REST API
。如果启动失败,可能需要重启系统。
2.5 测试服务器
可以使用Python代码测试服务器:
import json
import requests
xs = np.array([[9.0], [10.0]])
data = json.dumps({"signature_name": "serving_default",
"instances": xs.tolist()})
print(data)
headers = {"content-type": "application/json"}
json_response = requests.post(
'http://localhost:8501/v1/models/helloworld:predict',
data=data, headers=headers)
print(json_response.text)
发送数据时,需要将数据转换为JSON格式。输入数据应该是列表的列表,即使只有一个值也需要这样处理。响应将是一个包含预测结果的JSON负载:
{
"predictions": [[16.9834747], [18.9806728]]
}
requests
库还提供了
json
属性,可以将响应自动解码为JSON字典。
构建和部署模型流程
graph TD;
A[创建并训练模型] --> B[保存模型];
B --> C[检查模型];
C --> D[运行模型服务器];
D --> E[测试服务器];
3. 从Android访问服务器模型
3.1 创建Android应用界面
创建一个简单的Android应用,包含一个
EditText
用于输入数字,一个
TextView
用于显示结果,一个
Button
用于触发推理:
<ScrollView
android:id="@+id/scroll_view"
android:layout_width="match_parent"
android:layout_height="0dp"
app:layout_constraintTop_toTopOf="parent"
app:layout_constraintBottom_toTopOf="@+id/input_text">
<TextView
android:id="@+id/result_text_view"
android:layout_width="match_parent"
android:layout_height="wrap_content" />
</ScrollView>
<EditText
android:id="@+id/input_text"
android:layout_width="0dp"
android:layout_height="wrap_content"
android:hint="Enter Text Here"
android:inputType="number"
app:layout_constraintBaseline_toBaselineOf="@+id/ok_button"
app:layout_constraintEnd_toStartOf="@+id/ok_button"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintBottom_toBottomOf="parent" />
<Button
android:id="@+id/ok_button"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="OK"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toEndOf="@+id/input_text"
/>
3.2 添加依赖
在
app
的
build.gradle
文件中添加
Volley
库的依赖:
implementation 'com.android.volley:volley:1.1.1'
3.3 编写Activity代码
class MainActivity : AppCompatActivity() {
lateinit var outputText: TextView
lateinit var inputText: EditText
lateinit var btnOK: Button
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
outputText = findViewById(R.id.result_text_view)
inputText = findViewById(R.id.input_text)
btnOK = findViewById(R.id.ok_button)
btnOK.setOnClickListener {
val inputValue: String = inputText.text.toString()
val nInput = inputValue.toInt()
doPost(nInput)
}
}
private fun doPost(inputValue: Int) {
val requestQueue: RequestQueue = Volley.newRequestQueue(this)
val URL = "http://10.0.2.2:8501/v1/models/helloworld:predict"
val jsonBody = JSONObject()
jsonBody.put("signature_name", "serving_default")
val innerarray = JSONArray()
val outerarray = JSONArray()
innerarray.put(inputValue)
outerarray.put(innerarray)
jsonBody.put("instances", outerarray)
val requestBody = jsonBody.toString()
val stringRequest: StringRequest =
object : StringRequest(Method.POST, URL,
Response.Listener { response ->
val str = response.toString()
val predictions = JSONObject(str).getJSONArray("predictions")
.getJSONArray(0)
val prediction = predictions.getDouble(0)
outputText.text = prediction.toString()
},
Response.ErrorListener { error ->
Log.d("API", "error => $error")
}) {
override fun getBody(): ByteArray {
return requestBody.toByteArray(Charset.defaultCharset())
}
}
requestQueue.add(stringRequest)
}
}
Android应用访问服务器模型步骤
步骤 | 操作 | 代码 |
---|---|---|
1 | 创建界面 | XML布局文件 |
2 | 添加依赖 |
implementation 'com.android.volley:volley:1.1.1'
|
3 | 编写Activity代码 | Kotlin代码 |
4 | 处理输入和请求 |
doPost
函数
|
5 | 处理响应 |
Response.Listener
|
4. 从iOS访问服务器模型
4.1 定义数据结构
在Swift中,为了方便解码JSON值,需要创建与JSON结构对应的结构体。对于预测结果,可以创建如下结构体:
struct Results: Decodable {
let predictions: [[Double]]
}
4.2 创建JSON负载
如果有一个
Double
类型的值
value
,可以按照以下方式创建上传到服务器的JSON负载:
let json: [String: Any] =
["signature_name" : "serving_default", "instances" : [[value]]]
let jsonData = try? JSONSerialization.data(withJSONObject: json)
4.3 发送POST请求
创建一个POST请求并添加JSON负载到请求体中:
let url = URL(string: "http://192.168.86.26:8501/v1/models/helloworld:predict")!
var request = URLRequest(url: url)
request.httpMethod = "POST"
request.httpBody = jsonData
4.4 异步处理请求和响应
使用
URLSession
的
dataTask
方法异步处理请求和响应:
let task = URLSession.shared.dataTask(with: request) { data, response, error in
if let data = data {
do {
let results: Results = try JSONDecoder().decode(Results.self, from: data)
DispatchQueue.main.async {
// 更新UI
self.txtOutput.text = String(results.predictions[0][0])
}
} catch {
print("Error decoding JSON: \(error)")
}
}
}
task.resume()
iOS应用访问服务器模型步骤
步骤 | 操作 | 代码 |
---|---|---|
1 | 定义数据结构 |
struct Results: Decodable
|
2 | 创建JSON负载 |
let json: [String: Any]
|
3 | 发送POST请求 |
URLRequest
和
httpMethod = "POST"
|
4 | 异步处理请求和响应 |
URLSession.shared.dataTask
|
5 | 解码响应 |
JSONDecoder().decode
|
iOS访问服务器模型流程
graph TD;
A[定义数据结构] --> B[创建JSON负载];
B --> C[发送POST请求];
C --> D[异步处理请求和响应];
D --> E[解码响应];
5. 注意事项
5.1 JSON负载格式
在向服务器传递数据时,输入数据应始终以列表的列表形式呈现,即使只有一个值。例如,要对值9.0进行推理,应使用
[[9.0]]
而不是
[9.0]
。同样,对于多个值,如9.0和10.0,应使用
[[9.0], [10.0]]
。
5.2 服务器地址
在不同的环境中,需要注意服务器地址的设置。在Android模拟器中运行代码时,可以使用
10.0.2.2
代替
localhost
。在iOS示例中,使用了具体的IP地址
192.168.86.26
,实际使用时需要根据服务器的实际地址进行修改。
5.3 错误处理
在处理请求和响应时,要考虑到可能出现的错误情况。在Android代码中,使用
Response.ErrorListener
捕获错误信息;在iOS代码中,使用
catch
块处理JSON解码错误。
注意事项总结
注意事项 | 说明 |
---|---|
JSON负载格式 | 输入数据必须是列表的列表 |
服务器地址 | 根据不同环境设置正确的地址 |
错误处理 | 捕获并处理请求和响应中的错误 |
通过以上步骤,我们可以在Linux上安装TensorFlow模型服务器,构建和部署模型,并从Android和iOS应用中访问服务器模型进行推理。整个过程涉及到模型的创建、保存、部署以及多平台的访问,需要注意JSON负载格式、服务器地址和错误处理等关键问题。