想必大家都训练出过比较好玩的模型,但是是不是想要向别人提供下接口或者自己试着玩下,这时候就需要涉及到部署模型了,这里,我们将使用 Flask 部署 PyTorch 模型,并构建用于模型推理的REST API。
要注意的是:使用 Flask 是为 PyTorch 模型提供服务的最简单方法,但不适用于具有高性能要求的场景。
对高性能有要求的场景,可以使用 TorchScript,下次再说。
环境安装:
pip install Flask==1.0.3 torch==1.2.0 torchvision-0.3.0
```
假设我们的场景是上传图片进行返回图片的分类结果,那么我们定义下 API 形式,请求和响应类型。
将 API endpoint 将位于 /predict,接受带有包含图像的文件参数的 HTTP POST 请求。响应将是包含预测结果的 JSON 响应:
{"class_id": "xx", "class_name": "yy"}
首先先复习下,构建一个简单的 Web 服务器
```
from flask import Flaskapp = Flask(__name__)@app.route('/')def hello(): return 'welcome to http://towardsdeeplearning.com !'
```
运行
```
FLASK_ENV=development FLASK_APP=app.py flask run
访问 http://localhost:5000/ 可以看到 welcome to http://towardsdeeplearning.com !
可以查看 flask 文档,熟悉下 post。为了符合上边 api 的定义,我们需要修改下代码:
from flask import Flask, jsonifyapp = Flask(__name__)@app.route('/predict', methods=['POST'])def predict(): return jsonify({'class_id':'IMAGE_NET_XXX','class_name':'Cat'})
`
到此,骨干网络已经搭建完毕。
还缺少什么呢?上边这个是返回的json是写死的,但是实际上要根据 post 的图片进行预测。
图片通过 HTTP POST 请求传递过来, 可以通过下面这个方式获取
@app.route('/predict', methods=['POST'])defpredict(): if request.method =='POST': # we will get the file from the request file = request.files['file']
搭建下预测的代码,这里使用了 mnasnet ,可以在 torchvision 导入预训模型。mnasnet 的输入图片是 3 通道的 RGB 模型,大小为 224 x 224。
其实熟悉 pytorch 的同学应该很容易写出前向预测的代码的。
import ioimport torchvision.transforms as transformsfromPILimport Imagedef transform_image(image_bytes): my_transforms = transforms.Compose([transforms.Resize(255), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( [0.485,0.456,0.406], [0.229,0.224,0.225])]) # 接收的图片是 bytes 转成图片格式,再进行转换 image = Image.open(io.BytesIO(image_bytes)) return my_transforms(image).unsqueeze(0)from torchvision import modelsmodel = models.mnasnet1_0(pretrained=True)model.eval()def predict(image_bytes): tensor = transform_image(image_bytes=image_bytes) outputs = model.forward(tensor) _, pred = outputs.max(1) return pred
predict 的结果是类别的id,为了方便显示,我们需要进行转成文字, 就是具体的类别,狗狗啊这样人类可读性好的。
import jsonimagenet_class_index = json.load(open('imagenet_class_index.json'))def predict(image_bytes): tensor = transform_image(image_bytes=image_bytes) outputs = model.forward(tensor) _, y_hat = outputs.max(1) predicted_idx =str(y_hat.item()) return imagenet_class_index[predicted_idx]
最后,整理的代码如下
import ioimport jsonimport torchvision.transforms as transformsfrom PIL import Imagefrom flask import Flask, jsonify, requestfrom torchvision import modelsapp = Flask(__name__)imagenet_class_index = json.load(open('./imagenet_class_index.json',"r"))model = models.mnasnet1_0(pretrained=True)model.eval()def transform_image(image_bytes): my_transforms = transforms.Compose([transforms.Resize(255), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( [0.485,0.456,0.406], [0.229,0.224,0.225])]) image = Image.open(io.BytesIO(image_bytes)) return my_transforms(image).unsqueeze(0)def get_prediction(image_bytes): tensor = transform_image(image_bytes=image_bytes) outputs = model.forward(tensor) _, y_hat = outputs.max(1) predicted_idx =str(y_hat.item()) return imagenet_class_index[predicted_idx]@app.route('/predict', methods=['POST'])def predict(): if request.method =='POST': file = request.files['file'] img_bytes = file.read() class_id, class_name = get_prediction(image_bytes=img_bytes) return jsonify({'class_id': class_id,'class_name': class_name})if __name__ =='__main__': app.run()
使用下面的命令运行
FLASK_ENV=development FLASK_APP=app.py flask run
使用下面的测试代码,进行测试。
import requestsresp = requests.post("http://localhost:5000/predict", files={"file":open('dog.jpg','rb')})print( resp.json()# {"class_id": "xx", "class_name": "xx"}
完。