pytorch部署新利器TorchServer

pytorch的爬坑指南

排坑不易转载请注明出处!
参考文档 https://github.com/pytorch/serve
由于本地下载pytorch==1.7有问题【刚又瞅一眼现在是需要1.6?】,所以采用了docker部署

上干货:

1.docker版本不能太低,我装的19.03.13
2.下载项目文件
git clone https://github.com/pytorch/serve.git
cd serve/docker
3.构建docker image(cpu版本)
DOCKER_BUILDKIT=1 docker build --file Dockerfile -t torchserve:latest .

或者

docker pull pytorch/torchserve:latest

可用所有tags https://hub.docker.com/r/pytorch/torchserve/tags

4.将训练好的模型及环境打包
# 加载模型
checkpoint = torch.load(checkpoint_dir)
bertconfig = BertConfig(vocab_size=int(vocab_size), num_hidden_layers=3)
model = Bert_Sentiment_Analysis(config=bertconfig)
# 校验
model.eval()
# 加载参数
model.load_state_dict(checkpoint["model_state_dict"])
# model input sample
texts_tokens_ = torch.randint(0, 100, (1,127))
positional_enc = torch.randn((1, 127, 384), dtype=torch.float32)
# 打包并保存
traced_script_module = torch.jit.trace(model, (texts_tokens_, positional_enc))
traced_script_module.save("sentiment_test.pt")
5.完成handle.py文件对接torchServer

注意 preprocess方法接收的data[0].get(“data”)数据类型为bytes

# 模板
class ModelHandler(BaseHandler):
    """
    A custom model handler implementation.
    """

    def __init__(self):
        self._context = None
        self.initialized = False

    def initialize(self, context):
        """
        Initialize model. This will be called during model loading time
        :param context: Initial context contains model server system properties.
        :return:
        """
        self._context = context
        self.initialized = True
        properties = context.system_properties
        
        #  load the model
        self.manifest = context.manifest
        model_dir = properties.get("model_dir")
        self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")

        # Read model serialize/pt file
        serialized_file = self.manifest['model']['serializedFile']
        model_pt_path = os.path.join(model_dir, serialized_file)
        if not os.path.isfile(model_pt_path):
            raise RuntimeError("Missing the model.pt file")

        self.model = torch.jit.load(model_pt_path)
        self.model.to(self.device)
		
		...
		
        self.initialized = True

    def preprocess(self, data: list):
        """
        Transform raw input into model input data.
        :param batch: list of raw requests, should match batch size
        :return: list of preprocessed model input data
        """
        # Take the input data and make it inference ready
        text = data[0].get("data") or data[0].get("body")
        # 异常判断
        if text is None:
            warnings.warn("data params is none")
            raise Exception("no data")
        else:
	        text = text.decode()
        # 预处理, 获取batch
        ...

    def inference(self, texts_tokens_, positional_enc):
        """
        Internal inference methods
        :param model_input: transformed model input data
        :return: list of inference output in NDArray
        """
        # Do some inference call to engine here and return output
        predictions = self.model.forward(texts_tokens_, positional_enc)
		...

    def postprocess(self, inference_output):
        """
        Return inference result.
        :param inference_output: list of inference output
        :return: list of predict results
        """
        # Take output from network and post-process to desired format
        postprocess_output = inference_output
        return postprocess_output

    def handle(self, data, context):
        """
        Invoke by TorchServe for prediction request.
        Do pre-processing of data, prediction using model and postprocessing of prediciton output
        :param data: Input data for prediction
        :param context: Initial context contains model server system properties.
        :return: prediction output
        """
        self.preprocess(data)
        self.inference()
		self.postprocess(model_output)
		...


service = ModelHandler()

def handle(data, context):
    if not service.initialized:
        service.initialize(context)
    if data is None:
        return None
    return service.handle(data, context)
6.打包模型预测需要的文件
torch-model-archiver --model-name sentiment_test --version 1.0 --serialized-file /home/model-server/model-store/sentiment_test.pt \
--export-path /home/model-server/model-store \
--extra-files  /home/model-server/model-store/bert_word2idx.json \
--handler model_handler:handle -f
--model-name: 模型的名称,后来的接口名称和管理的模型名称都是这个
--serialized-file: 模型环境及代码及参数的打包文件
--export-path: 本次打包文件存放位置
--extra-files: handle.py中需要使用到的其他文件
--handler: 指定handler函数。(模型名:函数名)
-f 覆盖之前导出的同名打包文件

执行完会发现在/home/model-server/model-store目下多了一个以.mar结尾的文件,这个就是我们要在模型服务中使用的最终的打包文件

7.启动docker服务

将.mar文件放到宿主机的/home/model-server/model-store目录下

docker run --rm -it -p 3000:8080 -p 3001:8081 --name sentiment_test \
 -v /home/model-server/model-store:/home/model-server/model-store \
 torchserve:latest
# docker 可选参数 --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 \
# 切换到后台
ctrl + p 
8.torchserver接口
8.1模型管理相关接口
# 注冊模型并为模型分配资源
curl -v -X POST "http://localhost:3001/models?initial_workers=1&synchronous=false&url=sentiment_test.mar&batch_size=8&max_batch_delay=200"

# 修改分配worker数量
curl -v -X PUT "http://localhost:3001/models/sentiment_test?min_worker=3"

# 查看指定模型当前状态
curl http://192.168.5.135:3001/models/sentiment_test
8.2模型预测接口(handler.py文件中的逻辑)
curl -X POST http://localhost:3000/predictions/sentiment_test -d "data=这也太难吃了把?再也不来了"
9.进入docker 镜像停止或运行程序
# 进入docker容器
docker exec -it [容器名称] /bin/bash

# 停止服务
torchserve --stop

# 启动服务
torchserve --start --ncs --model-store /home/model-server/model-store --models sentiment_test.mar
评论 30
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

岳大博

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值