使用ONNX Runtime在Java Web应用中部署深度学习模型
背景和目标
当应用场景需要集成深度学习模型进行推理时,直接在Java Web应用中集成深度学习框架可能会面临性能、兼容性等问题。为了将深度学习模型无缝集成到Java Web应用中,模型服务化是一项广受认可且实用的方法。
本篇文档将介绍如何将模型转换为ONNX格式,并通过ONNX Runtime Server进行部署,并通过Java Web应用调用以进行回归或预测任务。本方案的主要目标是实现以下功能:
-
将PyTorch模型转换为ONNX格式。
-
部署ONNX Runtime Server,加载ONNX模型并提供推理服务。
-
在Java Web应用中调用ONNX Runtime Server进行模型推理。
此外,还演示了一种在Java应用内部直接加载ONNX模型并执行推理的方法(使用ONNX Runtime for Java库来实现)。
ONNX Runtime Server与Java客户端调用的优势
ONNX Runtime Server
ONNX(Open Neural Network Exchange)是一种开放格式,允许模型在不同的深度学习框架间自由转换。通过ONNX Runtime Server,您可以部署经ONNX转换的模型,该服务器支持多种硬件后端(CPU、GPU及特定AI芯片加速器),从而实现实时高效的模型推理。ONNX Runtime Server还拥有跨平台支持、多语言客户端、性能优化以及广泛模型兼容性的优点,使得Java Web应用仅需通过标准HTTP接口就能调用模型,而无需关注底层深度学习框架的具体实现。
Java客户端调用
Java客户端可以简单快捷地通过RESTful API或gRPC方式调用远程的ONNX Runtime Server服务。这种设计使Java应用开发者能够集中精力于业务逻辑,无需深入理解模型内部工作原理。客户端只需遵循规定的接口规范发送请求,接收并解析服务器返回的预测结果。
完整部署流程
1. 模型转换
首先,使用ONNX将训练好的PyTorch模型转换为ONNX格式,以实现跨框架互操作性。
2. 部署环境与工具
模型部署
-
本地部署:在本地服务器上安装ONNX Runtime Server,Java Web应用位于同一内网环境,通过内网地址调用模型服务。
本地部署意味着将ONNX Runtime Server部署在企业内部的物理服务器或虚拟机上,模型及其运行环境完全在企业控制之下。部署步骤可能包括:
- 在本地服务器上安装ONNX Runtime Server及相关软件。
- 将转换后的ONNX模型加载到本地部署的ONNX Runtime Server中。
- 配置防火墙规则和网络设置,确保Java Web应用可以通过局域网访问模型服务。
- Java Web应用直接通过内网IP和端口调用模型服务。
优点:
- 数据隐私和安全控制更强,因为数据无需离开企业内部网络。
- 网络延迟较低,尤其对于大规模数据流处理或实时性要求高的应用。
- 可根据实际需求定制硬件资源,如使用专用的GPU服务器进行高速推理。
缺点:
- 需要投入额外的硬件成本和维护精力,包括服务器购置、运维和升级。
- 扩展性和弹性较差,当模型服务需求增长时,增加资源可能较慢且成本较高。
-
云端部署:在云服务商(如AWS、Azure、阿里云等)提供的服务器上部署ONNX Runtime Server,公开HTTP API并实施安全策略。Java Web应用部署在任何可访问互联网的环境中,通过公网地址调用模型服务。
云端部署则指将ONNX Runtime Server部署在云服务提供商(如AWS、Azure或Google Cloud)的云端服务器上。部署步骤大致如下:
- 在云服务提供商平台上创建服务器实例,并安装ONNX Runtime Server。
- 将ONNX模型上传至云端服务器,并在服务器上加载模型到ONNX Runtime Server。
- 配置云服务器的安全组规则,对外开放模型服务的HTTP接口。
- Java Web应用通过公网URL访问云端部署的模型服务。
优点:
- 弹性伸缩性强,可以根据需求动态调整服务器资源,节省成本。
- 自动化运维程度高,云服务提供商通常会提供易于管理和监控的服务。
- 全球访问便捷,适合分布式或跨国应用。
缺点:
- 数据传输可能会增加网络延迟,尤其是对于带宽要求高的应用。
- 对第三方云服务的依赖性增强,涉及数据安全和合规性考量。
- 长期使用可能产生持续的云服务租赁费用,相比一次性硬件投资可能更高。
Java Web应用
Java Web应用需要部署在独立服务器或同一服务器的不同进程中,确保安装Java运行环境(JRE或JDK)、Web服务器(如Tomcat或Jetty)和相应的开发框架(如Spring Boot)。
工具与库
- 模型转换工具:使用Python的ONNX库将PyTorch模型转为ONNX格式。
- 模型部署工具:ONNX Runtime Server负责加载ONNX模型并提供REST API接口。
- Java Web应用工具:
- 客户端HTTP库:Apache HttpClient或OkHttp等,用于Java应用与ONNX Runtime Server之间的HTTP通信。
- JSON处理库:如Jackson或Gson,用于处理请求数据和预测结果的JSON序列化和反序列化。
性能优化与注意事项
- 根据应用需求选择合适的硬件资源(如CPU、GPU)部署ONNX Runtime Server。
- 确保本地部署的网络连接稳定可靠,云端部署时合理规划网络架构以降低延迟。
- 对于大流量或实时性要求高的场景,建议采用负载均衡、缓存策略和异步调用机制。
步骤详解
步骤一:将PyTorch模型转换为ONNX格式
-
安装必要的Python库:
pip install onnx onnxruntime torch torchvision
-
使用ONNX将PyTorch模型导出为ONNX格式:
import torch import onnx from your_model_module import YourModelClass # 加载PyTorch模型 model = YourModelClass() # 创建模型的一个实例,YourModelClass是你定义的PyTorch模型类。 model.eval() # 将模型设置为评估模式,这是在进行推理时常用的做法,因为它会关闭一些特定于训练阶段的行为,比如dropout model.load_state_dict(torch.load('path_to_your_trained_model.pth')) # 加载训练好的模型权重 input_shape = [1, ...] # 替换为模型实际输入维度 # 这里,你需要根据你的模型实际的输入维度来替换input_shape。 # 例如,如果你的模型接受一个形状为(1, 3, 224, 224)的张量作为输入(这是一个常见的输入形状,用于具有三个颜色通道和224x224像素的图像) # 那么你应该将input_shape设置为[1, 3, 224, 224]。torch.randn(input_shape)会生成一个具有随机数的张量,用于模拟一个输入样本。 x_example = torch.randn(input_shape) input_names = ["input"] # 根据模型实际输入名称替换 output_names = ["output"] # 根据模型实际输出名称替换 # 在ONNX模型中,每个输入和输出都有一个名称。这些名称在模型的推理过程中用于标识输入和输出张量。在这里,你需要根据你的模型的实际情况来替换"input"和"output"。如果你不确定你的模型的输入和输出名称,你可以暂时保留它们,然后在模型转换后检查ONNX模型的元数据以获取正确的名称。 torch.onnx.export(model, x_example, "your_model.onnx", export_params=True, opset_version=11, do_constant_folding=True, input_names=input_names, output_names=output_names)
步骤二:部署ONNX模型到ONNX Runtime Server
-
安装ONNX Runtime Server:
pip install onnxruntime-server
-
启动ONNX Runtime Server并注册模型:
onnxruntime-server --start --model-path your_model.onnx --service-address localhost:8001 --model-name your_model
步骤三:在Java Web应用中调用模型服务
-
添加ONNX Runtime for Java的Maven依赖:
<dependency> <groupId>com.microsoft.onnxruntime</groupId> <artifactId>onnxruntime</artifactId> <version>1.9.0</version> <!-- 根据最新版本替换 --> </dependency>
-
使用Java HTTP客户端调用ONNX Runtime Server:
// Java代码片段 import org.apache.http.HttpResponse; import org.apache.http.client.methods.HttpPost; import org.apache.http.entity.StringEntity; import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.HttpClients; import org.apache.http.util.EntityUtils; import org.apache.http.entity.ContentType; import org.apache.http.HttpStatus; public class ModelCaller { public static void main(String[] args) throws Exception { CloseableHttpClient httpClient = HttpClients.createDefault(); HttpPost httpPost = new HttpPost("http://localhost:8001/v1/models/your_model/infer"); // 假设模型接受JSON格式的输入数据 String jsonInputData = "{\"data\": [...]}"; // 替换为实际输入数据 StringEntity inputEntity = new StringEntity(jsonInputData, ContentType.APPLICATION_JSON); httpPost.setHeader("Content-Type", "application/json"); httpPost.setEntity(inputEntity); HttpResponse response = httpClient.execute(httpPost); if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) { String resultJson = EntityUtils.toString(response.getEntity()); // 解析并处理预测结果 processPredictionResult(resultJson); } httpClient.close(); } private static void processPredictionResult(String resultJson) { // 根据实际模型输出格式解析并处理结果 } }
本示例展示了通过HTTP REST API调用ONNX Runtime Server的方法,但对于模型较小且对延迟敏感的应用场景,也可以选择在Java应用内部直接加载ONNX模型并执行推理。
附加:在Java应用内部直接加载ONNX模型并执行推理
在Java应用内部直接加载ONNX模型并执行推理,可以使用ONNX Runtime for Java库来实现。
以下是一个简化的示例,在Java应用中加载ONNX模型并进行推理:
-
添加ONNX Runtime for Java依赖:
在Maven项目中,你需要在pom.xml
文件中添加ONNX Runtime的依赖项:<dependency> <groupId>com.microsoft.onnxruntime</groupId> <artifactId>onnxruntime</artifactId> <!-- 根据最新版本号替换 --> <version>1.12.0</version> </dependency>
-
加载ONNX模型:
下面是一个基本示例,展示了如何加载ONNX模型并执行推理:import com.microsoft.onnxruntime.OnnxRuntime; import com.microsoft.onnxruntime.OrtEnvironment; import com.microsoft.onnxruntime.SessionOptions; import com.microsoft.onnxruntime.TensorInfo; import com.microsoft.onnxruntime.capi.OnnxTensor; import com.microsoft.onnxruntime.exceptions.OnnxRuntimeException; import com.microsoft.onnxruntime.Session; import java.nio.FloatBuffer; import java.nio.IntBuffer; public class OnnxInferenceExample { public static void main(String[] args) { try (OrtEnvironment env = OrtEnvironment.getEnvironment()) { // 初始化SessionOptions SessionOptions sessionOptions = new SessionOptions(); // 加载ONNX模型 String modelPath = "path_to_your_model.onnx"; try (Session session = env.createSession(modelPath, sessionOptions)) { // 获取模型输入和输出的信息 TensorInfo[] inputInfos = session.getInputTypeInfo(); TensorInfo[] outputInfos = session.getOutputTypeInfo(); // 假设模型有一个名为"data"的输入,其类型为float,维度为[1, 224, 224, 3] int[] inputDims = new int[]{1, 224, 224, 3}; FloatBuffer inputData = FloatBuffer.allocate(1 * 224 * 224 * 3); // 填充真实输入数据 // 创建输入Tensor OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputData, inputDims); // 执行模型推理 OnnxTensor[] outputs = session.run(new OnnxTensor[]{inputTensor}); // 获取第一个输出结果 float[] predictionArray = outputs[0].getValue().asFloatBuffer().array(); // 进一步处理预测结果 processPredictionResults(predictionArray); // 清理资源 inputTensor.close(); for (OnnxTensor output : outputs) { output.close(); } } } catch (OnnxRuntimeException e) { System.out.println("Error occurred while loading or running the model: " + e.getMessage()); } } private static void processPredictionResults(float[] predictionArray) { // 在这里处理预测结果 } }
在这个示例中,首先初始化ONNX Runtime环境并创建Session
对象来加载ONNX模型。接着,创建一个输入张量并填充数据,然后通过调用session.run()
方法执行推理。推理完成后,从输出张量中提取预测结果并进行处理。
请注意,在实际使用时需要根据模型输入和输出的具体类型和维度调整上述代码。同时,需要确保ONNX模型的路径正确,并根据模型的实际结构填充正确的输入数据。