问答模型(十) torchserve部署

torchserve源码分析

父类basehandler

"""
Base default handler to load torchscript or eager mode [state_dict] models
Also, provides handle method per torch serve custom model specification
"""
import abc
import logging
import os
import importlib.util
import time
import torch

from ..utils.util import list_classes_from_module, load_label_mapping

logger = logging.getLogger(__name__)


class BaseHandler(abc.ABC):
    """
    Base default handler to load torchscript or eager mode [state_dict] models
    Also, provides handle method per torch serve custom model specification
    """

    def __init__(self):
        self.model = None
        self.mapping = None
        self.device = None
        self.initialized = False
        self.context = None
        self.manifest = None
        self.map_location = None
        self.explain = False
        self.target = 0

    def initialize(self, context):
        """Initialize function loads the model.pt file and initialized the model object.
	   First try to load torchscript else load eager mode state_dict based model.

        Args:
            context (context): It is a JSON Object containing information
            pertaining to the model artifacts parameters.

        Raises:
            RuntimeError: Raises the Runtime error when the model.py is missing

        """
        properties = context.system_properties
        self.map_location = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(
            self.map_location + ":" + str(properties.get("gpu_id"))
            if torch.cuda.is_available()
            else self.map_location
        )
        self.manifest = context.manifest

        model_dir = properties.get("model_dir")
        serialized_file = self.manifest["model"]["serializedFile"]
        model_pt_path = os.path.join(model_dir, serialized_file)

        if not os.path.isfile(model_pt_path):
            raise RuntimeError("Missing the model.pt file")

        # model def file
        model_file = self.manifest["model"].get("modelFile", "")

        if model_file:
            logger.debug("Loading eager model")
            self.model = self._load_pickled_model(model_dir, model_file, model_pt_path)
        else:
            logger.debug("Loading torchscript model")
            self.model = self._load_torchscript_model(model_pt_path)

        self.model.to(self.device)
        self.model.eval()

        logger.debug('Model file %s loaded successfully', model_pt_path)

        # Load class mapping for classifiers
        mapping_file_path = os.path.join(model_dir, "index_to_name.json")
        self.mapping = load_label_mapping(mapping_file_path)

        self.initialized = True

    def _load_torchscript_model(self, model_pt_path):
        """Loads the PyTorch model and returns the NN model object.

        Args:
            model_pt_path (str): denotes the path of the model file.

        Returns:
            (NN Model Object) : Loads the model object.
        """
        return torch.jit.load(model_pt_path, map_location=self.map_location)

    def _load_pickled_model(self, model_dir, model_file, model_pt_path):
        """
        Loads the pickle file from the given model path.

        Args:
            model_dir (str): Points to the location of the model artefacts.
            model_file (.py): the file which contains the model class.
            model_pt_path (str): points to the location of the model pickle file.

        Raises:
            RuntimeError: It raises this error when the model.py file is missing.
            ValueError: Raises value error when there is more than one class in the label,
                        since the mapping supports only one label per class.

        Returns:
            serialized model file: Returns the pickled pytorch model file
        """
        model_def_path = os.path.join(model_dir, model_file)
        if not os.path.isfile(model_def_path):
            raise RuntimeError("Missing the model.py file")

        module = importlib.import_module(model_file.split(".")[0])
        model_class_definitions = list_classes_from_module(module)
        if len(model_class_definitions) != 1:
            raise ValueError(
                "Expected only one class as model definition. {}".format(
                    model_class_definitions
                )
            )

        model_class = model_class_definitions[0]
        state_dict = torch.load(model_pt_path, map_location=self.map_location)
        model = model_class()
        model.load_state_dict(state_dict)
        return model

    def preprocess(self, data):
        """
        Preprocess function to convert the request input to a tensor(Torchserve supported format).
        The user needs to override to customize the pre-processing

        Args :
            data (list): List of the data from the request input.

        Returns:
            tensor: Returns the tensor data of the input
        """
        return torch.as_tensor(data, device=self.device)

    def inference(self, data, *args, **kwargs):
        """
        The Inference Function is used to make a prediction call on the given input request.
        The user needs to override the inference function to customize it.

        Args:
            data (Torch Tensor): A Torch Tensor is passed to make the Inference Request.
            The shape should match the model input shape.

        Returns:
            Torch Tensor : The Predicted Torch Tensor is returned in this function.
        """
        marshalled_data = data.to(self.device)
        with torch.no_grad():
            results = self.model(marshalled_data, *args, **kwargs)
        return results

    def postprocess(self, data):
        """
        The post process function makes use of the output from the inference and converts into a
        Torchserve supported response output.

        Args:
            data (Torch Tensor): The torch tensor received from the prediction output of the model.

        Returns:
            List: The post process function returns a list of the predicted output.
        """

        return data.tolist()

    def handle(self, data, context):
        """Entry point for default handler. It takes the data from the input request and returns
           the predicted outcome for the input.

        Args:
            data (list): The input data that needs to be made a prediction request on.
            context (Context): It is a JSON Object containing information pertaining to
                               the model artefacts parameters.

        Returns:
            list : Returns a list of dictionary with the predicted response.
        """

        # It can be used for pre or post processing if needed as additional request
        # information is available in context
        start_time = time.time()

        self.context = context
        metrics = self.context.metrics

        data_preprocess = self.preprocess(data)

        if not self._is_explain():
            output = self.inference(data_preprocess)
            output = self.postprocess(output)
        else :
            output = self.explain_handle(data_preprocess, data)

        stop_time = time.time()
        metrics.add_time('HandlerTime', round((stop_time - start_time) * 1000, 2), None, 'ms')
        return output

    def explain_handle(self, data_preprocess, raw_data):
        """Captum explanations handler

        Args:
            data_preprocess (Torch Tensor): Preprocessed data to be used for captum
            raw_data (list): The unprocessed data to get target from the request

        Returns:
            dict : A dictionary response with the explanations response.
        """
        output_explain = None
        inputs = None
        target = 0

        logger.info("Calculating Explanations")
        row = raw_data[0]
        if isinstance(row, dict):
            logger.info("Getting data and target")
            inputs = row.get("data") or row.get("body")
            target = row.get("target")
            if not target:
                target = 0

        output_explain = self.get_insights(data_preprocess, inputs, target)
        return output_explain

    def _is_explain(self):
        if self.context and self.context.get_request_header(0, "explain"):
            if self.context.get_request_header(0, "explain") == "True":
                self.explain = True
                return True
        return False

类方法功能分析
  • def initialize:初始化类内变量
  • def preprocess:收取由该模型接口传入的数据
  • inference:送入模型进行预测后得到模型的输出
  • postprocess:对输出数据后处理

部署逻辑

单个模型的基本逻辑
  • 向NER接口发送question可进行question的命名实体识别得到entity
  • 向W2V接口发送entity可得到entity的close entity
  • 向reader接口发送question和doc可获得answer
模型完成问答的共同逻辑
  • 向reader接口发送query–question在这里插入图片描述
  • reader获得query后向NER接口发送question,进行NER
            entity = json.loads(
                requests.post('http://xx.xx.xxx.x:xxxx/predictions/NER', data=query.encode('utf-8')).text)
            entity = entity['entity']
  • 若实体在知识库,则返回对于doc;若不在,调用W2V搜索近似实体,返回近似实体的doc;否则抛出异常,返回不知道
            if entity and entity in self.titles:
                print('实体在知识库中')
                doc_id = self.doc_ids[self.titles.index(entity)]
                sql = "select text from documents where id = '{}'".format(doc_id)
                doc = pd.read_sql_query(sql, con=conn)['text'].tolist()[0]
                return doc, entity
            if entity:
                print(f'实体不在知识库,使用W2V近似搜索:{entity}')
                word = json.loads(
                    requests.post('http://xx.xx.xxx.x:xxxx/predictions/W2V', data=entity.encode('utf-8')).text)
                print(word)
                word = word['close entity'][0]
                print(f'近似搜索结果:{word}')

接口测试

import requests
import json
def process(question):
    question = question.encode('utf-8')
    output = requests.post('http://xx.xx.xxx.x:xxxx/predictions/reader', data=question).text
    answers = json.loads(output)
    answers['index'] = 1
    answers = [answers]
    return answers


def W2V(question):
    question = question.encode('utf-8')
    output = requests.post('http://xx.xx.xxx.x:xxxx/predictions/W2V', data=question).text
    answers = json.loads(output)
    answers['index'] = 1
    answers = [answers]
    return answers


def ner(question):
    question = question.encode('utf-8')
    output = requests.post('http://xx.xx.xxx.x:xxxx/predictions/NER', data=question).text
    answers = json.loads(output)
    answers['index'] = 1
    answers = [answers]
    return answers
  1. 测试NER
    在这里插入图片描述

  2. 测试W2V
    在这里插入图片描述

  3. 测试reader
    在这里插入图片描述

向前端开放可随时调用的命名实体与相似命名实体查询api及使用方式说明

@app.route("/query", methods=["POST"])
def query_w2v(str):
    # data = request.json
    data = {"question": str}
    json.dumps(data)
    answers = W2V(question=data['question'])
    print("+++")
    print(answers)
    return answers

@app.route("/query", methods=["POST"])
def query_enity(str):
    # data = request.json
    data = {"question": str}
    json.dumps(data)
    answers = ner(question=data['question'])
    print("+++")
    print(answers)
    return answers

@app.route("/query", methods=["POST"])
def query(str):
    # data = request.json
    data = {"question": str}
    json.dumps(data)
    answers = process(question=data['question'])
    print("+++")
    print(answers)
    return answers
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值