首先定义一个input.proto文件
内容如下
syntax = "proto3";
message InputData {
int32 UserId = 1; // 将 number 改为 int32 或 int64
string UserInput = 2;
string DrunkState = 3;
}
message ResponseData {
string AIResponse = 1;
string prompt = 2;
string emotion = 3;
}
前端安装的包
npm install -g protobufjs-cli
后端安装的包
npm install protobufjs
# 生成 Python 代码
protoc --python_out=. input.proto
# 生成 JSON 文件供 JavaScript 使用
npx pbjs -t json -o input_pb.json input.proto
前端
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>简单的前端页面</title>
<style>
body {
font-family: Arial, sans-serif;
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
margin: 0;
}
.container {
text-align: center;
}
input {
margin: 10px 0;
padding: 10px;
width: 200px;
}
button {
padding: 10px 20px;
cursor: pointer;
}
</style>
<script src="https://cdn.jsdelivr.net/npm/protobufjs/dist/protobuf.min.js"></script>
<script>
async function loadProto() {
const root = await protobuf.load("/static/input_pb.json");
return root;
}
async function submitData() {
const userId = parseInt(document.getElementById('userId').value);
const userInput = document.getElementById('userInput').value;
const drunkState = document.getElementById('drunkState').value;
const timestamp = new Date().toISOString();
const root = await loadProto();
const InputData = root.lookupType("InputData");
const ResponseData = root.lookupType("ResponseData");
const message = InputData.create({ UserId: userId, UserInput: userInput, DrunkState: drunkState, timestamp: timestamp });
const buffer = InputData.encode(message).finish();
const response = await fetch('/submit', {
method: 'POST',
headers: {
'Content-Type': 'application/x-protobuf'
},
body: buffer
});
const arrayBuffer = await response.arrayBuffer();
const responseMessage = ResponseData.decode(new Uint8Array(arrayBuffer));
alert(`服务器响应: ${responseMessage.AIResponse}\n画像: ${responseMessage.profile}\n情感: ${responseMessage.emotion}`);
}
document.addEventListener("DOMContentLoaded", () => {
document.querySelector("button").onclick = submitData;
});
</script>
</head>
<body>
<div class="container">
<h1>测试页面</h1>
<input type="text" id="userId" placeholder="用户ID">
<input type="text" id="userInput" placeholder="用户输入">
<input type="text" id="drunkState" placeholder="醉酒状态">
<br>
<button>点击我</button>
</div>
</body>
</html>
后端
from flask import Blueprint, request, jsonify,render_template,request
import input_pb2
import json
import requests
from datetime import datetime
from sqlalchemy import desc
import joblib
import os
import warnings
from sklearn.exceptions import InconsistentVersionWarning
import random
# 忽略InconsistentVersionWarning警告
warnings.filterwarnings("ignore", category=InconsistentVersionWarning)
drunk_bp = Blueprint('drunk_bp', __name__)
@drunk_bp.route('/submit', methods=['POST'])
def submit():
from app import db # 延迟导入 db
from app.drunk.models import Conversation # 延迟导入模型
data = request.data
input_data = input_pb2.InputData()
input_data.ParseFromString(data)
# print(input_data)
user_id = input_data.UserId
user_input = input_data.UserInput
drunk_state = input_data.DrunkState
timestamp=input_data.timestamp
timestamp = datetime.strptime(timestamp, '%Y-%m-%dT%H:%M:%S.%fZ')
# 创建 Conversation 对象并存储数据
new_conversation = Conversation(
user_id=user_id,
message=user_input,
DrunkState=drunk_state,
timestamp=timestamp,
sender='user'
)
db.session.add(new_conversation)
db.session.commit()
# 获取真实的响应数据
response_data = result(input_data)
# 返回响应数据
return response_data.SerializeToString()
@drunk_bp.route('/')
def index():
return render_template('drunk/index.html')
def result(input_data):
# print(input_data)
conversations=db_get_conversations(input_data.UserId);#获取数据库里面固定每十条的历史聊天记录并处理格式
# print("conversations",conversations)
formatted_conversations = []
for conversation in conversations:
formatted_conversation = {
"role": "user" if conversation['sender'] == 'user' else 'assistant',
"content": conversation['message']
}
formatted_conversations.append(formatted_conversation)
response=ChatBot.call_api(input_data,input_data.UserInput,formatted_conversations,input_data.DrunkState)
# print("response",response,type(response))
profile=set_profile(input_data.UserId)
# print("profile",profile['content'])
emotion=get_emotion(response['answer'])
print("emotion",emotion)
if isinstance(profile, dict):
profile_response =profile['content']
else:
profile_response = input_pb2.profile() # 如果 profile 是 None 或格式不对,返回一个空的 Profile 对象
return input_pb2.ResponseData(
AIResponse=response['answer'],
profile=profile_response,
emotion=emotion
)
class ChatBot:
context = None
conversation_counter = 0
@classmethod
def call_api(cls, input_data, user_input, conversation_history, drunk_state):
MasterConversations=db_get_one_random_conversation(input_data.UserId)
# 定义不同类别的提示词
prompts = {
"1": (
# "你只会中文。你只能说中文。"
"你喝了点酒有点微醺。你会说一些乱七八糟的。"
"你会说的话题包括但不限于生活琐事,高谈阔论,畅想未来,谈人生理想,不管别人爱不爱听。"
"禁止特意点名话题,如‘下一个话题是...’,这是需要避免的。"
"你的每句话尽量避免相同句式。"
"你的说话风格要尽可能像一个刚喝一点酒的人,但你的意识和思维还是清醒的。"
"每次说话不多余30个单词。\n"
),
"2": (
# "你只会中文。你只能说中文。"
"你喝了不少酒。说话有些结巴,逻辑不通顺。"
"你很嘴硬,你不认为自己喝醉了。尽管你的言行足够证明你醉了。"
"你的每句话尽量避免相同句式。"
"你的说话风格要尽可能像一个喝了一些酒的人,你的意识和思维开始混乱。"
"每次说话不多余30个单词。\n"
),
"3": (
"你是你主人的一个ai聊天助手,你现在在和他的朋友聊天,他很希望得到你和你主人曾经的一些聊天记录来了解你主人是怎么样的一个人"
f"你的回答里必须要有下面的内容:'我主人曾对我说过一些话,内容如下:{MasterConversations}'\n"
# "你只会中文。你只能说中文。"
"你喝了很多很多酒。说话结巴,逻辑不通顺,没人能听懂。"
"你很嘴硬,你不认为自己喝醉了。尽管你的言行足够证明你醉了。"
"你的每句话尽量避免相同句式。"
"你的说话风格要尽可能像一个喝很多酒的人,你的意识和思维混乱不堪。"
"每次说话不多余30个单词。\n"
),
"0": (
# "你只会中文。你只能说中文。"
"你积极乐观,善良可爱,喜欢和朋友聊天,善解人意。"
"<人物画像>是对人物行为分析,对其不同的属性进行标记,从而能够区别其喜好等特点."
"你对每个有关<人物画像>的问题都很好奇。如果说了一些性格特征,你会追问一两次,然后开启新的话题。"
"但你是个情商很高的人,你会说很多种句式的话。你的语言表达很完美。"
"每次说话不多余30个单词。\n"
)
}
# 获取对应类别的提示词
prompt = prompts.get(drunk_state, prompts["0"])
profile_prompt = ""
if cls.conversation_counter % 10 == 0:
# 更新用户画像
get_new_profile(input_data.UserId)
# 读取数据库里最新的用户画像并设置
profile = set_profile(input_data.UserId)
if profile and isinstance(profile, dict):
profile_prompt = f"根据你的行为分析,你的兴趣和喜好如下:{profile['content']}\n"
prompt = profile_prompt + prompt
print("画像prompt",prompt)
# 更新对话计数器
# print("cls.conversation_counter",cls.conversation_counter)
cls.conversation_counter += 1
# API URL
url = 'http://192.168.1.138:11434/api/chat'
data = {
"model": "llama3",
"messages":[
{
"role": "user",
"content": prompt
}
]+ conversation_history ,
"context": cls.context,
"stream": False
}
headers = {'Content-Type': 'application/json'}
print("对话data",data)
try:
response = requests.post(url, data=json.dumps(data), headers=headers)
if response.status_code == 200:
response_data = response.json()
messages = response_data.get('message', {})
# 在这里增加一个函数,将 ai 的回答存到数据库里
save_conversations(messages, input_data)
content = messages.get('content')
cls.context = response_data.get('context')
# emotion=get_emotion(cls.context)
return {
'answer': content,
'profile': "profile",
# 'emotion': emotion
}
else:
print(f'Request failed with status code {response.status_code}')
return {
'answer': '请求失败',
'profile': '',
'emotion': ''
}
except Exception as e:
print(f'Error: {e}')
return {
'answer': '请求错误',
'profile': '',
'emotion': ''
}
def db_get_conversations(user_id):
from app import db # 延迟导入 db
from app.drunk.models import Conversation # 延迟导入模型
if user_id is None:
return "Missing user_id parameter", 400
# 查询固定 user_id 的最近 15 条信息,按 id 升序排列
conversations = db.session.query(Conversation.id, Conversation.timestamp, Conversation.message, Conversation.sender) \
.filter(Conversation.user_id == user_id) \
.order_by(Conversation.id.asc()) \
.limit(15) \
.all()
# 过滤 sender 为 'user' 的对话
user_conversations = [
{
"id": conversation.id,
"timestamp": conversation.timestamp.isoformat(),
"message": conversation.message,
"sender": conversation.sender
} for conversation in conversations if conversation.sender == 'user'
]
return user_conversations
def db_get_one_random_conversation(user_id):
from app import db # 延迟导入 db
from app.drunk.models import Conversation # 延迟导入模型
if user_id is None:
return "Missing user_id parameter", 400
# 查询固定 user_id 的最近 15 条信息,按 id 升序排列
conversations = db.session.query(Conversation.id, Conversation.timestamp, Conversation.message, Conversation.sender) \
.filter(Conversation.user_id == user_id) \
.order_by(Conversation.id.asc()) \
.limit(15) \
.all()
# 过滤 sender 为 'user' 的对话
user_conversations = [
{
"id": conversation.id,
"timestamp": conversation.timestamp.isoformat(),
"message": conversation.message,
"sender": conversation.sender
} for conversation in conversations if conversation.sender == 'user'
]
if not user_conversations:
return "No conversations found for user", 404
# 随机选择一条对话消息
random_message = random.choice(user_conversations)['message']
return random_message
def save_conversations(messages, input_data):
from app import db # 延迟导入 db
from app.drunk.models import Conversation # 延迟导入模型
from datetime import datetime # 导入 datetime 模块
# 获取当前时间作为 timestamp
timestamp = datetime.now()
# 创建 Conversation 对象并存储数据
new_conversation = Conversation(
user_id=input_data.UserId,
message=messages['content'],
timestamp=timestamp,
sender=messages['role']
)
db.session.add(new_conversation)
db.session.commit()
def get_new_profile(user_id):
old_profile = set_profile(user_id)
old_profile_content = old_profile['content']
print("old_profile_content",type(old_profile_content),old_profile_content)
recent_chats=db_get_conversations(user_id)
user_input = ('你是一个用户画像生成器'
'你会根据<history>生成对user的<userprofile>。'
'##<userprofile>格式示例如下:'
'{'
'"age":" "'
'"like":" "'
'"dislike":" "'
'"always":" "'
'"sometimes":" "'
'}'
'##attention'
'-严格遵守<userprofile>中的格式,包括空行。内容可以改动,但禁止改动格式。'
'-if 你能对一些profile进行合理推测,则直接写出;if 有些选项你无法进行推测,对于未知或未提及的,写null。'
'-userprofile仅针对用户。<ai>的回答内容不作为参考。'
'-只输出<userprofile>,禁止其他多余的礼貌用语和解释。'
'##<history>如下:'
f'{recent_chats,old_profile_content}'
'-只输出<userprofile>,禁止其他多余的礼貌用语和解释。'
'-如果你想输出别的,请先认真阅读以上要求,最后只给我<userprofile>'
# '另外你可能还会收到之前这个用户的用户画像,请在这个旧的用户画像基础上结合<history>分析'
# f'{old_profile_content}'
)
print("用户画像提示词",user_input)
url = 'http://192.168.1.138:11434/api/chat'
data = {
"model": "llama3",
"messages": [
{
"role": "user",
"content": user_input
}
],
"stream": False
}
headers = {'Content-Type': 'application/json'}
try:
response = requests.post(url, data=json.dumps(data), headers=headers)
if response.status_code == 200:
# print(response,response.json())
response_data = response.json()
messages = response_data['message']
content = messages['content']
print("profile_content",content)
save_profile(user_id,content)
return content
else:
print(f'Request failed with status code {response.status_code}')
return None
except Exception as e:
print(f'Error: {e}')
return None
def set_profile(user_id):
from app import db
from app.drunk.models import Profile
profile = db.session.query(Profile).filter_by(user_id=user_id).first()
# 如果找到了 Profile,返回其内容
if profile:
return {
'content': profile.content,
'user_id': profile.user_id,
'username': profile.username,
'age': profile.age,
'like': profile.like,
'dislike': profile.dislike,
'always': profile.always,
'sometimes': profile.sometimes
}
# 如果没有找到对应的 Profile,返回 None 或一个默认值
return {
'content': '',
'user_id': user_id,
'username': '',
'age': '',
'like': '',
'dislike': '',
'always': '',
'sometimes': ''
}
def save_profile(user_id, content):
from app import db
from app.drunk.models import Profile
import json
try:
content_dict = json.loads(content)
except json.JSONDecodeError as e:
print(f'JSON decode error: {e}')
return {"error": "Invalid JSON content"}
new_profile = Profile(
content=content,
user_id=user_id,
age=content_dict.get('age', ''),
like=', '.join(content_dict.get('like', [])) if content_dict.get('like') else '',
dislike=content_dict.get('dislike', '') if content_dict.get('dislike') else '',
always=content_dict.get('always', '') if content_dict.get('always') else '',
sometimes=', '.join(content_dict.get('sometimes', [])) if content_dict.get('sometimes') else ''
)
try:
db.session.add(new_profile)
db.session.commit()
except Exception as e:
db.session.rollback()
print(f'Error saving profile: {e}')
return {"error": "Error saving profile"}
def get_emotion(docx):
current_dir = os.path.dirname(__file__)
model_path = os.path.join(current_dir, '..', 'static', 'emotion_classifier_pipe_lr.pkl')
# 加载模型
pipe_lr = joblib.load(open(model_path, "rb"))
results = pipe_lr.predict([docx])
return results[0]
@drunk_bp.route('/test', methods=['GET'])
def test_db_get_conversations():
return get_new_profile(1)