PyTorch翻译官网教程-DEPLOYING PYTORCH IN PYTHON VIA A REST API WITH FLASK

官网链接

Deploying PyTorch in Python via a REST API with Flask — PyTorch Tutorials 2.0.1+cu117 documentation

通过flask的rest API在python中部署pytorch

在本教程中,我们将使用Flask部署PyTorch模型,并开放用于模型推断的REST API。特别是,我们将部署一个预训练的DenseNet 121模型来检测图像。

这是关于在生产环境中部署PyTorch模型的系列教程中的第一篇。使用Flask这种方式是迄今为止部署PyTorch模型的最简单方法,但它不适用于具有高性能要求的用例。

API定义

我们将首先定义API 路径、请求和响应类型。我们的API路径是 /predict它接受带有包含图像的文件参数的HTTP POST请求。响应将是JSON响应,其中包含预测结果。

{"class_id": "n02124075", "class_name": "Egyptian_cat"}{"class_id": "n02124075", "class_name": "Egyptian_cat"}


依赖项

运行以下命令安装所需的依赖项:

$ pip install Flask==2.0.1 torchvision==0.10.0

简单Web服务器

下面是一个简单的web服务器,摘自Flask的文档

from flask import Flask
app = Flask(__name__)


@app.route('/')
def hello():
    return 'Hello World!'

将上面的代码片段保存在一个名为app.py的文件中,现在你可以通过输入以下命令来运行Flask开发服务器:

$ FLASK_ENV=development FLASK_APP=app.py flask run

当您在web浏览器中访问http://localhost:5000/时,您将看到Hello World!文本

我们将对上面的代码片段做一些修改,使其适合我们的API定义。首先,我们将把方法重命名为predict。我们将把请求路径更新为/predict。由于图像文件将通过HTTP POST请求发送,我们将更新它,使其也只接受POST请求。

@app.route('/predict', methods=['POST'])
def predict():
    return 'Hello World!'

我们还将更改响应类型,以便它返回一个包含ImageNet类id和名称的JSON响应。更新后的app.py文件现在将是:

from flask import Flask, jsonify
app = Flask(__name__)

@app.route('/predict', methods=['POST'])
def predict():
    return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})

推理

在下一节中,我们将重点讨论如何编写推理代码。这将涉及两个部分,一个是我们准备图像,以便它可以馈送到DenseNet,接下来,我们将编写代码以从模型中获得实际预测。

准备图像

DenseNet模型要求图像为3通道RGB图像,大小为224 x 224。我们还将用所需的均值和标准差值对图像张量进行归一化。你可以在这里读到更多(here)。

我们将使用torchvision 库中的 transforms ,并构建一个变换管道,它可以根据要求变换我们的图像。你可以在这里阅读更多关于变换的内容(here)。

import io

import torchvision.transforms as transforms
from PIL import Image

def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)

上述方法接受字节数据的图像,应用一些列的transforms 并返回一个张量。为了测试上述方法,以字节模式读取图像文件(首先将../_static/img/sample_file.jpeg替换为计算机上文件的实际路径)并查看是否返回一个张量:

with open("../_static/img/sample_file.jpeg", 'rb') as f:
    image_bytes = f.read()
    tensor = transform_image(image_bytes=image_bytes)
    print(tensor)


预测

现在将使用预训练的DenseNet 121模型来预测图像类别。我们将使用torchvision库,加载模型并获得推理结果。虽然我们将在本例中使用预训练模型,但您可以对自己的模型使用相同的方法。了解更多关于加载模型的信息(tutorial)。

from torchvision import models

# Make sure to set `weights` as `'IMAGENET1K_V1'` to use the pretrained weights:
model = models.densenet121(weights='IMAGENET1K_V1')
# Since we are using our model only for inference, switch to `eval` mode:
model.eval()


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    return y_hat


张量y_hat 将包含预测类别的索引id, 然而,我们需要一个人类可读的类名。为此,我们需要一个类别id和命名的映射。下载 imagenet_class_index.json 这个文件( this file),并记住保存它的位置(或者,如果您遵循本教程中的确切步骤,将其保存在教程/_static中)。这个文件包含ImageNet类别id到ImageNet类名的映射。我们将加载这个JSON文件并获取预测类别索引的类名。

import json

imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]


在使用imagenet_class_index字典之前,首先我们将把张量值转换为字符串值,因为imagenet_class_index字典中的键是字符串。我们将测试上述方法:

with open("../_static/img/sample_file.jpeg", 'rb') as f:
    image_bytes = f.read()
    print(get_prediction(image_bytes=image_bytes))

你应该得到这样的返回:

['n02124075', 'Egyptian_cat']

数组中的第一项是ImageNet类别id,第二项是人类可读的名称。

注意

您是否注意到model变量不是get_prediction方法的局部变量,或者说为什么model是一个全局变量?就内存和计算而言,加载模型可能是一项昂贵的操作。如果我们在get_prediction方法中加载模型,那么每次调用该方法时都会不必要地加载模型。因为我们正在构建一个web服务器,每秒可能有数千个请求,我们不应该浪费时间为每个推理加载模型。因此,我们只将模型加载到内存中一次。在生产系统中,为了能够大规模地处理请求,必须高效地使用计算,因此通常应该在处理请求之前加载模型。


在我们的API服务器中集成模型

在最后一部分中,我们将把模型添加到Flask API服务器中。由于我们的API服务器应该接受一个图像文件,我们将更新我们的预测方法来从请求中读取文件:

from flask import request

@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        # we will get the file from the request
        file = request.files['file']
        # convert that to bytes
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})

app.py文件现在已经完成。以下是完整版本;将路径替换为您保存文件的路径,它应该运行:

import io
import json

from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request


app = Flask(__name__)
imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json'))
model = models.densenet121(weights='IMAGENET1K_V1')
model.eval()


def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]


@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})


if __name__ == '__main__':
    app.run()

让我们测试一下我们的web服务器!运行:

$ FLASK_ENV=development FLASK_APP=app.py flask run


我们可以使用requests库向我们的应用发送POST请求:

import requests

resp = requests.post("http://localhost:5000/predict",
                     files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})

打印rep .json()将显示以下内容:

{"class_id": "n02124075", "class_name": "Egyptian_cat"}

下一个步骤

我们编写的服务器非常简单,可能无法完成生产应用程序所需的所有功能。所以,这里有一些你可以做的事情来让它变得更好:

  • 请求路径 /predict假定请求中总是有一个图像文件。这可能并不适用于所有请求。我们的用户可以发送带有不同参数的图像或根本不发送图像。
  • 用户也可以发送非图像类型的文件。由于我们不处理错误,这将破坏我们的服务器。显式添加异常的错误处理路径,将使我们能够更好地处理错误输入
  • 尽管该模型可以识别大量的图像类别,但它可能无法识别所有的图像。优化实现以处理模型无法识别图像中的任何内容的情况。
  • 我们以开发模式运行Flask服务器,这种模式不适合部署到生产环境中。您可以查看本教程,了解如何在生产环境中部署Flask服务器。(this tutorial
  • 您还可以通过创建带有表单的页面来添加UI,该表单接受图像并显示预测结果。请查看类似项目的演示及其源代码。(source code.
  • 在本教程中,我们只展示了如何构建一个每次可以返回单个图像预测的服务。我们可以修改我们的服务,使其能够一次返回多个图像的预测结果。此外,service-streamer库会自动将请求排队到您的服务中,并将它们采样到可以馈送到模型中的小批量中。您可以查看本教程(this tutorial.)。
  • 最后,我们鼓励您查看页面顶部链接的关于部署PyTorch模型的其他教程.

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值