torchserve源码分析
父类basehandler
"""
Base default handler to load torchscript or eager mode [state_dict] models
Also, provides handle method per torch serve custom model specification
"""
import abc
import logging
import os
import importlib.util
import time
import torch
from ..utils.util import list_classes_from_module, load_label_mapping
logger = logging.getLogger(__name__)
class BaseHandler(abc.ABC):
"""
Base default handler to load torchscript or eager mode [state_dict] models
Also, provides handle method per torch serve custom model specification
"""
def __init__(self):
self.model = None
self.mapping = None
self.device = None
self.initialized = False
self.context = None
self.manifest = None
self.map_location = None
self.explain = False
self.target = 0
def initialize(self, context):
"""Initialize function loads the model.pt file and initialized the model object.
First try to load torchscript else load eager mode state_dict based model.
Args:
context (context): It is a JSON Object containing information
pertaining to the model artifacts parameters.
Raises:
RuntimeError: Raises the Runtime error when the model.py is missing
"""
properties = context.system_properties
self.map_location = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(
self.map_location + ":" + str(properties.get("gpu_id"))
if torch.cuda.is_available()
else self.map_location
)
self.manifest = context.manifest
model_dir = properties.get("model_dir")
serialized_file = self.manifest["model"]["serializedFile"]
model_pt_path = os.path.join(model_dir, serialized_file)
if not os.path.isfile(model_pt_path):
raise RuntimeError("Missing the model.pt file")
# model def file
model_file = self.manifest["model"].get("modelFile", "")
if model_file:
logger.debug("Loading eager model")
self.model = self._load_pickled_model(model_dir, model_file, model_pt_path)
else:
logger.debug("Loading torchscript model")
self.model = self._load_torchscript_model(model_pt_path)
self.model.to(self.device)
self.model.eval()
logger.debug('Model file %s loaded successfully', model_pt_path)
# Load class mapping for classifiers
mapping_file_path = os.path.join(model_dir, "index_to_name.json")
self.mapping = load_label_mapping(mapping_file_path)
self.initialized = True
def _load_torchscript_model(self, model_pt_path):
"""Loads the PyTorch model and returns the NN model object.
Args:
model_pt_path (str): denotes the path of the model file.
Returns:
(NN Model Object) : Loads the model object.
"""
return torch.jit.load(model_pt_path, map_location=self.map_location)
def _load_pickled_model(self, model_dir, model_file, model_pt_path):
"""
Loads the pickle file from the given model path.
Args:
model_dir (str): Points to the location of the model artefacts.
model_file (.py): the file which contains the model class.
model_pt_path (str): points to the location of the model pickle file.
Raises:
RuntimeError: It raises this error when the model.py file is missing.
ValueError: Raises value error when there is more than one class in the label,
since the mapping supports only one label per class.
Returns:
serialized model file: Returns the pickled pytorch model file
"""
model_def_path = os.path.join(model_dir, model_file)
if not os.path.isfile(model_def_path):
raise RuntimeError("Missing the model.py file")
module = importlib.import_module(model_file.split(".")[0])
model_class_definitions = list_classes_from_module(module)
if len(model_class_definitions) != 1:
raise ValueError(
"Expected only one class as model definition. {}".format(
model_class_definitions
)
)
model_class = model_class_definitions[0]
state_dict = torch.load(model_pt_path, map_location=self.map_location)
model = model_class()
model.load_state_dict(state_dict)
return model
def preprocess(self, data):
"""
Preprocess function to convert the request input to a tensor(Torchserve supported format).
The user needs to override to customize the pre-processing
Args :
data (list): List of the data from the request input.
Returns:
tensor: Returns the tensor data of the input
"""
return torch.as_tensor(data, device=self.device)
def inference(self, data, *args, **kwargs):
"""
The Inference Function is used to make a prediction call on the given input request.
The user needs to override the inference function to customize it.
Args:
data (Torch Tensor): A Torch Tensor is passed to make the Inference Request.
The shape should match the model input shape.
Returns:
Torch Tensor : The Predicted Torch Tensor is returned in this function.
"""
marshalled_data = data.to(self.device)
with torch.no_grad():
results = self.model(marshalled_data, *args, **kwargs)
return results
def postprocess(self, data):
"""
The post process function makes use of the output from the inference and converts into a
Torchserve supported response output.
Args:
data (Torch Tensor): The torch tensor received from the prediction output of the model.
Returns:
List: The post process function returns a list of the predicted output.
"""
return data.tolist()
def handle(self, data, context):
"""Entry point for default handler. It takes the data from the input request and returns
the predicted outcome for the input.
Args:
data (list): The input data that needs to be made a prediction request on.
context (Context): It is a JSON Object containing information pertaining to
the model artefacts parameters.
Returns:
list : Returns a list of dictionary with the predicted response.
"""
# It can be used for pre or post processing if needed as additional request
# information is available in context
start_time = time.time()
self.context = context
metrics = self.context.metrics
data_preprocess = self.preprocess(data)
if not self._is_explain():
output = self.inference(data_preprocess)
output = self.postprocess(output)
else :
output = self.explain_handle(data_preprocess, data)
stop_time = time.time()
metrics.add_time('HandlerTime', round((stop_time - start_time) * 1000, 2), None, 'ms')
return output
def explain_handle(self, data_preprocess, raw_data):
"""Captum explanations handler
Args:
data_preprocess (Torch Tensor): Preprocessed data to be used for captum
raw_data (list): The unprocessed data to get target from the request
Returns:
dict : A dictionary response with the explanations response.
"""
output_explain = None
inputs = None
target = 0
logger.info("Calculating Explanations")
row = raw_data[0]
if isinstance(row, dict):
logger.info("Getting data and target")
inputs = row.get("data") or row.get("body")
target = row.get("target")
if not target:
target = 0
output_explain = self.get_insights(data_preprocess, inputs, target)
return output_explain
def _is_explain(self):
if self.context and self.context.get_request_header(0, "explain"):
if self.context.get_request_header(0, "explain") == "True":
self.explain = True
return True
return False
类方法功能分析
- def initialize:初始化类内变量
- def preprocess:收取由该模型接口传入的数据
- inference:送入模型进行预测后得到模型的输出
- postprocess:对输出数据后处理
部署逻辑
单个模型的基本逻辑
- 向NER接口发送question可进行question的命名实体识别得到entity
- 向W2V接口发送entity可得到entity的close entity
- 向reader接口发送question和doc可获得answer
模型完成问答的共同逻辑
- 向reader接口发送query–question
- reader获得query后向NER接口发送question,进行NER
entity = json.loads(
requests.post('http://xx.xx.xxx.x:xxxx/predictions/NER', data=query.encode('utf-8')).text)
entity = entity['entity']
- 若实体在知识库,则返回对于doc;若不在,调用W2V搜索近似实体,返回近似实体的doc;否则抛出异常,返回不知道
if entity and entity in self.titles:
print('实体在知识库中')
doc_id = self.doc_ids[self.titles.index(entity)]
sql = "select text from documents where id = '{}'".format(doc_id)
doc = pd.read_sql_query(sql, con=conn)['text'].tolist()[0]
return doc, entity
if entity:
print(f'实体不在知识库,使用W2V近似搜索:{entity}')
word = json.loads(
requests.post('http://xx.xx.xxx.x:xxxx/predictions/W2V', data=entity.encode('utf-8')).text)
print(word)
word = word['close entity'][0]
print(f'近似搜索结果:{word}')
接口测试
import requests
import json
def process(question):
question = question.encode('utf-8')
output = requests.post('http://xx.xx.xxx.x:xxxx/predictions/reader', data=question).text
answers = json.loads(output)
answers['index'] = 1
answers = [answers]
return answers
def W2V(question):
question = question.encode('utf-8')
output = requests.post('http://xx.xx.xxx.x:xxxx/predictions/W2V', data=question).text
answers = json.loads(output)
answers['index'] = 1
answers = [answers]
return answers
def ner(question):
question = question.encode('utf-8')
output = requests.post('http://xx.xx.xxx.x:xxxx/predictions/NER', data=question).text
answers = json.loads(output)
answers['index'] = 1
answers = [answers]
return answers
-
测试NER
-
测试W2V
-
测试reader
向前端开放可随时调用的命名实体与相似命名实体查询api及使用方式说明
@app.route("/query", methods=["POST"])
def query_w2v(str):
# data = request.json
data = {"question": str}
json.dumps(data)
answers = W2V(question=data['question'])
print("+++")
print(answers)
return answers
@app.route("/query", methods=["POST"])
def query_enity(str):
# data = request.json
data = {"question": str}
json.dumps(data)
answers = ner(question=data['question'])
print("+++")
print(answers)
return answers
@app.route("/query", methods=["POST"])
def query(str):
# data = request.json
data = {"question": str}
json.dumps(data)
answers = process(question=data['question'])
print("+++")
print(answers)
return answers