from deep_training.data_helper import ModelArguments, TrainingArguments, DataArguments
from deep_training.nlp.models.chatglm import setup_model_profile, ChatGLMConfig
from deep_training.nlp.models.lora.v2 import LoraArguments
from transformers import HfArgumentParser
from typing import Optional, List, Tuple
from data_utils import train_info_args, NN_DataHelper
from models import MyTransformer,ChatGLMTokenizer
import os
import gradio as gr
from webui.context import ctx
from webui.device import torch_gc
css = "style.css"
script_path = "scripts"
_gradio_template_response_orig = gr.routes.templates.TemplateResponse
# 加载模型
train_info_args['seed'] = None
parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments, LoraArguments))
model_args, training_args, data_args, _ = parser.parse_dict(train_info_args)
setup_model_profile()
dataHelper = NN_DataHelper(model_args, training_args, data_args)
tokenizer: ChatGLMTokenizer
tokenizer, _, _, _ = dataHelper.load_tokenizer_and_config(tokenizer_class_name=ChatGLMTokenizer, config_class_name=ChatGLMConfig)
config = ChatGLMConfig.from_pretrained('./best_ckpt')
# config = ChatGLMConfig.from_pretrained('./best_ckpt')
config.initializer_weight = False
lora_args = LoraArguments.from_pretrained('./last_ckpt')
# lora_args = LoraArguments.from_pretrained('./best_ckpt')
assert lora_args.inference_mode == True and config.pre_seq_len is None
pl_model = MyTransformer(config=config, model_args=model_args, training_args=training_args, lora_args=lora_args)
# 加载lora权重
pl_model.backbone.from_pretrained(pl_model.backbone.model, pretrained_model_name_or_path='./last_ckpt',
lora_config=lora_args)
# pl_model.backbone.from_pretrained(pl_model.backbone.model, pretrained_model_name_or_path = './best_ckpt', lora_config = lora_args)
model = pl_model.get_glm_model()
# 按需修改
model.half().cuda()
model = model.eval()
def infer(query,
history: Optional[List[Tuple]],
max_length, top_p, temperature):
# if cmd_opts.ui_dev:
# return "hello", "hello, dev mode!"
if not model:
raise "Model not loaded"
if history is None:
history = []
output, history = model.chat(
tokenizer, query=query, history=history,
max_length=max_length,
top_p=top_p,
temperature=temperature
)
print(output)
torch_gc()
return query, output
def predict(query, max_length, top_p, temperature):
ctx.limit_round()
_, output = infer(
query=query,
history=ctx.history,
max_length=max_length,
top_p=top_p,
temperature=temperature
)
ctx.append(query, output)
torch_gc()
# for clear input textbox
return ctx.history, ""
def clear_history():
ctx.clear()
return gr.update(value=[])
def apply_max_round_click(max_round):
ctx.max_rounds = max_round
def reload_javascript():
scripts_list = [os.path.join(script_path, i) for i in os.listdir(script_path) if i.endswith(".js")]
javascript = ""
# with open("script.js", "r", encoding="utf8") as js_file:
# javascript = f'<script>{js_file.read()}</script>'
for path in scripts_list:
with open(path, "r", encoding="utf8") as js_file:
javascript += f"\n<script>{js_file.read()}</script>"
# todo: theme
# if cmd_opts.theme is not None:
# javascript += f"\n<script>set_theme('{cmd_opts.theme}');</script>\n"
def template_response(*args, **kwargs):
res = _gradio_template_response_orig(*args, **kwargs)
res.body = res.body.replace(
b'</head>', f'{javascript}</head>'.encode("utf8"))
res.init_headers()
return res
gr.routes.templates.TemplateResponse = template_response
def main():
# 创建ui
reload_javascript()
with gr.Blocks(css=css, analytics_enabled=False) as chat_interface:
prompt = "输入你的内容..."
with gr.Row():
with gr.Column(scale=3):
gr.Markdown("""<h2><center>ChatGLM WebUI</center></h2>""")
with gr.Row():
with gr.Column(variant="panel"):
with gr.Row():
max_length = gr.Slider(minimum=4, maximum=4096, step=4, label='Max Length', value=2048)
top_p = gr.Slider(minimum=0.01, maximum=1.0, step=0.01, label='Top P', value=0.7)
with gr.Row():
temperature = gr.Slider(minimum=0.01, maximum=1.0, step=0.01, label='Temperature',
value=0.95)
with gr.Row():
max_rounds = gr.Slider(minimum=1, maximum=100, step=1, label="最大对话轮数(调小可以显著改善爆显存,但是会丢失上下文)",
value=20)
apply_max_rounds = gr.Button("✔", elem_id="del-btn")
with gr.Row():
with gr.Column(variant="panel"):
with gr.Row():
clear = gr.Button("清空对话(上下文)")
with gr.Row():
save_his_btn = gr.Button("保存对话")
load_his_btn = gr.UploadButton("读取对话", file_types=['file'], file_count='single')
with gr.Column(scale=7):
chatbot = gr.Chatbot(elem_id="chat-box", show_label=False).style(height=800)
with gr.Row():
input_message = gr.Textbox(placeholder=prompt, show_label=False, lines=2, elem_id="chat-input")
clear_input = gr.Button("🗑️", elem_id="del-btn")
with gr.Row():
submit = gr.Button("发送", elem_id="c_generate")
submit.click(predict, inputs=[
input_message,
max_length,
top_p,
temperature
], outputs=[
chatbot,
input_message
])
clear.click(clear_history, outputs=[chatbot])
clear_input.click(lambda x: "", inputs=[input_message], outputs=[input_message])
save_his_btn.click(ctx.save_history)
load_his_btn.upload(ctx.load_history, inputs=[
load_his_btn,
], outputs=[
chatbot
])
apply_max_rounds.click(apply_max_round_click, inputs=[max_rounds])
interfaces = [
(chat_interface, "Chat", "chat"),
]
with gr.Blocks(css=css, analytics_enabled=False, title="ChatGLM") as demo:
with gr.Tabs(elem_id="tabs") as tabs:
for interface, label, ifid in interfaces:
with gr.TabItem(label, id=ifid, elem_id="tab_" + ifid):
interface.render()
ui = demo
ui.launch(
server_name="127.0.0.1",
# server_name="0.0.0.0" if cmd_opts.listen else None,
server_port=17860,
share=True
# share=cmd_opts.share
)
if __name__ == "__main__":
main()
需要配合一些代码一块使用.