使用Flask通过REST API部署PyTorch模型并实现Java Web程序调用


使用Flask通过REST API部署预训练的PyTorch模型并实现Java Web程序调用


一、准备工作

在Python环境中,我们首先需要安装必要的库Flask和torchvision,用于部署预训练的PyTorch模型:

pip install Flask==2.0.1 torchvision==0.10.0

二、模型加载与部署

  1. 加载自定义预训练模型

假设您有一个自定义训练好的PyTorch模型,保存在.pth文件中。首先加载模型并将其置于评估模式:

import torch
from your_model_module import CustomModel  # 自定义模型类

# 加载模型权重
model_path = '<PATH_TO_YOUR_PRETRAINED_MODEL.pth>'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

custom_model = CustomModel()
custom_model.load_state_dict(torch.load(model_path, map_location=device))
custom_model.to(device)
custom_model.eval()
  1. 准备类别ID到名称的映射

假设您有一个单独的映射文件,用来存储模型预测结果的类别ID与对应名称之间的关系:

def load_class_mapping(mapping_filepath):
    # 根据实际映射文件格式实现加载函数
    pass

class_id_to_name_mapping = load_your_class_mapping('<PATH_TO_CLASS_MAPPING>')

load_class_mapping 函数的实现取决于映射文件的具体格式。
这里给出两种常见的映射文件格式(JSON和CSV)的加载示例:

  • JSON格式的映射文件加载示例:
import json

def load_class_mapping(json_mapping_filepath):
    with open(json_mapping_filepath, 'r') as f:
        class_id_to_name_mapping = json.load(f)
    # 假设JSON文件是如下格式:
    # {"0": "类别A", "1": "类别B", "2": "类别C"}
    # 此处无需进一步处理,因为json.load已经将其转换为字典
    return class_id_to_name_mapping

  • CSV格式的映射文件加载示例:
import csv
def load_class_mapping(csv_mapping_filepath):
    class_id_to_name_mapping = {}
    with open(csv_mapping_filepath, 'r') as f:
        reader = csv.reader(f)
        next(reader)  # 如果存在表头,则跳过
        for row in reader:
            class_id_to_name_mapping[row[0]] = row[1]  # 假设第一列是ID,第二列是名称
    return class_id_to_name_mapping

在这两个示例中,load_class_mapping 函数都接收一个映射文件路径作为参数,并返回一个字典,其中键是类别ID,值是类别名称。需要根据实际情况调整这些函数以适应您的映射文件结构。例如,如果CSV文件具有不同的列顺序或命名,或者JSON文件的键值对结构不同,则需要相应地调整代码。

  1. 定义图像预处理与预测函数
import io
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize

def preprocess_image(image_bytes):
    # 根据自定义模型的输入要求定义预处理管道
    preprocessing_pipeline = Compose([...])
    pil_image = Image.open(io.BytesIO(image_bytes))
    preprocessed_image = preprocessing_pipeline(pil_image)
    return preprocessed_image.unsqueeze(0).to(device)

def get_custom_prediction(image_bytes):
    input_tensor = preprocess_image(image_bytes)
    with torch.no_grad():
        outputs = custom_model(input_tensor)
        # 根据模型输出特性获取预测类别ID
        _, predicted_class_id = torch.topk(outputs, k=1)  # 或者使用其他方法提取预测ID
        class_name = class_id_to_name_mapping[str(predicted_class_id.item())]
    return predicted_class_id.item(), class_name
  1. 在Flask中实现API端点
from flask import Flask, jsonify, request

app = Flask(__name__)

@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_id, class_name = get_custom_prediction(img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)
Python部分代码汇总

以下是上述Python部分的关键代码片段的汇总,它们涵盖了从加载预训练模型、部署REST API到加载类别映射表:

# 引入必要库
from flask import Flask, jsonify, request
import torch
from torchvision import models
import io
import PIL.Image
import json
from torchvision.transforms import Compose, Resize, ToTensor, Normalize

# 加载预训练模型
model_path = '<PATH_TO_YOUR_PRETRAINED_MODEL.pth>'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = YourCustomModel()  # 替换成你的自定义模型类
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

# 加载类别ID到名称的映射表
def load_class_mapping(mapping_filepath):
    with open(mapping_filepath, 'r') as f:
        class_id_to_name = json.load(f)
    return class_id_to_name

class_id_to_name_mapping = load_class_mapping('<PATH_TO_CLASS_MAPPING.json>')

# 图像预处理
transform = Compose([
    Resize((224, 224)),  # 根据你的模型输入调整大小
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 参考ImageNet的标准化参数
])

# 预测函数
def predict(image_bytes):
    img = PIL.Image.open(io.BytesIO(image_bytes)).convert('RGB')
    input_tensor = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(input_tensor)
        _, prediction = torch.max(output, dim=1)
        class_id = prediction.item()
        class_name = class_id_to_name_mapping[str(class_id)]
    return class_id, class_name

# 创建Flask应用并定义路由
app = Flask(__name__)

@app.route('/predict', methods=['POST'])
def classify_image():
    if 'file' not in request.files:
        return jsonify({"error": "No file part"}), 400
    file = request.files['file']
    if not file:
        return jsonify({"error": "No selected file"}), 400
    img_bytes = file.read()
    class_id, class_name = predict(img_bytes)
    return jsonify({"class_id": class_id, "class_name": class_name})

if __name__ == '__main__':
    app.run(debug=True, host='0.0.0.0')

三、Java Web程序调用API

在Java Web应用中,使用java.net.HttpURLConnection或其他HTTP客户端库(如Apache HttpClient)发起POST请求,将图像文件发送至Flask API并获取预测结果:

import java.io.*;
import java.net.HttpURLConnection;
import java.net.URL;
import org.json.JSONObject;

public class ImagePredictor {
    public static void main(String[] args) throws Exception {
        String imageUrl = "<PATH_TO_LOCAL_IMAGE>";
        String apiUrl = "http://localhost:5000/predict";

        try (FileInputStream fis = new FileInputStream(new File(imageUrl))) {
            HttpURLConnection connection = (HttpURLConnection) new URL(apiUrl).openConnection();
            connection.setRequestMethod("POST");
            connection.setRequestProperty("Content-Type", "multipart/form-data");

            // 将图像文件添加到请求体
            connection.setDoOutput(true);
            OutputStream os = connection.getOutputStream();
            byte[] imageData = new byte[(int) new File(imageUrl).length()];
            fis.read(imageData);
            os.write(imageData);
            os.flush();
            os.close();

            // 解析API响应
            BufferedReader in = new BufferedReader(new InputStreamReader(connection.getInputStream()));
            String inputLine;
            StringBuilder content = new StringBuilder();
            while ((inputLine = in.readLine()) != null) {
                content.append(inputLine);
            }
            in.close();

            JSONObject responseJson = new JSONObject(content.toString());
            int classId = responseJson.getInt("class_id");
            String className = responseJson.getString("class_name");

            System.out.println("Class ID: " + classId);
            System.out.println("Class Name: " + className);
        }
    }
}

通过以上步骤部署了一个自定义预训练的PyTorch模型,并且能够通过Java Web程序调用此模型提供的REST API来获取预测结果。请注意根据不同模型的输入要求调整预处理函数,并确保Java Web应用中的请求发送与接收与Flask API规范保持一致。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值