搭建服务器端生成API
模型加载初始化
定义函数,从check point加载训练好的模型
def get_lora_model(pretrained_ckpt, use_bf16=False):
"""Model Provider with tokenizer and processor.
Args:
pretrained_ckpt (string): The path to pre-trained checkpoint.
use_bf16 (bool, optional): Whether to use bfloat16 to load the model. Defaults to False.
Returns:
model: MplugOwl Model
tokenizer: MplugOwl text tokenizer
processor: MplugOwl processor (including text and image)
"""
model = MplugOwlForConditionalGeneration.from_pretrained(
pretrained_ckpt,
torch_dtype=torch.bfloat16 if use_bf16 else torch.half,
)
image_processor = MplugOwlImageProcessor.from_pretrained(pretrained_ckpt)
tokenizer = AutoTokenizer.from_pretrained(pretrained_ckpt)
processor = MplugOwlProcessor(image_processor, tokenizer)
return model, tokenizer, processor
其中use_bf16为半精度加载模式,可节约大量显存占用(约50%),但同时也可能会损失推理精度。
模型访问接口封装
将模型的加载部署、推理等调用都封装为class model_interface的方法:
class model_interface:
def __init__(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ckpt_dir = '/root/zsk/MAGAer13/mplug-owl-bloomz-7b-multilingual/'
# self.model, self.tokenizer, self.processor = get_model(pretrained_ckpt=ckpt_dir, use_bf16=True)
self.model, self.tokenizer, self.processor = get_model(pretrained_ckpt=ckpt_dir, use_bf16=False)
self.model