通过带Flask的REST API在Python中部署PyTorch

在本教程中,我们将使用Flask来部署PyTorch模型,并用讲解用于模型推断的 REST API。特别是,我们将部署一个预训练的DenseNet 121模
型来检测图像。

备注:
可在GitHub上获取本文用到的完整代码

这是在生产中部署PyTorch模型的系列教程中的第一篇。到目前为止,以这种方式使用Flask是开始为PyTorch模型提供服务的最简单方法,
但不适用于具有高性能要求的用例。因此:

1.定义API

我们将首先定义API端点、请求和响应类型。我们的API端点将位于/ predict,它接受带有包含图像的file参数的HTTP POST请求。响应
将是包含预测的JSON响应:

{"class_id": "n02124075", "class_name": "Egyptian_cat"}

2.依赖(包)

运行下面的命令来下载我们需要的依赖:

$ pip install Flask==1.0.3 torchvision-0.3.0

3.简单的Web服务器

以下是一个简单的Web服务器,摘自Flask文档

from flask import Flask
app = Flask(__name__)


@app.route('/')
def hello():
    return 'Hello World!'

将以上代码段保存在名为app.py的文件中,您现在可以通过输入以下内容来运行Flask开发服务器:

$ FLASK_ENV=development FLASK_APP=app.py flask run

当您在web浏览器中访问http://localhost:5000/时,您会收到文本Hello World的问候!

我们将对以上代码片段进行一些更改,以使其适合我们的API定义。首先,我们将重命名predict方法。我们将端点路径更新为/predict
由于图像文件将通过HTTP POST请求发送,因此我们将对其进行更新,使其也仅接受POST请求:

@app.route('/predict', methods=['POST'])
def predict():
    return 'Hello World!'

我们还将更改响应类型,以使其返回包含ImageNet类的id和name的JSON响应。更新后的app.py文件现在为:

from flask import Flask, jsonify
app = Flask(__name__)

@app.route('/predict', methods=['POST'])
def predict():
    return jsonify({
   'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})

4.推理

在下一部分中,我们将重点介绍编写推理代码。这将涉及两部分,第一部分是准备图像,以便可以将其馈送到DenseNet;第二部分,我们将编
写代码以从模型中获取实际的预测。

4.1 准备图像

DenseNet模型要求图像为尺寸为224 x 224的 3 通道RGB图像。我们还将使用所需的均值和标准偏差值对图像张量进行归一化。你可以点击
这里来了解更多关于它的内容。

我们将使用来自torchvision库的transforms来建立转换管道,该转换管道可根据需要转换图像。您可以在此处
阅读有关转换的更多信息。

import io

import torchvision.transforms as transforms
from PIL import Image

def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)

上面的方法以字节为单位获取图像数据,应用一系列变换并返回张量。要测试上述方法,请以字节模式读取图像文件(首先将…/_static/img/
sample_file.jpeg替换为计算机上文件的实际路径),然后查看是否获得了张量:

with open("../_static/img/sample_file.jpeg", 'rb'
  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值