23、基于Linux的TensorFlow模型部署与多平台访问

基于Linux的TensorFlow模型部署与多平台访问

1. 在Linux上直接安装TensorFlow模型服务器

在Linux系统上安装TensorFlow模型服务器时,无论使用 tensorflow-model-server 还是 tensorflow-model-server-universal ,包名都是相同的。为确保安装正确的版本,建议在开始安装前先移除旧版本的 tensorflow-model-server 。可以使用以下命令进行移除:

apt-get remove tensorflow-model-server

接着,需要将TensorFlow包源添加到系统中。可以使用以下命令:

echo "deb http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | tee /etc/apt/sources.list.d/tensorflow-serving.list && \
curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | apt-key add -

如果在本地系统中需要使用 sudo 权限,可以这样操作:

sudo echo "deb http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list && \
curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -

之后,更新 apt-get

apt-get update

完成上述步骤后,就可以使用 apt 安装模型服务器了:

apt-get install tensorflow-model-server

为确保使用的是最新版本,可以使用以下命令进行升级:

apt-get upgrade tensorflow-model-server

安装步骤总结

步骤 操作 命令
1 移除旧版本 apt-get remove tensorflow-model-server
2 添加包源 echo "deb http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | tee /etc/apt/sources.list.d/tensorflow-serving.list && \ curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | apt-key add -
3 更新 apt-get apt-get update
4 安装模型服务器 apt-get install tensorflow-model-server
5 升级模型服务器 apt-get upgrade tensorflow-model-server

2. 构建和部署模型

2.1 创建并训练模型

使用简单的“Hello World”模型进行示例。以下是创建和训练模型的代码:

import numpy as np
import tensorflow as tf

xs = np.array([-1.0,  0.0, 1.0, 2.0, 3.0, 4.0], dtype=float)
ys = np.array([-3.0, -1.0, 1.0, 3.0, 5.0, 7.0], dtype=float)

model = tf.keras.Sequential([tf.keras.layers.Dense(units=1, input_shape=[1])])
model.compile(optimizer='sgd', loss='mean_squared_error')
history = model.fit(xs, ys, epochs=500, verbose=0)

print("Finished training the model")
print(model.predict([10.0]))

该模型训练速度很快,当输入 x 为10.0时,预测结果约为18.98。

2.2 保存模型

训练完成后,需要将模型保存到临时文件夹中:

export_path = "/tmp/serving_model/1/"
model.save(export_path, save_format="tf")
print('\nexport_path = {}'.format(export_path))

需要注意的是,TensorFlow Serving会根据数字查找模型版本,默认查找版本1。因此,虽然保存模型时路径为 /tmp/serving_model/1/ ,但在部署时使用 /tmp/serving_model/ 。如果保存模型的目录中有其他内容,建议在保存前删除。

2.3 检查模型

可以使用 (saved_model_cli) 工具检查模型的元数据:

saved_model_cli show --dir {export_path} --all

该命令的输出会很长,包含以下关键信息:

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['dense_input'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 1)
        name: serving_default_dense_input:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['dense'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 1)
        name: StatefulPartitionedCall:0

需要注意 signature_def 的内容,这里是 serving_default ,后续会用到。同时,输入和输出都有定义的形状和类型,这里都是浮点型,形状为 (-1, 1) ,可以忽略 -1 ,记住输入和输出都是浮点型即可。

2.4 运行TensorFlow模型服务器

使用以下命令在命令行中运行TensorFlow模型服务器:

tensorflow_model_server --rest_api_port=8501 --model_name="helloworld" --model_base_path="/tmp/serving_model/" > server.log 2>&1

命令解释:
- rest_api_port :指定服务器运行的端口号,这里设置为8501。
- model_name :为模型指定名称,这里为 helloworld
- model_base_path :指定模型保存的路径。

打开 server.log 文件,应该能看到服务器成功启动的信息,显示正在 localhost:8501 导出 HTTP/REST API 。如果启动失败,可能需要重启系统。

2.5 测试服务器

可以使用Python代码测试服务器:

import json
import requests

xs = np.array([[9.0], [10.0]])
data = json.dumps({"signature_name": "serving_default",
                   "instances": xs.tolist()})
print(data)

headers = {"content-type": "application/json"}
json_response = requests.post(
    'http://localhost:8501/v1/models/helloworld:predict',
    data=data, headers=headers)
print(json_response.text)

发送数据时,需要将数据转换为JSON格式。输入数据应该是列表的列表,即使只有一个值也需要这样处理。响应将是一个包含预测结果的JSON负载:

{
    "predictions": [[16.9834747], [18.9806728]]
}

requests 库还提供了 json 属性,可以将响应自动解码为JSON字典。

构建和部署模型流程

graph TD;
    A[创建并训练模型] --> B[保存模型];
    B --> C[检查模型];
    C --> D[运行模型服务器];
    D --> E[测试服务器];

3. 从Android访问服务器模型

3.1 创建Android应用界面

创建一个简单的Android应用,包含一个 EditText 用于输入数字,一个 TextView 用于显示结果,一个 Button 用于触发推理:

<ScrollView
    android:id="@+id/scroll_view"
    android:layout_width="match_parent"
    android:layout_height="0dp"
    app:layout_constraintTop_toTopOf="parent"
    app:layout_constraintBottom_toTopOf="@+id/input_text">
    <TextView
        android:id="@+id/result_text_view"
        android:layout_width="match_parent"
        android:layout_height="wrap_content" />
</ScrollView>
<EditText
    android:id="@+id/input_text"
    android:layout_width="0dp"
    android:layout_height="wrap_content"
    android:hint="Enter Text Here"
    android:inputType="number"
    app:layout_constraintBaseline_toBaselineOf="@+id/ok_button"
    app:layout_constraintEnd_toStartOf="@+id/ok_button"
    app:layout_constraintStart_toStartOf="parent"
    app:layout_constraintBottom_toBottomOf="parent" />
<Button
    android:id="@+id/ok_button"
    android:layout_width="wrap_content"
    android:layout_height="wrap_content"
    android:text="OK"
    app:layout_constraintBottom_toBottomOf="parent"
    app:layout_constraintEnd_toEndOf="parent"
    app:layout_constraintStart_toEndOf="@+id/input_text"
    />

3.2 添加依赖

app build.gradle 文件中添加 Volley 库的依赖:

implementation 'com.android.volley:volley:1.1.1'

3.3 编写Activity代码

class MainActivity : AppCompatActivity() {
    lateinit var outputText: TextView
    lateinit var inputText: EditText
    lateinit var btnOK: Button

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)
        outputText = findViewById(R.id.result_text_view)
        inputText = findViewById(R.id.input_text)
        btnOK = findViewById(R.id.ok_button)
        btnOK.setOnClickListener {
            val inputValue: String = inputText.text.toString()
            val nInput = inputValue.toInt()
            doPost(nInput)
        }
    }

    private fun doPost(inputValue: Int) {
        val requestQueue: RequestQueue = Volley.newRequestQueue(this)
        val URL = "http://10.0.2.2:8501/v1/models/helloworld:predict"

        val jsonBody = JSONObject()
        jsonBody.put("signature_name", "serving_default")
        val innerarray = JSONArray()
        val outerarray = JSONArray()
        innerarray.put(inputValue)
        outerarray.put(innerarray)
        jsonBody.put("instances", outerarray)
        val requestBody = jsonBody.toString()

        val stringRequest: StringRequest =
            object : StringRequest(Method.POST, URL,
                Response.Listener { response ->
                    val str = response.toString()
                    val predictions = JSONObject(str).getJSONArray("predictions")
                        .getJSONArray(0)
                    val prediction = predictions.getDouble(0)
                    outputText.text = prediction.toString()
                },
                Response.ErrorListener { error ->
                    Log.d("API", "error => $error")
                }) {
                override fun getBody(): ByteArray {
                    return requestBody.toByteArray(Charset.defaultCharset())
                }
            }
        requestQueue.add(stringRequest)
    }
}

Android应用访问服务器模型步骤

步骤 操作 代码
1 创建界面 XML布局文件
2 添加依赖 implementation 'com.android.volley:volley:1.1.1'
3 编写Activity代码 Kotlin代码
4 处理输入和请求 doPost 函数
5 处理响应 Response.Listener

4. 从iOS访问服务器模型

4.1 定义数据结构

在Swift中,为了方便解码JSON值,需要创建与JSON结构对应的结构体。对于预测结果,可以创建如下结构体:

struct Results: Decodable {
    let predictions: [[Double]]
}

4.2 创建JSON负载

如果有一个 Double 类型的值 value ,可以按照以下方式创建上传到服务器的JSON负载:

let json: [String: Any] =
    ["signature_name" : "serving_default", "instances" : [[value]]]
let jsonData = try? JSONSerialization.data(withJSONObject: json)

4.3 发送POST请求

创建一个POST请求并添加JSON负载到请求体中:

let url = URL(string: "http://192.168.86.26:8501/v1/models/helloworld:predict")!
var request = URLRequest(url: url)
request.httpMethod = "POST"
request.httpBody = jsonData

4.4 异步处理请求和响应

使用 URLSession dataTask 方法异步处理请求和响应:

let task = URLSession.shared.dataTask(with: request) { data, response, error in
    if let data = data {
        do {
            let results: Results = try JSONDecoder().decode(Results.self, from: data)
            DispatchQueue.main.async {
                // 更新UI
                self.txtOutput.text = String(results.predictions[0][0])
            }
        } catch {
            print("Error decoding JSON: \(error)")
        }
    }
}
task.resume()

iOS应用访问服务器模型步骤

步骤 操作 代码
1 定义数据结构 struct Results: Decodable
2 创建JSON负载 let json: [String: Any]
3 发送POST请求 URLRequest httpMethod = "POST"
4 异步处理请求和响应 URLSession.shared.dataTask
5 解码响应 JSONDecoder().decode

iOS访问服务器模型流程

graph TD;
    A[定义数据结构] --> B[创建JSON负载];
    B --> C[发送POST请求];
    C --> D[异步处理请求和响应];
    D --> E[解码响应];

5. 注意事项

5.1 JSON负载格式

在向服务器传递数据时,输入数据应始终以列表的列表形式呈现,即使只有一个值。例如,要对值9.0进行推理,应使用 [[9.0]] 而不是 [9.0] 。同样,对于多个值,如9.0和10.0,应使用 [[9.0], [10.0]]

5.2 服务器地址

在不同的环境中,需要注意服务器地址的设置。在Android模拟器中运行代码时,可以使用 10.0.2.2 代替 localhost 。在iOS示例中,使用了具体的IP地址 192.168.86.26 ,实际使用时需要根据服务器的实际地址进行修改。

5.3 错误处理

在处理请求和响应时,要考虑到可能出现的错误情况。在Android代码中,使用 Response.ErrorListener 捕获错误信息;在iOS代码中,使用 catch 块处理JSON解码错误。

注意事项总结

注意事项 说明
JSON负载格式 输入数据必须是列表的列表
服务器地址 根据不同环境设置正确的地址
错误处理 捕获并处理请求和响应中的错误

通过以上步骤,我们可以在Linux上安装TensorFlow模型服务器,构建和部署模型,并从Android和iOS应用中访问服务器模型进行推理。整个过程涉及到模型的创建、保存、部署以及多平台的访问,需要注意JSON负载格式、服务器地址和错误处理等关键问题。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值