【pytorch】将训练好的模型部署至生产环境:借助Flask框架及gunicorn(含tensor与json的转化)

89 篇文章 11 订阅
65 篇文章 3 订阅

(一)待训练模型采用
CIFAR10,10分类
按上述源码训练后得到模型参数文件:saveTextOnlyParams.pth
下面是model.py的源代码:

import os.path
from typing import Iterator
import numpy as np
import torch
import cv2
from PIL import Image
from torch.utils.data import Dataset,DataLoader,Subset,random_split
import re
from functools import reduce
from torch.utils.tensorboard import SummaryWriter as Writer
from torchvision import transforms,datasets
import torchvision as tv
from torch import nn
import torch.nn.functional as F
import time

#查看命令:tensorboard --logdir=./myBorderText
#可用pycharm中code中的generater功能实现:
def rgb2NetInput(input):
    '''这里将读入的原始图片width, height,3转为torch.Size([BATCH, 3, 32, 32])的格式'''
    inputX=torch.FloatTensor(input)
    inputX=inputX.permute(2,0,1).contiguous()
    inputX=inputX.unsqueeze(0)
    return inputX
class myCustomerNetWork(nn.Module):
    def __init__(self):
        super().__init__()
        #输入3通道输出6通道:
        self.features=nn.Sequential(nn.Conv2d(3, 64, (3, 3)),nn.ReLU(),nn.Conv2d(64,128,(3,3)),
                                    nn.ReLU(),nn.Conv2d(128,256,(3,3)),nn.ReLU(),nn.AdaptiveAvgPool2d(1))

        self.classfired=nn.Sequential(nn.Flatten(),nn.Linear(256,80),nn.Dropout(),nn.Linear(80,10))

    def forward(self,x):
        return self.classfired(self.features(x))
#网络输入要求为torch.Size([32, 3, 32, 32])格式

myNet=myCustomerNetWork()
pthfile = r'D:\flask_pytorch\saveTextOnlyParams.pth'
myNet.load_state_dict(torch.load(pthfile))
if torch.cuda.is_available():
    myNet=myNet.cuda()
myNet.eval()

def run(input):
    with torch.no_grad():
        return myNet(rgb2NetInput(input).cuda())

#测试用代码:
if __name__ == '__main__':
    imagePath=r"C:\Users\25360\Desktop\monodepth.jpeg"
    img = cv2.imdecode(np.fromfile(imagePath, np.uint8), -1)
    img=cv2.resize(img,(32,32))
    # bgr转rgb
    img = img[:, :, ::-1].copy()
    print(img.shape)
    print(run(img))

测试输出:

(32, 32, 3)
tensor([[  71.6224,   10.6505,  165.0648,  313.5768, -148.1144,  329.7959,
          109.9136, -266.1085, -171.0974, -272.6216]], device='cuda:0')

(二)实现思路为:
使用Flask+Gunicorn ,注意Gunicorn只能在linux下使用。
Flask 内置 WebServer + Flask App = 弱鸡版本的 Server, 单进程(单 worker) / 失败挂掉 / 不易 Scale
Gunicorn + Flask App = 多进程(多 worker) / 多线程 / 失败自动帮你重启 Worker / 可简单Scale
flask自带的服务器,不是生产级别的服务器。
如有需要,还可扩展Nginx:
多 Nginx + 多 Gunicorn + Flask App = 小型多实例 Web 应用在这里插入图片描述
app.py中的内容为:

from flask import Flask
from flask import request,jsonify,make_response
import numpy as np
from json import dumps
app = Flask(__name__)
import model
@app.route('/pic',methods={'POST'})
def predictPic():
    json = request.get_json()
    input=json.get('content')
    #list重新转为numpy的arrary:
    img = np.asarray(input)
    resultTensor=model.run(input)
    #下面将结果的tensor以json格式返回;
    result = resultTensor.tolist()
    # 字典形式保存数组
    result_dict = {}
    result_dict['result'] = result
    # 保存为json格式:
    json_data = jsonify(result_dict)
    return make_response(result_dict, 200)
if __name__ == '__main__':
    app.run()

编写程序读取程序,模拟发送post请求:

import os.path
from typing import Iterator
import numpy as np
import torch
import cv2
from PIL import Image
from torch.utils.data import Dataset,DataLoader,Subset,random_split
import re
from functools import reduce
from torch.utils.tensorboard import SummaryWriter as Writer
from torchvision import transforms,datasets
import torchvision as tv
from torch import nn
import torch.nn.functional as F
import time
from json import dumps
import requests

# 通过opencv读取图片
imagePath = r"C:\Users\25360\Desktop\monodepth.jpeg"
img = cv2.imdecode(np.fromfile(imagePath, np.uint8), -1)
img = cv2.resize(img, (32, 32))
# bgr转rgb
img = img[:, :, ::-1].copy()
# numpy中ndarray文件转为list
assert img is not None, "image did not read success!"
#将size(32,32,3)转换为同形list
img_list = img.tolist()
# 字典形式保存数组
img_dict = {}
img_dict['content'] = img_list
# 保存为json格式
json_data = dumps(img_dict, indent=2)
#使用post发生json:
url = r"http://127.0.0.1:5000/pic"
aheaders = {'Content-Type': 'application/json'}
res = requests.post(url=url,headers=aheaders,data=json_data)
print(res.json())

返回结果为:

{'result': [[71.62239837646484, 10.650474548339844, 165.06483459472656, 313.5767517089844, -148.11439514160156, 329.7958984375, 109.9135513305664, -266.1084899902344, -171.09744262695312, -272.6215515136719]]}

截止到现在,是使用flask自带web服务器完成部署,下面引入gunicorn:

gunicorn可以通过gunicorn -w 4 -b 127.0.0.1:5000 app:app启动一个Flask应用。其中,

-w 4是指预定义的工作进程数为4-b 127.0.0.1:5000指绑定地址和端口
run是flask的启动python文件,app则是flask应用程序实例

注:第一个app为flask项目实例所在的包(app.py),第二个app为生成的flask项目实例(app = Flask(name))

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
训练好的PyTorch模型部署到Django应用中需要以下步骤: 1. 在Django应用中创建一个view(视图),该视图将接收请求并返回模型的预测结果。 2. 加载训练好的PyTorch模型。在这个过程中,需要确保模型的权重文件和模型文件都被正确加载。 3. 处理请求数据。在这个过程中,需要将请求数据与模型期望的数据格式进行匹配。可以使用PyTorch的Transforms和Datasets功能来实现这个过程。 4. 运行模型并获取预测结果。在这个过程中,需要将请求数据传递给模型并获取预测结果。可以使用PyTorch的forward方法来实现这个过程。 5. 返回预测结果。在这个过程中,需要将预测结果格式化为JSON响应,并将其返回给请求方。 以下是一个简单的Django视图,用于加载并使用PyTorch模型进行预测: ```python import torch import torchvision.transforms as transforms from django.http import JsonResponse from django.views.decorators.csrf import csrf_exempt @csrf_exempt def predict(request): if request.method == 'POST': # 加载模型 model = torch.load('model.pth') model.eval() # 处理请求数据 image = request.FILES.get('image') image_tensor = transforms.ToTensor()(image).unsqueeze_(0) # 运行模型并获取预测结果 output = model(image_tensor) _, predicted = torch.max(output.data, 1) prediction = predicted.item() # 返回预测结果 return JsonResponse({'prediction': prediction}) ``` 在这个例子中,我们假设模型文件为'model.pth',请求数据包一个名为'image'的文件。我们首先加载模型,然后使用PyTorch的transforms将请求数据转换为模型期望的格式。接下来,我们将数据传递给模型并获取预测结果,最后将结果格式化为JSON响应并返回。需要注意的是,我们使用了Django的csrf_exempt装饰器来禁用CSRF保护,以便我们可以在没有CSRF令牌的情况下测试视图。在生产环境中,应该启用CSRF保护来确保应用程序的安全性。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

颢师傅

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值