模型运行页面设计webui代码

from deep_training.data_helper import ModelArguments, TrainingArguments, DataArguments
from deep_training.nlp.models.chatglm import setup_model_profile, ChatGLMConfig
from deep_training.nlp.models.lora.v2 import LoraArguments
from transformers import HfArgumentParser
from typing import Optional, List, Tuple
from data_utils import train_info_args, NN_DataHelper
from models import MyTransformer,ChatGLMTokenizer

import os

import gradio as gr

from webui.context import ctx
from webui.device import torch_gc
css = "style.css"
script_path = "scripts"
_gradio_template_response_orig = gr.routes.templates.TemplateResponse

# 加载模型
train_info_args['seed'] = None
parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments, LoraArguments))
model_args, training_args, data_args, _ = parser.parse_dict(train_info_args)

setup_model_profile()

dataHelper = NN_DataHelper(model_args, training_args, data_args)
tokenizer: ChatGLMTokenizer
tokenizer, _, _, _ = dataHelper.load_tokenizer_and_config(tokenizer_class_name=ChatGLMTokenizer, config_class_name=ChatGLMConfig)

config = ChatGLMConfig.from_pretrained('./best_ckpt')
# config = ChatGLMConfig.from_pretrained('./best_ckpt')
config.initializer_weight = False

lora_args = LoraArguments.from_pretrained('./last_ckpt')
# lora_args = LoraArguments.from_pretrained('./best_ckpt')

assert lora_args.inference_mode == True and config.pre_seq_len is None

pl_model = MyTransformer(config=config, model_args=model_args, training_args=training_args, lora_args=lora_args)
# 加载lora权重
pl_model.backbone.from_pretrained(pl_model.backbone.model, pretrained_model_name_or_path='./last_ckpt',
                                      lora_config=lora_args)
# pl_model.backbone.from_pretrained(pl_model.backbone.model, pretrained_model_name_or_path = './best_ckpt', lora_config = lora_args)

model = pl_model.get_glm_model()
# 按需修改
model.half().cuda()
model = model.eval()


def infer(query,
          history: Optional[List[Tuple]],
          max_length, top_p, temperature):
    # if cmd_opts.ui_dev:
    #     return "hello", "hello, dev mode!"

    if not model:
        raise "Model not loaded"

    if history is None:
        history = []
    output, history = model.chat(
        tokenizer, query=query, history=history,
        max_length=max_length,
        top_p=top_p,
        temperature=temperature
    )
    print(output)
    torch_gc()
    return query, output

def predict(query, max_length, top_p, temperature):
    ctx.limit_round()
    _, output = infer(
        query=query,
        history=ctx.history,
        max_length=max_length,
        top_p=top_p,
        temperature=temperature
    )
    ctx.append(query, output)
    torch_gc()
    # for clear input textbox
    return ctx.history, ""


def clear_history():
    ctx.clear()
    return gr.update(value=[])


def apply_max_round_click(max_round):
    ctx.max_rounds = max_round

def reload_javascript():
    scripts_list = [os.path.join(script_path, i) for i in os.listdir(script_path) if i.endswith(".js")]
    javascript = ""
    # with open("script.js", "r", encoding="utf8") as js_file:
    #     javascript = f'<script>{js_file.read()}</script>'

    for path in scripts_list:
        with open(path, "r", encoding="utf8") as js_file:
            javascript += f"\n<script>{js_file.read()}</script>"

    # todo: theme
    # if cmd_opts.theme is not None:
    #     javascript += f"\n<script>set_theme('{cmd_opts.theme}');</script>\n"

    def template_response(*args, **kwargs):
        res = _gradio_template_response_orig(*args, **kwargs)
        res.body = res.body.replace(
            b'</head>', f'{javascript}</head>'.encode("utf8"))
        res.init_headers()
        return res

    gr.routes.templates.TemplateResponse = template_response

def main():
    # 创建ui
    reload_javascript()

    with gr.Blocks(css=css, analytics_enabled=False) as chat_interface:
        prompt = "输入你的内容..."
        with gr.Row():
            with gr.Column(scale=3):
                gr.Markdown("""<h2><center>ChatGLM WebUI</center></h2>""")
                with gr.Row():
                    with gr.Column(variant="panel"):
                        with gr.Row():
                            max_length = gr.Slider(minimum=4, maximum=4096, step=4, label='Max Length', value=2048)
                            top_p = gr.Slider(minimum=0.01, maximum=1.0, step=0.01, label='Top P', value=0.7)
                        with gr.Row():
                            temperature = gr.Slider(minimum=0.01, maximum=1.0, step=0.01, label='Temperature',
                                                    value=0.95)

                        with gr.Row():
                            max_rounds = gr.Slider(minimum=1, maximum=100, step=1, label="最大对话轮数(调小可以显著改善爆显存,但是会丢失上下文)",
                                                   value=20)
                            apply_max_rounds = gr.Button("✔", elem_id="del-btn")

                with gr.Row():
                    with gr.Column(variant="panel"):
                        with gr.Row():
                            clear = gr.Button("清空对话(上下文)")

                        with gr.Row():
                            save_his_btn = gr.Button("保存对话")
                            load_his_btn = gr.UploadButton("读取对话", file_types=['file'], file_count='single')

            with gr.Column(scale=7):
                chatbot = gr.Chatbot(elem_id="chat-box", show_label=False).style(height=800)
                with gr.Row():
                    input_message = gr.Textbox(placeholder=prompt, show_label=False, lines=2, elem_id="chat-input")
                    clear_input = gr.Button("🗑️", elem_id="del-btn")

                with gr.Row():
                    submit = gr.Button("发送", elem_id="c_generate")

        submit.click(predict, inputs=[
            input_message,
            max_length,
            top_p,
            temperature
        ], outputs=[
            chatbot,
            input_message
        ])

        clear.click(clear_history, outputs=[chatbot])
        clear_input.click(lambda x: "", inputs=[input_message], outputs=[input_message])

        save_his_btn.click(ctx.save_history)
        load_his_btn.upload(ctx.load_history, inputs=[
            load_his_btn,
        ], outputs=[
            chatbot
        ])

        apply_max_rounds.click(apply_max_round_click, inputs=[max_rounds])

        interfaces = [
            (chat_interface, "Chat", "chat"),
        ]

    with gr.Blocks(css=css, analytics_enabled=False, title="ChatGLM") as demo:
        with gr.Tabs(elem_id="tabs") as tabs:
            for interface, label, ifid in interfaces:
                with gr.TabItem(label, id=ifid, elem_id="tab_" + ifid):
                    interface.render()

    ui = demo

    ui.launch(
        server_name="127.0.0.1",
        # server_name="0.0.0.0" if cmd_opts.listen else None,
        server_port=17860,
        share=True
        # share=cmd_opts.share
    )


if __name__ == "__main__":
    main()

需要配合一些代码一块使用.

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值