torchserve:https://github.com/pytorch/serve
教程地址:https://towardsdatascience.com/deploy-models-and-create-custom-handlers-in-torchserve-fc2d048fbe91
总结:
1.java 11,torchserver,torch-model-archiver安装
2. 默认以及自定义handlers
3. 模型打包生成.mar
4. 使用docker提供模型服务
densenet161官网教程
1.安装java11
sudo apt-get install openjdk-11-jdk
2.安装torchserve和torch-model-archiver
pip install torchserve torch-model-archiver
3.新建文件夹:
mkdir model_store
4.下载预训练模型:
wget https://download.pytorch.org/models/densenet161-8d451a50.pth
5.转换成.mar文件:
torch-model-archiver --model-name densenet161 --version 1.0 --model-file ./serve/examples/image_classifier/densenet_161/model.py --serialized-file densenet161-8d451a50.pth --export-path model_store --extra-files ./serve/examples/image_classifier/index_to_name.json --handler image_classifier
6.部署服务
torchserve --start --ncs --model-store model_store --models densenet161.mar
7.模型调用
curl -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg
curl http://127.0.0.1:8080/predictions/densenet161 -T kitten.jpg
docker部署:resnet34教程
自定义Handlers
MyHandler.py
import logging
import torch
import torch.nn.functional as F
import io
from PIL import Image
from torchvision import transforms
from ts.torch_handler.base_handler import BaseHandler
class MyHandler(BaseHandler):
def __init__(self,*args,**kwargs):
super().__init__()
self.transfrom = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225])
])
#预处理
def preprocess_one_image(self,req):
image = req.get("data")
if image is None:
image = req.get("body")
image = Image.open(io.BytesIO(image))
image = self.transfrom(image)
image = image.unsqueeze(0)
return image
def prepprocess(self,requests):
images = [self.preprocess_one_image(req) for req in requests]
images = torch.cat(images)
return images
#推理
def inference(self,x):
outs = self.model.forward(x)
probs = F.softmax(outs,dim=1)
preds = torch.argmax(probs,dim=1)
return preds
#后处理
def postprocess(self,preds):
res = []
preds = preds.cpu().tolist()
for pred in preds:
label = self.mapping[str(pred)][1]
res.append({"label":label,"index":pred})
return res
my_handler.py
该函数调用自定义handler中的方法
from MyHandler import MyHandler
_service = MyHandler()
def handle(data,context):
if not _service.initialized:
_service.initialize(context)
if data is None:
return None
data = _service.prepprocess(data)
data = _service.inference(data)
data = _service.postprocess(data)
return data
模型导出
resnet34
import torch
from torchvision.models import resnet34
model = resnet34(pretrained=True)
example_input = torch.rand(1,3,224,224)
model.eval()
traced_model = torch.jit.trace(model,example_input)
traced_model.save("resnet34.pt")
创建 .mar 文件
1.安装torch-model-archiver
git clone https://github.com/pytorch/serve.git
cd serve/model-archiver
pip install .
2.创建.mar
torch-model-archiver --model-name resnet34 \--version 1.0 \--serialized-file resnet34.pt \--extra-files ./index_to_name.json,./MyHandler.py \--handler my_handler.py \--export-path model-store -f
model-name定义了模型的最终名称
–serialized-file指向创建的存储的 .pt模型
–handler是一个python文件,在其中调用我们的自定义handler
–export-path是 .mar存放的地方
-f 覆盖原有的文件
–extra-files传递index_to_name.json文件,它将自动加载到handler ,通过self.mapping访问返回label
模型部署
serve-resnet34.sh
docker run -d --rm -it -p 4000:8080 -p 4001:8081 -v /home/ygwl/Project_Image/model-server/model-store:/home/model-server/model-store pytorch/torchserve:0.2.0-cpu torchserve --start --model-store model-store --models resnet34.mar
1.将容器端口8080和8081分别绑定到3000和3001(8080/8081已经在我的机器中使用)。
2.pytorch/torchserve:0.2.0-cpu会自动拉取镜像
服务调用预测
- curl调用
curl -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg
curl -X POST http://127.0.0.1:3000/predictions/resnet34 -T kitten.jpg
结果:
{
"label": "tiger_cat",
"index": 282
}
2.py代码调用
import json
import requests
image = open(r"E:\pycharm_project\Data-proessing\torch-serve\kitten.jpg",'rb')
image2 = open(r"E:\pycharm_project\Data-proessing\torch-serve\kitten.jpg",'rb')
payload = {'data':image}
payload2 = {'data':image2}
response = requests.post('http://172.20.112.102:4000/predictions/resnet34',files=payload).json()
print(response)
r = requests.post("http://172.20.112.102:3000/predictions/densenet161", files=payload2).json()
print(r)
结果:
{'label': 'tiger_cat',
'index': 282}
{'tiger_cat': 0.4693359136581421, 'tabby': 0.4633873701095581, 'Egyptian_cat': 0.06456154584884644, 'lynx': 0.001282821292988956, 'plastic_bag': 0.00023323031200561672}