# -*- encoding: utf-8 -*-
"""
@File : flask_torch.py
@Time : 2020/07/12 11:59
@Author : Johnson
@Email : 593956670@qq.com
"""
import io
import json
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms as T
from torchvision.models import resnet50
#初始化flask应用
app = flask.Flask(__name__)
model = None
use_gpu = True
with open("class.txt",'r') as f:
idx2label = eval(f.read())
def load_model():
"""load the pre-trained model,you can used your model just as easy"""
global model
model = resnet50(pretrained=True)
model.eval()
if use_gpu:
model.cuda()
def prepare_image(image,target_size):
'''
对图片进行预处理
'''
if image.mode!="RGB":
image = image.convert("RGB")
#resize the image
image = T.resize(target_size)(image)
image = T.toTensor()(image)
#转化为Tensor格式和归一化处理
image = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)
#add batch_size axis
image = image[None]
if use_gpu:
image = image.cuda()
return torch.autograd.Variable(image,volatile=True)
@app.route("/predict",methods=["POST"])
def predict():
#initialize the data dic. that will be retured from the view
data = {"success",False}
#ensure the image was properly uploaded to out endpoint
if flask.request.method=="POST":
#read the image in PIL Image
image = flask.request.files["image"].read()
image = Image.open(io.BytesIO(image))
#process the image and prepare it for classification
image = prepare_image(image,target_size=(224,224))
#预测
preds = F.softmax(model(image),dim=1)
results = torch.topk(preds.cpu().data,k=1,dim=1)
#
data["predictions"] = list()
# Loop over the results and add them to the list of returned predictions
for prob, label in zip(results[0][0], results[1][0]):
prob = float(prob.item())
label = int(label.item())
label_name = idx2label[label]
r = {"label": label_name, "probability": float(prob)}
data['predictions'].append(r)
# Indicate that the request was a success.
data["success"] = True
# Return the data dictionary as a JSON response.
return flask.jsonify(data)
if __name__ == '__main__':
print("Loading PyTorch model and Flask starting server ...")
print("Please wait until server has fully started")
load_model()
app.run(debug=True)
0009-flask调用pytorch模型
最新推荐文章于 2023-06-18 08:30:00 发布