flask post json_想要部署深度学习模型?试试 FLASK 构建 REST API 部署

想必大家都训练出过比较好玩的模型,但是是不是想要向别人提供下接口或者自己试着玩下,这时候就需要涉及到部署模型了,这里,我们将使用 Flask 部署 PyTorch 模型,并构建用于模型推理的REST API。

89592be9792514ea67ae3ba5fdc5e5d4.png

要注意的是:使用 Flask 是为 PyTorch 模型提供服务的最简单方法,但不适用于具有高性能要求的场景。

对高性能有要求的场景,可以使用 TorchScript,下次再说。

环境安装:

pip install Flask==1.0.3 torch==1.2.0 torchvision-0.3.0

```

假设我们的场景是上传图片进行返回图片的分类结果,那么我们定义下 API 形式,请求和响应类型。

将 API endpoint 将位于 /predict,接受带有包含图像的文件参数的 HTTP POST 请求。响应将是包含预测结果的 JSON 响应:

{"class_id": "xx", "class_name": "yy"}

首先先复习下,构建一个简单的 Web 服务器

```

from flask import Flaskapp = Flask(__name__)@app.route('/')def hello():   return 'welcome to  http://towardsdeeplearning.com !'

```

运行

```

FLASK_ENV=development FLASK_APP=app.py flask run

访问 http://localhost:5000/ 可以看到 welcome to  http://towardsdeeplearning.com !

可以查看 flask 文档,熟悉下 post。为了符合上边 api 的定义,我们需要修改下代码:

from flask import Flask, jsonifyapp = Flask(__name__)@app.route('/predict', methods=['POST'])def predict():   return jsonify({'class_id':'IMAGE_NET_XXX','class_name':'Cat'})

`

到此,骨干网络已经搭建完毕。

还缺少什么呢?上边这个是返回的json是写死的,但是实际上要根据 post 的图片进行预测。

图片通过 HTTP POST 请求传递过来, 可以通过下面这个方式获取

@app.route('/predict', methods=['POST'])defpredict():   if request.method =='POST':       # we will get the file from the request       file = request.files['file']

搭建下预测的代码,这里使用了 mnasnet ,可以在 torchvision 导入预训模型。mnasnet 的输入图片是 3 通道的 RGB 模型,大小为 224 x 224。

其实熟悉 pytorch 的同学应该很容易写出前向预测的代码的。

import ioimport torchvision.transforms as transformsfromPILimport Imagedef transform_image(image_bytes):   my_transforms = transforms.Compose([transforms.Resize(255),                                       transforms.CenterCrop(224),                                       transforms.ToTensor(),                                       transforms.Normalize(                                           [0.485,0.456,0.406],                                           [0.229,0.224,0.225])])   # 接收的图片是 bytes 转成图片格式,再进行转换    image = Image.open(io.BytesIO(image_bytes))    return my_transforms(image).unsqueeze(0)from torchvision import modelsmodel = models.mnasnet1_0(pretrained=True)model.eval()def predict(image_bytes):   tensor = transform_image(image_bytes=image_bytes)   outputs = model.forward(tensor)   _, pred = outputs.max(1)   return pred

predict 的结果是类别的id,为了方便显示,我们需要进行转成文字, 就是具体的类别,狗狗啊这样人类可读性好的。

import jsonimagenet_class_index = json.load(open('imagenet_class_index.json'))def predict(image_bytes):   tensor = transform_image(image_bytes=image_bytes)   outputs = model.forward(tensor)   _, y_hat = outputs.max(1)   predicted_idx =str(y_hat.item())   return imagenet_class_index[predicted_idx]

最后,整理的代码如下

import ioimport jsonimport torchvision.transforms as transformsfrom PIL import Imagefrom flask import Flask, jsonify, requestfrom torchvision import modelsapp = Flask(__name__)imagenet_class_index = json.load(open('./imagenet_class_index.json',"r"))model = models.mnasnet1_0(pretrained=True)model.eval()def transform_image(image_bytes):   my_transforms = transforms.Compose([transforms.Resize(255),                                       transforms.CenterCrop(224),                                       transforms.ToTensor(),                                       transforms.Normalize(                                           [0.485,0.456,0.406],                                           [0.229,0.224,0.225])])   image = Image.open(io.BytesIO(image_bytes))   return my_transforms(image).unsqueeze(0)def get_prediction(image_bytes):   tensor = transform_image(image_bytes=image_bytes)   outputs = model.forward(tensor)   _, y_hat = outputs.max(1)   predicted_idx =str(y_hat.item())   return imagenet_class_index[predicted_idx]@app.route('/predict', methods=['POST'])def predict():   if request.method =='POST':       file = request.files['file']       img_bytes = file.read()       class_id, class_name = get_prediction(image_bytes=img_bytes)       return jsonify({'class_id': class_id,'class_name': class_name})if __name__ =='__main__':   app.run()

使用下面的命令运行

FLASK_ENV=development FLASK_APP=app.py flask run

使用下面的测试代码,进行测试。

import requestsresp = requests.post("http://localhost:5000/predict",                    files={"file":open('dog.jpg','rb')})print( resp.json()# {"class_id": "xx", "class_name": "xx"}

完。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值