基于Flask Web框架提供Pytorch 模型在线服务

本章节中,我们将使用Flask 部署一个Pytorch模型,并为模型预测提供一个REST API 接口。下面,我们部署一个预训练好的模型DenseNet 121,该模型用于检测图片

import io
import json

from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request

app = Flask(__name__)

# 加载lable和类别名称关系
imagenet_class_index = json.load(open('imagenet_class_index.json'))

图片特征提取

def transform_image(image_bytes):
    """
        DenseNet model requires the image to be of 3 channel RGB image of size 224 x 224.
        下面对原始图片的预处理

        1. transforms.Resize 改变原始图片的大小

        2.transforms.CenterCrop 生成一个CenterCrop类的对象,用来将图片从中心裁剪成224*224
        将给定的PIL.Image进行中心切割,得到给定的size,size可以是tuple,(target_height, target_width)。size也可以是一个Integer,
        在这种情况下,切出来的图片的形状是正方形。

        3. transforms.ToTensor 转为tensor,在GPU上运行

        4. transforms.Normalize 参数处理功能描述:图片标准化处理
        We will also normalise the image tensor with the required mean and standard deviation values.
        You can read more about it here.
    """
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])

    # 返回 PIL.Image.Image 对象
    image = Image.open(io.BytesIO(image_bytes))
    px = my_transforms(image).unsqueeze(0) # 通过unsqueeze(0) 后:torch.Size([3, 224, 224])->torch.Size([1, 3, 224, 224])

    return px

接下来,我们来提取一张图片的特征,观察下返回的tensor 的结果

with open("cat_pic.jpeg", 'rb') as f:
    image_bytes = f.read()
    tensor = transform_image(image_bytes=image_bytes)
    print(tensor.shape) # torch.Size([1, 3, 224, 224])  表示: 一张图片+3个渠道+长度224+宽度224 的数组
    print(tensor)
torch.Size([1, 3, 224, 224])
tensor([[[[-0.6109, -0.5424, -0.4568,  ..., -1.6727, -1.6898, -1.7240],
          [-0.5596, -0.4397, -0.3883,  ..., -1.7240, -1.7583, -1.7754],
          [-0.5253, -0.3883, -0.3369,  ..., -1.7583, -1.7925, -1.7925],
          ...,
          [ 0.9132,  0.7591,  0.6221,  ...,  1.8722,  1.9235,  1.9749],
          [ 0.8104,  0.7077,  0.3481,  ...,  1.8550,  1.8550,  1.8722],
          [ 0.3481, -0.0116, -0.3883,  ...,  1.8722,  1.8550,  1.8379]],

         [[-0.4951, -0.4426, -0.3725,  ..., -1.2654, -1.3004, -1.3004],
          [-0.4076, -0.3375, -0.2850,  ..., -1.3354, -1.3880, -1.4055],
          [-0.3725, -0.2850, -0.2500,  ..., -1.3880, -1.4580, -1.4755],
          ...,
          [ 0.2227,  0.0126, -0.2150,  ...,  1.7283,  1.7983,  1.8859],
          [-0.0049, -0.1275, -0.4076,  ...,  1.7108,  1.6933,  1.6933],
          [-0.2500, -0.5301, -0.7752,  ...,  1.7108,  1.6933,  1.6758]],

         [[-1.0027, -0.9504, -0.8981,  ..., -1.3861, -1.4036, -1.4210],
          [-0.9156, -0.8807, -0.8284,  ..., -1.4384, -1.4384, -1.4559],
          [-0.9330, -0.8807, -0.8110,  ..., -1.4733, -1.4907, -1.5081],
          ...,
          [-0.3753, -0.5670, -0.7587,  ...,  1.5942,  1.7163,  1.9080],
          [-0.6018, -0.7238, -0.8981,  ...,  1.5420,  1.5071,  1.5245],
          [-0.7413, -0.8633, -1.0201,  ...,  1.5420,  1.4897,  1.4722]]]])

在线服务预测

from torchvision import models

# Make sure to pass `pretrained` as `True` to use the pretrained weights:
model = models.densenet121(pretrained=True)
# Since we are using our model only for inference, switch to `eval` mode:
model.eval()

# 加载lable和类别名称关系
imagenet_class_index = json.load(open('imagenet_class_index.json'))

def get_prediction(image_bytes):
    # 数据预处理-图片特征提取
    tensor = transform_image(image_bytes=image_bytes)
    # 模型预测
    outputs = model.forward(tensor)
    # 输出结果可能性最大的一个数值
    _, y_hat = outputs.max(1)
    # tensor 转为一个数值类型数据
    predicted_idx = str(y_hat.item())
    # 获取名称和index
    return imagenet_class_index[predicted_idx]
with open("cat_pic.jpeg", 'rb') as f:
    image_bytes = f.read()
    print(get_prediction(image_bytes=image_bytes))
['n02127052', 'lynx']

基于Flask 提供 API Server

import io
import json

from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request

app = Flask(__name__)

# 加载lable和类别名称关系
imagenet_class_index = json.load(open('imagenet_class_index.json'))

#服务器启动加载,
model = models.densenet121(pretrained=True)
#dropout and batch normalization layers to evaluation mode
model.eval()


def transform_image(image_bytes):
    """
        DenseNet model requires the image to be of 3 channel RGB image of size 224 x 224.
        下面对原始图片的预处理

        1. transforms.Resize 改变原始图片的大小

        2.transforms.CenterCrop 生成一个CenterCrop类的对象,用来将图片从中心裁剪成224*224
        将给定的PIL.Image进行中心切割,得到给定的size,size可以是tuple,(target_height, target_width)。size也可以是一个Integer,
        在这种情况下,切出来的图片的形状是正方形。

        3. transforms.ToTensor 转为tensor,在GPU上运行

        4. transforms.Normalize 参数处理功能描述:图片标准化处理
        We will also normalise the image tensor with the required mean and standard deviation values.
        You can read more about it here.
    """
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])

    # 返回 PIL.Image.Image 对象
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)


def get_prediction(image_bytes):
    # 数据预处理-图片特征提取
    tensor = transform_image(image_bytes=image_bytes)
    # 模型预测
    outputs = model.forward(tensor)
    # 输出结果可能性最大的一个数值
    _, y_hat = outputs.max(1)
    # tensor 转为一个数值类型数据
    predicted_idx = str(y_hat.item())
    # 获取名称和index
    return imagenet_class_index[predicted_idx]


@app.route('/')
def hello():
    return "Hello World!"


@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})


if __name__ == '__main__':

    app.run()
 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
   WARNING: Do not use the development server in a production environment.
   Use a production WSGI server instead.
 * Debug mode: off


 * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)

上述代码:app.py ,可以命令行执行


1. 启动服务
$ FLASK_ENV=development FLASK_APP=app.py flask run
 * Serving Flask app "app.py" (lazy loading)
 * Environment: development
 * Debug mode: on
 * Restarting with stat
 * Debugger is active!
 * Debugger PIN: 276-234-659
 * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)



2. 在线预测
curl -X POST -F file=@cat_pic.jpeg http://127.0.0.1:5000/predict
{
  "class_id": "n02127052",
  "class_name": "lynx"
}

参考资料

[1] 基于Flask 部署Pytorch模型

https://pytorch.org/tutorials/intermediate/flask_rest_api_tutorial.html?highlight=serving

[2] Pytorch 提供基于ImageNet预训练模型

https://pytorch.org/docs/stable/torchvision/models.html

[3] Pytorch 模型保持和加载

https://pytorch.org/tutorials/beginner/saving_loading_models.html

  • 6
    点赞
  • 68
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

艾文教编程

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值