项目前后端交互

参数说明
def main(
        prompt_text: str,
        system_prompt: str,
        top_p: float = 0.8,
        temperature: float = 0.95,
        repetition_penalty: float = 1.0,
        max_new_tokens: int = 1024,
        retry: bool = False
):

Maximum length 参数

通常用于限制输入序列的最大长度。

因为 ChatGLM-6B 是2048长度推理的,一般这个保持默认就行,太大可能会导致性能下降。

Temperature 参数

用于控制模型输出的结果的随机性。

设置为0,对每个 prompt 都生成固定的输出。

较低的值,输出更集中,更有确定性。

较高的值,输出更随机,更有创意。

Top-p 参数

用于控制模型生成文本的概率分布。

较小的 Top-p 值会导致模型更加倾向于选择高频词汇。

而较大的 Top-p 值则会使模型更加注重选择低频词汇。

合适的 Top-p 值能够平衡生成文本的准确性和多样性。

初始化占位符和会话历史

placeholder = st.empty()
with placeholder.container():
    if 'chat_history' not in st.session_state:
        st.session_state.chat_history = []

创建一个占位符用于展示对话内容。
检查st.session_state中是否存在chat_history,如果不存在则初始化为空列表。

处理空的提示文本

if prompt_text == "" and retry == False:
    print("\n== Clean ==\n")
    st.session_state.chat_history = []
    return

如果提示文本为空且不重试,则清空会话历史并返回。

显示历史对话

history: list[Conversation] = st.session_state.chat_history
for conversation in history:
    conversation.show()

处理重试逻辑

if retry:
    print("\n== Retry ==\n")
    last_user_conversation_idx = None
    for idx, conversation in enumerate(history):
        if conversation.role == Role.USER:
            last_user_conversation_idx = idx
    if last_user_conversation_idx is not None:
        prompt_text = history[last_user_conversation_idx].content
        del history[last_user_conversation_idx:]

如果需要重试,找到最后一个用户对话,将其内容作为新的提示文本,并删除该对话及其后的所有对话。

处理用户输入并生成对话

if prompt_text:
    prompt_text = prompt_text.strip()
    append_conversation(Conversation(Role.USER, prompt_text), history)
    placeholder = st.empty()
    message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
    markdown_placeholder = message_placeholder.empty()

    output_text = ''
    for response in client.generate_stream(
            system_prompt,
            tools=None,
            history=history,
            do_sample=True,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            stop_sequences=[str(Role.USER)],
            repetition_penalty=repetition_penalty,
    ):
        token = response.token
        if response.token.special:
            print("\n==Output:==\n", output_text)
            match token.text.strip():
                case '':
                    break
                case _:
                    st.error(f'Unexpected special token: {token.text.strip()}')
                    break
        output_text += response.token.text
        markdown_placeholder.markdown(postprocess_text(output_text + '▌'))

    append_conversation(Conversation(
        Role.ASSISTANT,
        postprocess_text(output_text),
    ), history, markdown_placeholder)

如果提示文本不为空,则将其添加到会话历史中,并创建新的占位符。
使用client.generate_stream方法生成对话响应,设置各种参数以控制生成行为。
逐步将生成的令牌文本添加到输出文本中,并实时更新显示。
最终将生成的对话添加到会话历史中。

prompt_text = st.chat_input(
    'Chat with ChatGLM3!',  # 提示用户输入的文本
    key='chat_input',       # Streamlit组件的唯一键,用于标识这个输入框
)

response = demo_chat.main(
    retry=retry,                     # 是否重试生成对话,布尔值
    top_p=0.8,                       # 核采样参数,控制生成文本的多样性,0.8表示前80%的概率分布进行采样
    temperature=temperature,         # 控制生成文本的随机性,值越高生成的文本越随机
    prompt_text=prompt_text,         # 用户输入的提示文本
    system_prompt=system_prompt,     # 系统提示文本,提供对话的上下文或风格
    repetition_penalty=repetition_penalty,  # 重复惩罚,避免生成重复内容
    max_new_tokens=max_new_token     # 生成文本的最大长度,控制生成对话的长度
)
if response:
    handle_response(response)         # 处理生成的对话响应

prompt_text

用户在聊天界面输入的文本,这个文本会作为生成对话的初始提示。
st.chat_input 是 Streamlit 用于创建聊天输入框的函数。‘Chat with ChatGLM3!’ 是输入框的提示文字,key=‘chat_input’ 是这个输入框的唯一标识符。

retry

布尔值,指示是否在生成对话失败时重试生成。这个参数决定了在生成的对话不满足条件时,是否重新生成对话。

top_p

核采样参数,控制生成文本的多样性。默认值为 0.8,表示从概率分布的前 80% 采样。较低的值生成的文本越确定,较高的值生成的文本越多样化。

temperature

控制生成文本的随机性。较高的值会生成更随机的结果,较低的值会生成更确定的结果。默认值通常设为 0.95。

system_prompt

系统提示文本,提供对话的上下文或风格信息,以引导生成更符合预期的响应。

repetition_penalty

重复惩罚参数,用于避免生成的文本中出现过多重复内容。默认值为 1.0,值越高重复的概率越低。

max_new_tokens

生成文本的最大长度,以控制生成对话的长度,防止生成过长的响应。默认值通常设为 1024。

代码逻辑

获取用户输入

prompt_text = st.chat_input(
    'Chat with ChatGLM3!',
    key='chat_input',
)

使用 Streamlit 的 chat_input 组件获取用户输入的对话提示文本

调用生成对话的函数

response = demo_chat.main(
    retry=retry,
    top_p=0.8,
    temperature=temperature,
    prompt_text=prompt_text,
    system_prompt=system_prompt,
    repetition_penalty=repetition_penalty,
    max_new_tokens=max_new_token
)

调用 demo_chat.main 函数生成对话响应,并传入多个配置参数,以控制生成对话的行为。

处理生成的响应

if response:
    handle_response(response)

检查是否成功生成响应,如果 response 不为空,则调用 handle_response 函数对响应进行处理。

  • 4
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值