torchserve源码分析
父类basehandler
"""
Base default handler to load torchscript or eager mode [state_dict] models
Also, provides handle method per torch serve custom model specification
"""
import abc
import logging
import os
import importlib.util
import time
import torch
from ..utils.util import list_classes_from_module, load_label_mapping
logger = logging.getLogger(__name__)
class BaseHandler(abc.ABC):
"""
Base default handler to load torchscript or eager mode [state_dict] models
Also, provides handle method per torch serve custom model specification
"""
def __init__(self):
self.model = None
self.mapping = None
self.device = None
self.initialized = False
self.context = None
self.manifest = None
self.map_location = None
self.explain = False
self.target = 0
def initialize(self, context):
"""Initialize function loads the model.pt file and initialized the model object.
First try to load torchscript else load eager mode state_dict based model.
Args:
context (context): It is a JSON Object containing information
pertaining to the model artifacts parameters.
Raises:
RuntimeError: Raises the Runtime error when the model.py is missing
"""
properties = context.system_properties
self.map_location = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(
self.map_location + ":" + str(properties.get("gpu_id"))
if torch.cuda.is_available()
else self.map_location
)
self.manifest = context.manifest
model_dir = properties.get("model_dir")
serialized_file = self.manifest["model"]["serializedFile"]
model_pt_path = os.path.join(model_dir, serialized_file)
if not os.path.isfile(model_pt_path):
raise RuntimeError("Missing the model.pt file")
# model def file
model_file = self.manifest["model"].get("modelFile", "")
if model_file:
logger.debug("Loading eager model")
self.model = self._load_pickled_model(model_dir, model_file, model_pt_path)
else:
logger.debug("Loading torchscript model")
self.model = self._load_torchscript_model(<