用的是清华的glm04,网址:GitHub - THUDM/GLM: GLM (General Language Model)(运行网页是自己写的,附在后面)
能流式输出但是不能自己停下来,EOS token检测已设置,但还是有问题😥
求大佬帮帮忙🙏
下面附上我的代码:
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load tokenizer and model
MODEL_PATH = "./glm-4-9b-chat"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
trust_remote_code=True,
device_map="auto"
).eval()
# Initialize conversation history
conversation_history = []
def generate_text_stream(query):
global conversation_history
# Add user input to conversation history
conversation_history.append({"role": "user", "content": query})
# Limit conversation history length
MAX_HISTORY_LENGTH = 10
if len(conversation_history) > MAX_HISTORY_LENGTH:
conversation_history = conversation_history[-MAX_HISTORY_LENGTH:]
# Prepare inputs
inputs = tokenizer.apply_chat_template(
conversation_history,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True
)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate text step by step
generated_ids = []
with torch.no_grad():
for _ in range(512): # 限制生成最大长度为 512
# 将 generated_ids 转换为 Tensor
generated_ids_tensor = torch.tensor(generated_ids, device=device, dtype=torch.long)
# 拼接 input_ids 和 generated_ids_tensor
input_ids = torch.cat([inputs['input_ids'][0], generated_ids_tensor])
# 生成下一个 Token
outputs = model(
input_ids=input_ids.unsqueeze(0), # 添加 batch 维度
attention_mask=inputs['attention_mask']
)
next_token_logits = outputs.logits[:, -1, :]
next_token_id = torch.argmax(next_token_logits, dim=-1).item()
# 检测是否为 EOS Token
if next_token_id == tokenizer.eos_token_id:
print("Detected EOS token. Stopping...")
break
# 将生成的 Token 添加到 generated_ids
generated_ids.append(next_token_id)
token = tokenizer.decode([next_token_id], skip_special_tokens=True)
yield token # Stream each token to the frontend
# 检测生成长度
if len(generated_ids) >= 512:
print("Reached max token limit. Stopping...")
break
# Add model's response to conversation history
full_response = tokenizer.decode(generated_ids, skip_special_tokens=True)
conversation_history.append({"role": "assistant", "content": full_response})
# Custom CSS for styling
custom_css = """
.gradio-container {
background: url('https://ts1.cn.mm.bing.net/th/id/R-C.7caedee3aba35f84a8d2cc9beb077b88?rik=rznMzp8MiiIpSQ&riu=http%3a%2f%2fimage.thepaper.cn%2fwww%2fimage%2f30%2f422%2f605.jpg&ehk=WqbEXSXPkKZSjWlVY%2fgWzSIpLQ34mE1JqfA999GeH9A%3d&risl=&pid=ImgRaw&r=0');
background-size: cover;
background-position: center;
background-attachment: fixed;
font-family: '宋体', Tahoma, Geneva, Verdana, sans-serif;
color: #333333;
}
h1, h2, h3 {
color: black;
text-align: center;
font-family: '宋体', Tahoma, Geneva, Verdana, sans-serif;
font-weight: bold;
font-size: 32px; /* 调大两号字体 */
}
.gr-textbox {
background-color: rgba(255, 255, 255, 0.2); /* 半透明背景 */
border-radius: 12px; /* 圆角 */
border: 1px solid #cccccc; /* 边框 */
box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.1); /* 阴影效果 */
color: #333333;
font-size: 16px;
}
.gr-textbox label {
font-weight: bold;
color: #4CAF50;
}
.gr-button {
background-color: rgba(255, 255, 255, 0.4); /* 半透明背景 */
color: #333333;
border: 1px solid #cccccc;
border-radius: 20px;
font-size: 16px;
padding: 10px 20px;
cursor: pointer;
transition: background-color 0.3s ease, color 0.3s ease;
}
.gr-button:hover {
background-color: #4CAF50;
color: white;
}
footer {
display: none !important;
}
"""
# Create Gradio interface
with gr.Blocks(css=custom_css, title="AI文本生成系统") as demo:
gr.Markdown("# AI文本生成系统")
gr.Markdown("使用方法:输入问题后请耐心等待,系统会逐步生成相应的文本回答")
with gr.Row():
question_input = gr.Textbox(label="请输入您的问题", placeholder="在这里输入问题...", lines=2)
confirm_button = gr.Button("生成回答")
output_textbox = gr.Textbox(label="生成内容", interactive=False, lines=10)
# Stream text generation
def stream_response(query):
response = ""
for token in generate_text_stream(query):
response += token
yield gr.update(value=response) # Update the textbox with the current response
confirm_button.click(
fn=stream_response,
inputs=question_input,
outputs=output_textbox
)
demo.launch(server_name="0.0.0.0", server_port=12347)
关键部分:
# Initialize conversation history
conversation_history = []
def generate_text_stream(query):
global conversation_history
# Add user input to conversation history
conversation_history.append({"role": "user", "content": query})
# Limit conversation history length
MAX_HISTORY_LENGTH = 10
if len(conversation_history) > MAX_HISTORY_LENGTH:
conversation_history = conversation_history[-MAX_HISTORY_LENGTH:]
# Prepare inputs
inputs = tokenizer.apply_chat_template(
conversation_history,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True
)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate text step by step
generated_ids = []
with torch.no_grad():
for _ in range(512): # 限制生成最大长度为 512
# 将 generated_ids 转换为 Tensor
generated_ids_tensor = torch.tensor(generated_ids, device=device, dtype=torch.long)
# 拼接 input_ids 和 generated_ids_tensor
input_ids = torch.cat([inputs['input_ids'][0], generated_ids_tensor])
# 生成下一个 Token
outputs = model(
input_ids=input_ids.unsqueeze(0), # 添加 batch 维度
attention_mask=inputs['attention_mask']
)
next_token_logits = outputs.logits[:, -1, :]
next_token_id = torch.argmax(next_token_logits, dim=-1).item()
# 检测是否为 EOS Token
if next_token_id == tokenizer.eos_token_id:
print("Detected EOS token. Stopping...")
break
# 将生成的 Token 添加到 generated_ids
generated_ids.append(next_token_id)
token = tokenizer.decode([next_token_id], skip_special_tokens=True)
yield token # Stream each token to the frontend
# 检测生成长度
if len(generated_ids) >= 512:
print("Reached max token limit. Stopping...")
break
# Add model's response to conversation history
full_response = tokenizer.decode(generated_ids, skip_special_tokens=True)
conversation_history.append({"role": "assistant", "content": full_response})
非常感谢!!!!!