为了能够使得模型能够更调用并与前端进行通讯,这里我们选择flask作为后端框架,这样springboot端即可通过访问对应的网址从而实现模型的调用,在之前chat类的基础上,只需要根据springboot端的请求从而实现对应的功能即可实现python端的数据返回给后端springboot。
具体而言,使用parse_args()函数解析命令行参数,并将结果存储在args变量中。使用Config类和args参数创建配置对象cfg,用于读取和管理应用程序的配置信息。根据配置对象cfg获取模型配置信息model_config。根据模型配置信息中的模型架构(model_config.arch),从注册(registry)中获取相应的模型类。使用模型类和模型配置创建模型对象model。将模型移动到指定的GPU设备上(cuda:args.gpu_id)。根据配置对象cfg获取可视化处理器(vis_processor)的配置信息。
根据可视化处理器配置信息中的名称(vis_processor_cfg.name),从注册表中获取相应的处理器类。使用可视化处理器类和可视化处理器配置创建可视化处理器对象vis_processor。
使用模型对象model、可视化处理器对象vis_processor和指定的设备('cuda:args.gpu_id')创建Chat对象chat。最后启动flask端,app.run(debug=False)。
print('Initializing Chat')
args = parse_args()
cfg = Config(args)
use_amp = cfg.run_cfg.get("amp", False)
model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
print(model_config)
model = model_cls.from_config(model_config)
model = model.to('cuda:{}'.format(args.gpu_id))
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
print('Initialization Finished')
app.run(debug=False)
接收函数:
@torch.no_grad()
@app.route('/upload',methods=['GET'])
def upload_img():
smiles = request.args.get('smiles', None)
print(smiles)
reset_message()
if smiles is None:
return jsonify({'success': False})
with torch.cuda.amp.autocast(use_amp):
llm_message = chat.upload_img(smiles, chat_state, img_list)
print(type(llm_message))
return llm_message
smiles = request.args.get('smiles', None)
: 从请求的查询参数中获取名为'smiles'的值,赋给变量smiles
。如果查询参数中不存在'smiles',则将smiles
的值设为None
。reset_message()
: 调用一个名为reset_message()
的函数,用于重置消息。如果smiles
是None
,则返回一个JSON响应,内容为{'success': False}
,表示操作失败。with torch.cuda.amp.autocast(use_amp):
: 用于自动混合精度计算。它会在下面的代码块中启用自动混合精度计算。llm_message = chat.upload_img(smiles, chat_state, img_list)
: 调用名为chat.upload_img()
的函数,传入
@torch.no_grad()
@app.route('/ask', methods=['POST'])
def ask():
user_message = request.json.get('user_message', None)
print(user_message)
if len(user_message) == 0:
return None
else:
chat.ask(user_message,chat_state)
with torch.cuda.amp.autocast(use_amp):
llm_message = chat.answer(conv=chat_state,
img_list=img_list,
num_beams=1,
temperature=1,
max_new_tokens=300,
max_length=1000)[0]
print(llm_message)
return llm_message
smiles
、chat_state
和img_list
作为参数,并将返回的结果赋给变量llm_message
。user_message = request.json.get('user_message', None)
: 从请求的JSON数据中获取名为'user_message'的值,赋给变量user_message
。如果JSON数据中不存在'user_message',则将user_message
的值设为None
。if len(user_message) == 0:
: 检查user_message
的长度是否为0。用户消息为空,则返回None
。调用名为chat.ask()
的函数,传入user_message
和chat_state
作为参数,用于处理用户消息。with torch.cuda.amp.autocast(use_amp):
:用于自动混合精度计算。它会在下面的代码块中启用自动混合精度计算。llm_message = chat.answer(conv=chat_state, img_list=img_list, num_beams=1, temperature=1, max_new_tokens=300, max_length=1000)[0]
: 调用名为chat.answer()
的函数,传入chat_state
、img_list
以及其他参数,用于生成回答消息。返回的结果是一个列表,取第一个元素赋给变量llm_message
。 返回llm_message
作为响应。
重置函数:
def reset_message():
img_list = []
chat_state = CONV_VISION.copy()
直接将消息状态和列表清空,实现重置的效果。