使用 Flask 快速部署 PyTorch 模型

对于数据科学项目来说,我们一直都很关注模型的训练和表现,但是在实际工作中如何启动和运行我们的模型是模型上线的最后一步也是最重要的工作。

今天我将通过一个简单的案例:部署一个PyTorch图像分类模型,介绍这个最重要的步骤。

在这里插入图片描述

我们这里使用PyTorch和Flask。可以使用pip install torch和pip install flask安装这些包。

web应用

为Flask创建一个文件app.py和一个路由:

from flask import Flask
import torch


app = Flask(__name__)

@app.route('/')
def home():
    return 'Welcome to the PyTorch Flask app!'

现在我们可以运行python app.py,如果没有问题,你可以访问http://localhost:5000/,应该会看到一条简单的消息——“Welcome to the PyTorch Flask app!”

这就说明我们flask的web服务已经可以工作了,现在让我们添加一些代码,将数据传递给我们的模型!

源码分享&技术交流

技术要学会分享、交流,不建议闭门造车。 本文技术由粉丝群小伙伴分享汇总。源码、数据、技术交流提升,均可加交流群获取,群友已超过2000人,添加时最好的备注方式为:来源+兴趣方向,方便找到志同道合的朋友。

方式①、添加微信号:dkl88191,备注:来自CSDN +学管系统
方式②、微信搜索公众号:Python学习与数据挖掘,后台回复:学管系统

添加更多的导入

 from flask import Flask, request, render_template  
 from PIL import Image  
 import torch  
 import torchvision.transforms as transforms

然后再将主页的内容换成一个HTML页面

 @app.route('/')  
 def home():  
     return render\_template('home.html')

创建一个templates文件夹,然后创建home.html。

 <html>
  <head>
    <title>PyTorch Image Classification</title>
  </head>
  <body>
    <h1>PyTorch Image Classification</h1>
    <form method="POST" enctype="multipart/form-data" action="/predict">
      <input type="file" name="image">
      <input type="submit" value="Predict">
    </form>
  </body>
</html>

HTML非常简单——有一个上传按钮,可以上传我们想要运行模型的任何数据(在我们的例子中是图像)。

以上都是基本的web应用的内容,下面就是要将这个web应用和我们的pytorch模型的推理结合。

加载模型

在home route上面,加载我们的模型。

 model = torch.jit.load('path/to/model.pth')

我们都知道,模型的输入是张量,所以对于图片来说,我们需要将其转换为张量、还要进行例如调整大小或其他形式的预处理(这与训练时的处理一样)。

我们处理的是图像,所以预处理很简单

def process_image(image):
    # Preprocess image for model
    transformation = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image_tensor = transformation(image).unsqueeze(0)
    
    return image_tensor

我们还需要一个数组来表示类,本文只有2类

class_names = ['apple', 'banana'] #REPLACE THIS WITH YOUR CLASSES

预测

下一步就是创建一个路由,接收上传的图像,处理并使用模型进行预测,并返回每个类的概率。

 @app.route('/predict', methods=['POST'])
def predict():
    # Get uploaded image file
    image = request.files['image']

    # Process image and make prediction
    image_tensor = process_image(Image.open(image))
    output = model(image_tensor)

    # Get class probabilities
    probabilities = torch.nn.functional.softmax(output, dim=1)
    probabilities = probabilities.detach().numpy()[0]

    # Get the index of the highest probability
    class_index = probabilities.argmax()

    # Get the predicted class and probability
    predicted_class = class_names[class_index]
    probability = probabilities[class_index]

    # Sort class probabilities in descending order
    class_probs = list(zip(class_names, probabilities))
    class_probs.sort(key=lambda x: x[1], reverse=True)

    # Render HTML page with prediction results
    return render_template('predict.html', class_probs=class_probs,
                           predicted_class=predicted_class, probability=probability)

我们的/predict路由首先使用softmax函数获得类概率,然后获得最高概率的索引。它使用这个索引在类名列表中查找预测的类,并获得该类的概率。然后按降序对类别概率进行排序,并返回预测结果。

最后,我们的app.py文件应该是这样的:

from flask import Flask, request, render_template
from PIL import Image
import torch
import torchvision.transforms as transforms


model = torch.jit.load('path/to/model.pth')

@app.route('/')
def home():
    return render_template('home.html')

def process_image(image):
    # Preprocess image for model
    transformation = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image_tensor = transformation(image).unsqueeze(0)
    
    return image_tensor


class_names = ['apple', 'banana'] #REPLACE THIS WITH YOUR CLASSES

@app.route('/predict', methods=['POST'])
def predict():
    # Get uploaded image file
    image = request.files['image']

    # Process image and make prediction
    image_tensor = process_image(Image.open(image))
    output = model(image_tensor)

    # Get class probabilities
    probabilities = torch.nn.functional.softmax(output, dim=1)
    probabilities = probabilities.detach().numpy()[0]

    # Get the index of the highest probability
    class_index = probabilities.argmax()

    # Get the predicted class and probability
    predicted_class = class_names[class_index]
    probability = probabilities[class_index]

    # Sort class probabilities in descending order
    class_probs = list(zip(class_names, probabilities))
    class_probs.sort(key=lambda x: x[1], reverse=True)

    # Render HTML page with prediction results
    return render_template('predict.html', class_probs=class_probs,
                           predicted_class=predicted_class, probability=probability)

最后一个部分是实现predict.html模板,在templates目录创建一个名为predict.html的文件:

 <html>
  <head>
    <title>Prediction Results</title>
  </head>
  <body>
    <h1>Prediction Results</h1>
    <p>Predicted Class: {{ predicted_class }}</p>
    <p>Probability: {{ probability }}</p>
    <h2>Other Classes</h2>
    <ul>
      {% for class_name, prob in class_probs %}
        <li>{{ class_name }}: {{ prob }}</li>
      {% endfor %}
    </ul>
  </body>
</html>

这个HTML页面显示了预测的类别和概率,以及按概率降序排列的其他类别列表。

测试

使用python app.py运行服务,然后首页会显示我们创建的上传图片的按钮,可以通过按钮上传图片进行测试,这里我们还可以通过编程方式发送POST请求来测试您的模型。

下面就是发送POST请求的Python代码

import requests

# Set URL for Flask app
url = 'http://localhost:5000/predict'

# Set image file path
image_path = 'path/to/image.jpg'

# Read image file and set as payload
image = open(image_path, 'rb')
payload = {'image': image}

# Send POST request with image and get response
response = requests.post(url, headers=headers, data=payload)

print(response.text)

这段代码将向Flask应用程序发送一个POST请求,上传指定的图像文件。我们创建的Flask应用程会处理图像,做出预测并返回响应,最后响应将打印到控制台。

就是这样只要5分钟,我们就可以成功地部署一个ML模型。

  • 1
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
要将FlaskPyTorch目标检测模型部署为Web服务,可以按照以下步骤进行操作: 1. 准备环境:确保已安装FlaskPyTorch和其他必要的依赖库。 2. 构建Web应用:创建Python脚本或包含Flask应用的目录结构。在Flask应用中,定义一个路由(route)用于接收图像文件或URL,并将其传递给目标检测模型。 3. 加载模型和预处理:使用PyTorch加载预训练的目标检测模型,并进行必要的预处理操作,例如图像缩放和归一化。 4. 目标检测推理:将输入图像传递给目标检测模型进行推理。根据模型输出的结果,提取目标框的位置、类别和置信度等信息。 5. 可视化结果:根据推理结果,在原始图像上绘制检测到的目标框和类别,并将结果返回给用户。 6. 部署与测试:在本地环境中运行Flask应用,并通过浏览器或其他HTTP工具发送图像或URL请求进行测试。可以使用前端技术(如HTML、CSS和JavaScript)美化界面和实现用户交互。 7. 部署到服务器:将Flask应用部署到云服务器或虚拟机中,确保服务器具有足够的计算资源和网络带宽来支持多个并发请求。 8. 性能优化:根据实际需求,可以优化目标检测模型的推理速度,例如使用FP16精度、模型剪枝或量化等技术。 9. 安全性考虑:在处理用户上传的图像或URL时,确保实施适当的安全性措施,例如输入验证和图像过滤,以防止恶意程序或内容的传输。 通过以上步骤,就可以成功将PyTorch目标检测模型部署为一个可访问的Web服务。用户可以使用该服务上传图像或提供URL,查看模型对该图像中目标的检测结果。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值