GLM开源代码流式输出无法停止怎么办???

用的是清华的glm04,网址:GitHub - THUDM/GLM: GLM (General Language Model)(运行网页是自己写的,附在后面)

能流式输出但是不能自己停下来,EOS token检测已设置,但还是有问题😥

求大佬帮帮忙🙏

下面附上我的代码:

import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import os

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load tokenizer and model
MODEL_PATH = "./glm-4-9b-chat"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    trust_remote_code=True,
    device_map="auto"
).eval()

# Initialize conversation history
conversation_history = []

def generate_text_stream(query):
    global conversation_history

    # Add user input to conversation history
    conversation_history.append({"role": "user", "content": query})

    # Limit conversation history length
    MAX_HISTORY_LENGTH = 10
    if len(conversation_history) > MAX_HISTORY_LENGTH:
        conversation_history = conversation_history[-MAX_HISTORY_LENGTH:]

    # Prepare inputs
    inputs = tokenizer.apply_chat_template(
        conversation_history,
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
        return_dict=True
    )

    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Generate text step by step
    generated_ids = []
    with torch.no_grad():
        for _ in range(512):  # 限制生成最大长度为 512
            # 将 generated_ids 转换为 Tensor
            generated_ids_tensor = torch.tensor(generated_ids, device=device, dtype=torch.long)
            
            # 拼接 input_ids 和 generated_ids_tensor
            input_ids = torch.cat([inputs['input_ids'][0], generated_ids_tensor])
            
            # 生成下一个 Token
            outputs = model(
                input_ids=input_ids.unsqueeze(0),  # 添加 batch 维度
                attention_mask=inputs['attention_mask']
            )
            next_token_logits = outputs.logits[:, -1, :]
            next_token_id = torch.argmax(next_token_logits, dim=-1).item()

            # 检测是否为 EOS Token
            if next_token_id == tokenizer.eos_token_id:
                print("Detected EOS token. Stopping...")
                break

            # 将生成的 Token 添加到 generated_ids
            generated_ids.append(next_token_id)
            token = tokenizer.decode([next_token_id], skip_special_tokens=True)
            yield token  # Stream each token to the frontend

            # 检测生成长度
            if len(generated_ids) >= 512:
                print("Reached max token limit. Stopping...")
                break

    # Add model's response to conversation history
    full_response = tokenizer.decode(generated_ids, skip_special_tokens=True)
    conversation_history.append({"role": "assistant", "content": full_response})


# Custom CSS for styling
custom_css = """
.gradio-container {
    background: url('https://ts1.cn.mm.bing.net/th/id/R-C.7caedee3aba35f84a8d2cc9beb077b88?rik=rznMzp8MiiIpSQ&riu=http%3a%2f%2fimage.thepaper.cn%2fwww%2fimage%2f30%2f422%2f605.jpg&ehk=WqbEXSXPkKZSjWlVY%2fgWzSIpLQ34mE1JqfA999GeH9A%3d&risl=&pid=ImgRaw&r=0');
    background-size: cover;
    background-position: center;
    background-attachment: fixed;
    font-family: '宋体', Tahoma, Geneva, Verdana, sans-serif;
    color: #333333;
}

h1, h2, h3 {
    color: black;
    text-align: center;
    font-family: '宋体', Tahoma, Geneva, Verdana, sans-serif;
    font-weight: bold;
    font-size: 32px; /* 调大两号字体 */
}

.gr-textbox {
    background-color: rgba(255, 255, 255, 0.2); /* 半透明背景 */
    border-radius: 12px; /* 圆角 */
    border: 1px solid #cccccc; /* 边框 */
    box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.1); /* 阴影效果 */
    color: #333333;
    font-size: 16px;
}

.gr-textbox label {
    font-weight: bold;
    color: #4CAF50;
}

.gr-button {
    background-color: rgba(255, 255, 255, 0.4); /* 半透明背景 */
    color: #333333;
    border: 1px solid #cccccc;
    border-radius: 20px;
    font-size: 16px;
    padding: 10px 20px;
    cursor: pointer;
    transition: background-color 0.3s ease, color 0.3s ease;
}

.gr-button:hover {
    background-color: #4CAF50;
    color: white;
}

footer {
    display: none !important;
}
"""

# Create Gradio interface
with gr.Blocks(css=custom_css, title="AI文本生成系统") as demo:
    gr.Markdown("# AI文本生成系统")
    gr.Markdown("使用方法:输入问题后请耐心等待,系统会逐步生成相应的文本回答")

    with gr.Row():
        question_input = gr.Textbox(label="请输入您的问题", placeholder="在这里输入问题...", lines=2)
        confirm_button = gr.Button("生成回答")

    output_textbox = gr.Textbox(label="生成内容", interactive=False, lines=10)

    # Stream text generation
    def stream_response(query):
        response = ""
        for token in generate_text_stream(query):
            response += token
            yield gr.update(value=response)  # Update the textbox with the current response

    confirm_button.click(
        fn=stream_response,
        inputs=question_input,
        outputs=output_textbox
    )

demo.launch(server_name="0.0.0.0", server_port=12347)

关键部分:

# Initialize conversation history
conversation_history = []

def generate_text_stream(query):
    global conversation_history

    # Add user input to conversation history
    conversation_history.append({"role": "user", "content": query})

    # Limit conversation history length
    MAX_HISTORY_LENGTH = 10
    if len(conversation_history) > MAX_HISTORY_LENGTH:
        conversation_history = conversation_history[-MAX_HISTORY_LENGTH:]

    # Prepare inputs
    inputs = tokenizer.apply_chat_template(
        conversation_history,
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
        return_dict=True
    )

    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Generate text step by step
    generated_ids = []
    with torch.no_grad():
        for _ in range(512):  # 限制生成最大长度为 512
            # 将 generated_ids 转换为 Tensor
            generated_ids_tensor = torch.tensor(generated_ids, device=device, dtype=torch.long)
            
            # 拼接 input_ids 和 generated_ids_tensor
            input_ids = torch.cat([inputs['input_ids'][0], generated_ids_tensor])
            
            # 生成下一个 Token
            outputs = model(
                input_ids=input_ids.unsqueeze(0),  # 添加 batch 维度
                attention_mask=inputs['attention_mask']
            )
            next_token_logits = outputs.logits[:, -1, :]
            next_token_id = torch.argmax(next_token_logits, dim=-1).item()

            # 检测是否为 EOS Token
            if next_token_id == tokenizer.eos_token_id:
                print("Detected EOS token. Stopping...")
                break

            # 将生成的 Token 添加到 generated_ids
            generated_ids.append(next_token_id)
            token = tokenizer.decode([next_token_id], skip_special_tokens=True)
            yield token  # Stream each token to the frontend

            # 检测生成长度
            if len(generated_ids) >= 512:
                print("Reached max token limit. Stopping...")
                break

    # Add model's response to conversation history
    full_response = tokenizer.decode(generated_ids, skip_special_tokens=True)
    conversation_history.append({"role": "assistant", "content": full_response})

非常感谢!!!!!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值