在本教程中,我们将使用Flask来部署PyTorch模型,并用讲解用于模型推断的 REST API。特别是,我们将部署一个预训练的DenseNet 121模
型来检测图像。
备注:
可在GitHub上获取本文用到的完整代码
这是在生产中部署PyTorch模型的系列教程中的第一篇。到目前为止,以这种方式使用Flask是开始为PyTorch模型提供服务的最简单方法,
但不适用于具有高性能要求的用例。因此:
- 如果您已经熟悉TorchScript,则可以直接进入我们的Loading a TorchScript Model in C 教程。
- 如果您首先需要复习TorchScript,请查看我们的Intro a TorchScript教程。
1.定义API
我们将首先定义API端点、请求和响应类型。我们的API端点将位于/ predict
,它接受带有包含图像的file
参数的HTTP POST请求。响应
将是包含预测的JSON响应:
{"class_id": "n02124075", "class_name": "Egyptian_cat"}
2.依赖(包)
运行下面的命令来下载我们需要的依赖:
$ pip install Flask==1.0.3 torchvision-0.3.0
3.简单的Web服务器
以下是一个简单的Web服务器,摘自Flask文档
from flask import Flask
app = Flask(__name__)
@app.route('/')
def hello():
return 'Hello World!'
将以上代码段保存在名为app.py
的文件中,您现在可以通过输入以下内容来运行Flask开发服务器:
$ FLASK_ENV=development FLASK_APP=app.py flask run
当您在web浏览器中访问http://localhost:5000/
时,您会收到文本Hello World
的问候!
我们将对以上代码片段进行一些更改,以使其适合我们的API定义。首先,我们将重命名predict
方法。我们将端点路径更新为/predict
。
由于图像文件将通过HTTP POST请求发送,因此我们将对其进行更新,使其也仅接受POST请求:
@app.route('/predict', methods=['POST'])
def predict():
return 'Hello World!'
我们还将更改响应类型,以使其返回包含ImageNet类的id和name的JSON响应。更新后的app.py
文件现在为:
from flask import Flask, jsonify
app = Flask(__name__)
@app.route('/predict', methods=['POST'])
def predict():
return jsonify({
'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})
4.推理
在下一部分中,我们将重点介绍编写推理代码。这将涉及两部分,第一部分是准备图像,以便可以将其馈送到DenseNet;第二部分,我们将编
写代码以从模型中获取实际的预测。
4.1 准备图像
DenseNet模型要求图像为尺寸为224 x 224的 3 通道RGB图像。我们还将使用所需的均值和标准偏差值对图像张量进行归一化。你可以点击
这里来了解更多关于它的内容。
我们将使用来自torchvision
库的transforms
来建立转换管道,该转换管道可根据需要转换图像。您可以在此处
阅读有关转换的更多信息。
import io
import torchvision.transforms as transforms
from PIL import Image
def transform_image(image_bytes):
my_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
image = Image.open(io.BytesIO(image_bytes))
return my_transforms(image).unsqueeze(0)
上面的方法以字节为单位获取图像数据,应用一系列变换并返回张量。要测试上述方法,请以字节模式读取图像文件(首先将…/_static/img/
sample_file.jpeg替换为计算机上文件的实际路径),然后查看是否获得了张量:
with open("../_static/img/sample_file.jpeg", 'rb'