0009-flask调用pytorch模型

# -*- encoding: utf-8 -*-
"""
@File    : flask_torch.py
@Time    : 2020/07/12 11:59
@Author  : Johnson
@Email   : 593956670@qq.com
"""
import io
import json
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms as T
from torchvision.models import resnet50

#初始化flask应用
app = flask.Flask(__name__)
model = None
use_gpu = True

with open("class.txt",'r') as f:
    idx2label = eval(f.read())


def load_model():
    """load the pre-trained model,you can used your model just as easy"""
    global model
    model = resnet50(pretrained=True)
    model.eval()
    if use_gpu:
        model.cuda()


def prepare_image(image,target_size):
    '''
    对图片进行预处理
    '''
    if image.mode!="RGB":
        image = image.convert("RGB")

    #resize the image
    image = T.resize(target_size)(image)
    image = T.toTensor()(image)

    #转化为Tensor格式和归一化处理
    image = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)

    #add batch_size axis
    image = image[None]
    if use_gpu:
        image = image.cuda()

    return torch.autograd.Variable(image,volatile=True)

@app.route("/predict",methods=["POST"])
def predict():
    #initialize the data dic. that will be retured from the view
    data = {"success",False}

    #ensure the image was properly uploaded to out endpoint
    if flask.request.method=="POST":
        #read the image in PIL Image
        image = flask.request.files["image"].read()
        image = Image.open(io.BytesIO(image))

        #process the image and prepare it for classification
        image = prepare_image(image,target_size=(224,224))

        #预测
        preds = F.softmax(model(image),dim=1)
        results = torch.topk(preds.cpu().data,k=1,dim=1)

        #
        data["predictions"] = list()

        # Loop over the results and add them to the list of returned predictions
        for prob, label in zip(results[0][0], results[1][0]):
            prob = float(prob.item())
            label = int(label.item())
            label_name = idx2label[label]
            r = {"label": label_name, "probability": float(prob)}
            data['predictions'].append(r)

        # Indicate that the request was a success.
        data["success"] = True

        # Return the data dictionary as a JSON response.
    return flask.jsonify(data)

if __name__ == '__main__':
    print("Loading PyTorch model and Flask starting server ...")
    print("Please wait until server has fully started")
    load_model()
    app.run(debug=True)
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值