这段代码是 chat/manager.py
文件中的实现。它主要涉及到 WebSocket 连接的管理、聊天消息的处理和与聊天逻辑的集成。以下,我会详细讲解代码的每一部分,然后分析它所处的系统架构中的层次,以及它为何要被设计为当前结构。
代码的层次结构与目的
这段代码属于中间逻辑层,也可以称为服务管理层。它负责管理 WebSocket 连接的会话状态、消息的存储、消息的流式处理,以及与业务逻辑的实际集成。
- WebSocket 连接管理:这个类管理着所有的 WebSocket 连接,例如如何接受连接、保持连接、发送消息和关闭连接。
- 聊天会话管理:它维护了会话的状态,包括历史消息、流式结果的处理、缓存管理等。
- 与后端业务逻辑的集成:通过调用不同的服务类和缓存机制,它将业务逻辑的处理(如对话生成、文件处理等)整合进聊天的实际流程。
设计上,这样做的目的是将WebSocket 连接管理与具体的业务逻辑分开,避免在 WebSocket 管理的过程中涉及到过多具体的细节,保持代码的模块化和可维护性。
代码分析
这段代码主要由两个类 ChatHistory
和 ChatManager
组成,它们各自承担了不同的职责。
1. ChatHistory
类
这个类负责管理聊天的历史记录。它实现了一个订阅-通知机制,可以在消息到来时通知有关对象。
__init__(self)
: 初始化聊天历史对象,使用defaultdict(list)
来存储不同对话的历史。add_message(self, client_id: str, chat_id: str, message: ChatMessage)
: 这是一个关键方法,用于将新消息加入聊天历史中。它的主要流程是:- 记录消息的元数据,如
flow_id
和chat_id
。 - 将消息存储到数据库中(持久化操作)。
- 如果不是文件消息 (
FileResponse
),则调用notify()
,以便通知所有观察者聊天记录有更新。
- 记录消息的元数据,如
empty_history(self, client_id: str, chat_id: str)
: 清空某个client_id
和chat_id
相关的聊天历史。
2. ChatManager
类
这个类是代码的核心,主要负责WebSocket 连接的管理和聊天的处理逻辑。它包括以下部分:
2.1 属性
active_connections
: 记录当前所有活跃的 WebSocket 连接,以client_id
和chat_id
作为键,WebSocket
对象作为值。chat_history
: 一个ChatHistory
对象,管理所有聊天的历史记录。cache_manager
: 一个缓存管理器,用于缓存数据。这里的缓存管理器和观察者模式结合使用,当缓存更新时会调用update()
方法。active_clients
: 存储当前活跃的聊天客户端,用于管理每一个正在使用 WebSocket 的聊天会话。stream_queue
: 维护每个连接的流式输出队列,保证消息可以以流的方式发出。
2.2 方法
- WebSocket 连接管理
connect()
: 建立新的 WebSocket 连接,并将其添加到active_connections
中。disconnect()
: 根据客户端 ID 或连接键,移除对应的 WebSocket 连接。send_message()
/send_json()
: 向指定客户端发送文本或 JSON 格式的消息。close_connection()
: 关闭 WebSocket 连接并移除相关记录。
- 聊天逻辑处理
accept_client()
和clear_client()
: 这些方法用于管理active_clients
中的活跃客户端,接受新的客户端或者移除断开的客户端。dispatch_client()
: 处理聊天客户端的核心逻辑:- 接受客户端的连接并创建一个
ChatClient
对象来管理会话。 - 进入一个循环,不断监听来自客户端的消息,通过 WebSocket 进行交互。
- 调用
chat_client.handle_message(payload)
来处理接收到的消息。
- 接受客户端的连接并创建一个
handle_websocket()
: 这是 WebSocket 核心处理函数。- 首先通过
connect()
方法建立 WebSocket 连接。 - 然后进入一个循环,处理来自客户端的 JSON 数据,判断当前消息是否是新消息、是否有文件上传等。
- 最终根据情况调用
_process_when_payload()
来处理消息,完成处理任务并发送响应。
- 首先通过
_process_when_payload()
: 这是一个内部辅助函数,用于处理有效的消息内容。它会检查聊天上下文中的状态,执行必要的步骤,调用外部方法进行实际业务逻辑的处理。
- 流式数据处理
init_langchain_object_task()
:用于初始化 LangChain 对象的任务,实际上是为了准备好聊天过程中使用的各种模型和对象。在 WebSocket 连接开始时就需要这个对象来为整个对话过程服务。
2.3 使用的工具
- 缓存管理器:缓存用来存储聊天上下文,以提高访问速度。
- 观察者模式:
ChatHistory
继承自Subject
,使得聊天历史可以在发生改变时通知相关的观察者。 - 多线程池:通过
ThreadPoolManager
进行多线程管理,处理复杂、耗时的聊天生成任务,保证 WebSocket 主线程不被阻塞。
为什么这样设计
- 模块化管理:将
ChatHistory
和ChatManager
进行拆分,分别负责聊天历史和 WebSocket 连接管理,符合单一职责原则,使得代码更加易于维护和扩展。 - 高并发管理:由于 WebSocket 是一个长连接协议,必须考虑到高并发情况下如何处理大量的连接。代码中通过字典存储
active_connections
和stream_queue
,实现了对大量连接的快速访问和管理。 - 异步与并行化:使用
asyncio
和ThreadPoolManager
,保证长时间处理不会阻塞主线程,这对于聊天生成任务尤为重要,因为这些任务可能需要调用耗时的模型来生成对话。 - 观察者模式:将聊天历史作为观察者对象,这使得系统中的其他部分能够对聊天历史的变化做出反应,比如向用户实时推送消息。
- 缓存策略:通过
InMemoryCache
提供的缓存功能,可以快速存取聊天上下文和生成的对话内容,减少重复计算,提升性能。
总结
ChatManager
类:这个类主要负责 WebSocket 连接的管理、聊天的核心处理逻辑,以及与后端模型或业务逻辑的交互。它处于整个聊天系统的服务层,封装了与聊天有关的底层管理逻辑和与客户端的通信逻辑。- 层级关系:
- API 层:FastAPI 中的
chat.py
定义了 WebSocket 路由和请求处理函数。 - 服务层:
ChatManager
和ChatClient
是服务层,负责处理具体的业务逻辑,如管理 WebSocket 连接、与聊天模型交互等。 - 数据层:数据库交互部分,通过
ChatMessageDao
、UserDao
等与数据库打交道,实现持久化存储。
- API 层:FastAPI 中的
ChatHistory
ChatHistory
类在这段代码中扮演着一个核心角色,它用于存储和管理用户的聊天历史记录,并确保在聊天过程中能够持续跟踪和管理消息的流动。下面是对 ChatHistory
类的详细解释。
类结构
class ChatHistory(Subject):
def __init__(self):
super().__init__()
self.history: Dict[str, List[ChatMessage]] = defaultdict(list)
def add_message(
self,
client_id: str,
chat_id: str,
message: ChatMessage,
):
"""Add a message to the chat history."""
t1 = time.time()
from bisheng.database.models.message import ChatMessage
message.flow_id = client_id
message.chat_id = chat_id
if chat_id and (message.message or message.intermediate_steps
or message.files) and message.type != 'stream':
msg = message.copy()
msg.message = json.dumps(msg.message) if isinstance(msg.message, dict) else msg.message
files = json.dumps(msg.files) if msg.files else ''
msg.__dict__.pop('files')
db_message = ChatMessage(files=files, **msg.__dict__)
logger.info(f'chat={db_message} time={time.time() - t1}')
with session_getter() as seesion:
seesion.add(db_message)
seesion.commit()
seesion.refresh(db_message)
message.message_id = db_message.id
if not isinstance(message, FileResponse):
self.notify()
def empty_history(self, client_id: str, chat_id: str):
"""Empty the chat history for a client."""
self.history[get_cache_key(client_id, chat_id)] = []
详细解析
1. 继承自 Subject
类
ChatHistory
继承自 Subject
类,Subject
通常是一个设计模式中的概念,常用于观察者模式中。在这里,ChatHistory
可能是作为一个“被观察者”来使用的,意味着它可以通知其他部分的程序(例如聊天客户端)有新的消息到来,或者历史记录发生了变化。
2. __init__
方法
def __init__(self):
super().__init__()
self.history: Dict[str, List[ChatMessage]] = defaultdict(list)
- 在
__init__
方法中,super().__init__()
表示调用父类Subject
的初始化方法。 self.history
是一个字典,用来存储特定client_id
和chat_id
下的聊天历史。它的键是通过client_id
和chat_id
生成的缓存键(get_cache_key(client_id, chat_id)
),值是一个ChatMessage
类型的消息列表。defaultdict(list)
确保如果访问的键不存在时,返回的是一个空的列表。
3. add_message
方法
def add_message(
self,
client_id: str,
chat_id: str,
message: ChatMessage,
):
"""Add a message to the chat history."""
t1 = time.time()
from bisheng.database.models.message import ChatMessage
message.flow_id = client_id
message.chat_id = chat_id
if chat_id and (message.message or message.intermediate_steps
or message.files) and message.type != 'stream':
msg = message.copy()
msg.message = json.dumps(msg.message) if isinstance(msg.message, dict) else msg.message
files = json.dumps(msg.files) if msg.files else ''
msg.__dict__.pop('files')
db_message = ChatMessage(files=files, **msg.__dict__)
logger.info(f'chat={db_message} time={time.time() - t1}')
with session_getter() as seesion:
seesion.add(db_message)
seesion.commit()
seesion.refresh(db_message)
message.message_id = db_message.id
if not isinstance(message, FileResponse):
self.notify()
add_message
方法的主要作用是将传入的ChatMessage
消息对象保存到聊天历史中,并将其存入数据库。- 参数说明:
client_id
:标识客户端的唯一 ID。chat_id
:标识当前聊天会话的唯一 ID。message
:要保存的消息对象,是一个ChatMessage
类型的实例。
步骤解析:
- 处理消息内容:首先,代码会根据
message.message
和message.files
的内容来判断是否需要保存。如果消息类型不是流式消息(type != 'stream'
),才会进行保存。 - 消息复制和数据格式化:
- 使用
message.copy()
创建一个消息的副本。 - 将
message.message
和message.files
转换为 JSON 格式字符串,因为存入数据库的数据通常是字符串格式。 - 从消息对象中删除不需要保存到数据库的
files
属性(通过msg.__dict__.pop('files')
)。
- 使用
- 保存到数据库:
- 创建一个
ChatMessage
数据库模型对象db_message
,将消息内容保存到数据库中。 - 使用
session_getter()
来获取数据库会话(假设session_getter()
是一个数据库会话的获取方法),然后将消息对象db_message
添加到数据库会话中,并提交和刷新数据库,以便获得数据库中已保存的消息 ID(message.message_id
)。
- 创建一个
- 通知更新:如果消息类型不是
FileResponse
(文件响应),调用self.notify()
通知历史记录有了更新。这个通知机制通常会触发其他部分的更新逻辑。
4. empty_history
方法
def empty_history(self, client_id: str, chat_id: str):
"""Empty the chat history for a client."""
self.history[get_cache_key(client_id, chat_id)] = []
empty_history
方法用来清空某个特定客户端(client_id
)和聊天会话(chat_id
)的聊天历史。- 它通过
get_cache_key(client_id, chat_id)
生成的键,找到对应的历史记录并将其清空(设置为空列表)。
总结
ChatHistory
类的主要职责是管理和维护聊天历史记录,它将消息保存到内存和数据库,并能够清空历史记录。- 保存消息:在
add_message
方法中,聊天消息会被保存到内存(self.history
)和数据库(ChatMessage
表)。 - 通知机制:通过继承自
Subject
类,ChatHistory
具备了向观察者发送通知的能力,确保其他部分(如聊天客户端)能够在消息变动时进行响应。 - 聊天历史清理:
empty_history
方法可以用于清空某个客户端和会话的聊天历史,通常用于会话结束或重新初始化时。
这种设计结构有助于将聊天历史的存储和管理与聊天逻辑的其他部分分离开,提升了系统的可维护性和扩展性。
ChatManager
ChatManager
类详细解析
ChatManager
类是 chat/manager.py
文件中的核心组件,负责管理和处理与 WebSocket 连接相关的所有聊天逻辑。它不仅处理 WebSocket 连接的建立与关闭,还管理聊天历史记录、缓存、并发任务处理等关键功能。以下是对 ChatManager
类的详细解析,包括其各个方法的功能、工作流程以及在系统架构中的位置和设计原因。
类定义
class ChatManager:
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.chat_history = ChatHistory()
self.cache_manager = cache_manager
self.cache_manager.attach(self.update)
self.in_memory_cache = InMemoryCache()
self.task_manager: List[asyncio.Task] = []
# 已连接的客户端
self.active_clients: Dict[str, ChatClient] = {}
# 记录流式输出结果
self.stream_queue: Dict[str, Queue] = {}
属性说明
active_connections: Dict[str, WebSocket]
- 用途:存储当前所有活跃的 WebSocket 连接。
- 键:通过
client_id
和chat_id
生成的唯一缓存键(get_cache_key(client_id, chat_id)
)。 - 值:对应的
WebSocket
对象。
chat_history = ChatHistory()
- 用途:管理聊天历史记录,通过
ChatHistory
类实现消息的存储和通知机制。
- 用途:管理聊天历史记录,通过
cache_manager = cache_manager
- 用途:引用缓存管理器,用于管理和更新缓存数据。
- 附加操作:通过
self.cache_manager.attach(self.update)
将update
方法作为观察者附加到缓存管理器上,实现缓存更新时的通知机制。
in_memory_cache = InMemoryCache()
- 用途:管理内存中的缓存数据,提供快速的数据存取能力。
task_manager: List[asyncio.Task] = []
- 用途:存储当前正在执行的异步任务,便于管理和取消任务。
active_clients: Dict[str, ChatClient] = {}
- 用途:记录已连接的聊天客户端,键为唯一的客户端键(
client_key
),值为ChatClient
对象。
- 用途:记录已连接的聊天客户端,键为唯一的客户端键(
stream_queue: Dict[str, Queue] = {}
- 用途:维护每个连接的流式输出队列,确保消息能够以流的方式发送给客户端。
方法详解
1. update
方法
def update(self):
if self.cache_manager.current_client_id in self.active_connections:
self.last_cached_object_dict = self.cache_manager.get_last()
# Add a new ChatResponse with the data
chat_response = FileResponse(
message=None,
type='file',
data=self.last_cached_object_dict['obj'],
data_type=self.last_cached_object_dict['type'],
)
self.chat_history.add_message(self.cache_manager.current_client_id,
self.cache_manager.current_chat_id, chat_response)
- 功能:当缓存管理器更新时,此方法被触发。它检查当前客户端是否在活跃连接中,如果是,则获取最新的缓存数据,并将其作为
FileResponse
消息添加到聊天历史中。 - 详细流程:
- 检查
current_client_id
是否在active_connections
中。 - 获取最新的缓存数据
last_cached_object_dict
。 - 构建一个
FileResponse
消息对象。 - 将该消息添加到
chat_history
中。
- 检查
2. 连接管理方法
a. connect
方法
async def connect(self, client_id: str, chat_id: str, websocket: WebSocket):
await websocket.accept()
self.active_connections[get_cache_key(client_id, chat_id)] = websocket
self.stream_queue[get_cache_key(client_id, chat_id)] = Queue()
- 功能:接受并记录新的 WebSocket 连接。
- 参数:
client_id
: 客户端的唯一标识符。chat_id
: 聊天会话的唯一标识符。websocket
:WebSocket
对象。
- 详细流程:
- 调用
await websocket.accept()
接受 WebSocket 连接。 - 将 WebSocket 连接记录到
active_connections
中。 - 为该连接创建一个新的消息队列
stream_queue
。
- 调用
b. reuse_connect
方法
def reuse_connect(self, client_id: str, chat_id: str, websocket: WebSocket):
self.active_connections[get_cache_key(client_id, chat_id)] = websocket
self.stream_queue[get_cache_key(client_id, chat_id)] = Queue()
- 功能:复用已有的连接,将新的 WebSocket 对象关联到现有的
client_id
和chat_id
。 - 详细流程:
- 更新
active_connections
中的 WebSocket 对象。 - 创建新的消息队列
stream_queue
。
- 更新
c. disconnect
方法
def disconnect(self, client_id: str, chat_id: str, key: str = None):
if key:
logger.debug('disconnect_ws key={}', key)
self.active_connections.pop(key, None)
else:
logger.info('disconnect_ws key={}', get_cache_key(client_id, chat_id))
self.active_connections.pop(get_cache_key(client_id, chat_id), None)
- 功能:断开并移除 WebSocket 连接。
- 参数:
client_id
: 客户端的唯一标识符。chat_id
: 聊天会话的唯一标识符。key
: 可选参数,直接指定缓存键以断开连接。
- 详细流程:
- 如果提供了
key
,直接根据key
从active_connections
中移除对应的 WebSocket。 - 如果没有提供
key
,根据client_id
和chat_id
生成缓存键,并移除对应的 WebSocket。
- 如果提供了
3. 消息发送方法
a. send_message
方法
async def send_message(self, client_id: str, chat_id: str, message: str):
websocket = self.active_connections[get_cache_key(client_id, chat_id)]
await websocket.send_text(message)
- 功能:向指定客户端发送文本消息。
- 参数:
client_id
: 客户端的唯一标识符。chat_id
: 聊天会话的唯一标识符。message
: 要发送的文本消息。
- 详细流程:
- 根据
client_id
和chat_id
生成缓存键,获取对应的WebSocket
对象。 - 使用
await websocket.send_text(message)
发送文本消息。
- 根据
b. send_json
方法
async def send_json(self, client_id: str, chat_id: str, message: ChatMessage, add=True):
message.flow_id = client_id
message.chat_id = chat_id
websocket = self.active_connections[get_cache_key(client_id, chat_id)]
# 增加消息记录
if add:
self.chat_history.add_message(client_id, chat_id, message)
await websocket.send_json(message.dict())
- 功能:向指定客户端发送 JSON 格式的消息。
- 参数:
client_id
: 客户端的唯一标识符。chat_id
: 聊天会话的唯一标识符。message
: 要发送的ChatMessage
对象。add
: 可选参数,是否将消息添加到聊天历史中(默认True
)。
- 详细流程:
- 设置
message
的flow_id
和chat_id
。 - 获取对应的
WebSocket
对象。 - 如果
add
为True
,调用chat_history.add_message
将消息保存到历史记录中。 - 使用
await websocket.send_json(message.dict())
发送 JSON 格式的消息。
- 设置
4. 连接关闭方法
async def close_connection(self,
flow_id: str,
chat_id: str,
code: int,
reason: str,
key_list: List[str] = None):
"""close and clean ws"""
if websocket := self.active_connections[get_cache_key(flow_id, chat_id)]:
try:
await websocket.close(code=code, reason=reason)
self.disconnect(flow_id, chat_id)
if key_list:
for key in key_list:
self.disconnect(flow_id, chat_id, key)
except RuntimeError as exc:
# This is to catch the following error:
# Unexpected ASGI message 'websocket.close', after sending 'websocket.close'
if 'after sending' in str(exc):
logger.error(exc)
- 功能:关闭指定的 WebSocket 连接,并清理相关记录。
- 参数:
flow_id
: 流程的唯一标识符。chat_id
: 聊天会话的唯一标识符。code
: 关闭连接的状态码。reason
: 关闭连接的原因。key_list
: 可选参数,一组缓存键,用于批量断开连接。
- 详细流程:
- 根据
flow_id
和chat_id
生成缓存键,获取对应的WebSocket
对象。 - 调用
await websocket.close(code=code, reason=reason)
关闭 WebSocket 连接。 - 调用
disconnect(flow_id, chat_id)
从active_connections
中移除该连接。 - 如果提供了
key_list
,循环调用disconnect
方法断开列表中的所有连接。 - 捕获
RuntimeError
,记录错误日志(例如,当连接已经关闭时)。
- 根据
5. 心跳方法
async def ping(self, client_id: str, chat_id: str):
ping_pong = ChatMessage(
is_bot=True,
message='pong',
intermediate_steps='',
)
await self.send_json(client_id, chat_id, ping_pong, False)
- 功能:向指定客户端发送心跳消息,用于保持连接活跃或验证连接状态。
- 参数:
client_id
: 客户端的唯一标识符。chat_id
: 聊天会话的唯一标识符。
- 详细流程:
- 构建一个
ChatMessage
对象,内容为'pong'
。 - 调用
send_json
方法发送该消息,参数add=False
表示不将此消息添加到聊天历史中。
- 构建一个
6. 缓存管理方法
a. set_cache
方法
def set_cache(self, client_id: str, langchain_object: Any) -> bool:
"""
Set the cache for a client.
"""
self.in_memory_cache.set(client_id, langchain_object)
return client_id in self.in_memory_cache
- 功能:为指定的客户端设置缓存数据。
- 参数:
client_id
: 客户端的唯一标识符。langchain_object
: 要缓存的对象,通常与聊天逻辑相关。
- 详细流程:
- 调用
in_memory_cache.set
方法设置缓存。 - 返回
client_id
是否成功设置在缓存中。
- 调用
b. accept_client
和 clear_client
方法
async def accept_client(self, client_key: str, chat_client: ChatClient, websocket: WebSocket):
await websocket.accept()
self.active_clients[client_key] = chat_client
def clear_client(self, client_key: str):
if client_key not in self.active_clients:
logger.warning('close_client client_key={} not in active_clients', client_key)
return
logger.info('close_client client_key={}', client_key)
self.active_clients.pop(client_key, None)
accept_client
方法:- 功能:接受新的客户端连接并记录相关的
ChatClient
对象。 - 参数:
client_key
: 客户端的唯一键。chat_client
:ChatClient
对象,管理具体的聊天会话。websocket
:WebSocket
对象。
- 详细流程:
- 调用
await websocket.accept()
接受 WebSocket 连接。 - 将
ChatClient
对象记录到active_clients
中。
- 调用
- 功能:接受新的客户端连接并记录相关的
clear_client
方法:- 功能:清除指定的客户端记录。
- 参数:
client_key
: 客户端的唯一键。
- 详细流程:
- 检查
client_key
是否存在于active_clients
中。 - 如果存在,移除该客户端的记录。
- 如果不存在,记录警告日志。
- 检查
c. close_client
方法
async def close_client(self, client_key: str, code: int, reason: str):
if chat_client := self.active_clients.get(client_key):
try:
await chat_client.websocket.close(code=code, reason=reason)
self.clear_client(client_key)
except RuntimeError as exc:
# This is to catch the following error:
# Unexpected ASGI message 'websocket.close', after sending 'websocket.close'
if 'after sending' in str(exc):
logger.error(exc)
- 功能:关闭特定客户端的 WebSocket 连接,并清除相关记录。
- 参数:
client_key
: 客户端的唯一键。code
: 关闭连接的状态码。reason
: 关闭连接的原因。
- 详细流程:
- 获取
ChatClient
对象。 - 调用
await chat_client.websocket.close(code=code, reason=reason)
关闭连接。 - 调用
clear_client(client_key)
移除客户端记录。 - 捕获并记录
RuntimeError
异常,防止因重复关闭连接导致错误。
- 获取
7. dispatch_client
方法
async def dispatch_client(
self,
request: Request, # 原始请求体
client_id: str,
chat_id: str,
login_user: UserPayload,
work_type: WorkType,
websocket: WebSocket,
graph_data: dict = None):
client_key = uuid.uuid4().hex
chat_client = ChatClient(request,
client_key,
client_id,
chat_id,
login_user.user_id,
login_user,
work_type,
websocket,
graph_data=graph_data)
await self.accept_client(client_key, chat_client, websocket)
logger.debug(
f'act=accept_client client_key={client_key} client_id={client_id} chat_id={chat_id}')
try:
while True:
try:
json_payload_receive = await asyncio.wait_for(websocket.receive_json(),
timeout=2.0)
except asyncio.TimeoutError:
continue
try:
payload = json.loads(json_payload_receive) if json_payload_receive else {}
except TypeError:
payload = json_payload_receive
# client内部处理自己的业务逻辑
# TODO zgq:这里可以增加线程池防止阻塞
await chat_client.handle_message(payload)
except WebSocketDisconnect as e:
logger.info('act=rcv_client_disconnect {}', str(e))
except IgnoreException:
# client 内部自己关闭了ws链接,并无异常的情况
pass
except Exception as e:
# Handle any exceptions that might occur
logger.exception(str(e))
await self.close_client(client_key,
code=status.WS_1011_INTERNAL_ERROR,
reason='后端未知错误类型')
finally:
try:
await self.close_client(client_key,
code=status.WS_1000_NORMAL_CLOSURE,
reason='Client disconnected')
except Exception as e:
logger.exception(e)
self.clear_client(client_key)
- 功能:管理和分发客户端的消息处理流程,确保每个客户端的消息能够被正确地处理。
- 参数:
request
: 原始的 HTTP 请求对象。client_id
: 客户端的唯一标识符。chat_id
: 聊天会话的唯一标识符。login_user
: 当前登录的用户对象,包含用户的身份信息。work_type
: 聊天工作的类型。websocket
:WebSocket
对象。graph_data
: 流程数据,可能用于初始化聊天逻辑。
- 详细流程:
- 生成唯一客户端键:使用
uuid.uuid4().hex
生成一个唯一的client_key
。 - 创建
ChatClient
对象:实例化一个ChatClient
,传入所有必要的参数。 - 接受客户端连接:调用
await self.accept_client(client_key, chat_client, websocket)
接受并记录客户端连接。 - 日志记录:记录客户端连接的日志信息。
- 消息接收循环:
- 进入一个无限循环,持续监听来自客户端的消息。
- 使用
asyncio.wait_for
设置超时时间(2秒),等待客户端发送 JSON 格式的消息。 - 解析收到的消息载荷
payload
。 - 调用
chat_client.handle_message(payload)
处理接收到的消息。
- 异常处理:
- 捕获
WebSocketDisconnect
异常,记录日志。 - 捕获
IgnoreException
异常,表示客户端主动关闭连接,无需处理。 - 捕获其他异常,记录日志并关闭连接,发送错误代码和原因。
- 捕获
- 最终处理:
- 无论是否发生异常,最后尝试关闭客户端连接并清除客户端记录。
- 生成唯一客户端键:使用
8. handle_websocket
方法
async def handle_websocket(
self,
flow_id: str,
chat_id: str,
websocket: WebSocket,
user_id: int,
gragh_data: dict = None,
):
# 建立连接,并存储映射,兼容不复用ws 场景
key_list = set([get_cache_key(flow_id, chat_id)])
await self.connect(flow_id, chat_id, websocket)
# autogen_pool = ThreadPoolManager(max_workers=1, thread_name_prefix='autogen')
context_dict = {
get_cache_key(flow_id, chat_id): {
'status': 'init',
'has_file': False,
'flow_id': flow_id,
'chat_id': chat_id
}
}
payload = {}
base_param = {
'user_id': user_id,
'flow_id': flow_id,
'chat_id': chat_id,
'type': 'end',
'category': 'system'
}
try:
while True:
try:
json_payload_receive = await asyncio.wait_for(websocket.receive_json(),
timeout=2.0)
except asyncio.TimeoutError:
json_payload_receive = ''
try:
payload = json.loads(json_payload_receive) if json_payload_receive else {}
except TypeError:
payload = json_payload_receive
# websocket multi use
if payload and 'flow_id' in payload:
chat_id = payload.get('chat_id')
flow_id = payload.get('flow_id')
key = get_cache_key(flow_id, chat_id)
if key not in key_list:
gragh_data, message = self.preper_reuse_connection(
flow_id, chat_id, websocket)
context_dict.update({
key: {
'status': 'init',
'has_file': False,
'flow_id': flow_id,
'chat_id': chat_id
}
})
if message:
logger.info('act=new_chat message={}', message)
erro_resp = ChatResponse(intermediate_steps=message, **base_param)
erro_resp.category = 'error'
await self.send_json(flow_id, chat_id, erro_resp, add=False)
continue
logger.info('act=new_chat_init_success key={}', key)
key_list.add(key)
if not payload.get('inputs'):
continue
# 判断当前是否是空循环
process_param = {
'autogen_pool': thread_pool,
'user_id': user_id,
'payload': payload,
'graph_data': gragh_data,
'context_dict': context_dict
}
if payload:
await self._process_when_payload(flow_id, chat_id, **process_param)
else:
for v in context_dict.values():
if v['status'] != 'init':
await self._process_when_payload(v['flow_id'], v['chat_id'],
**process_param)
# 处理任务状态
complete_normal = await thread_pool.as_completed(key_list)
complete = complete_normal
# if async_task and async_task.done():
# logger.debug(f'async_task_complete result={async_task.result()}')
if complete:
for future_key, future in complete:
try:
future.result()
logger.debug('task_complete key={}', future_key)
except Exception as e:
if isinstance(e, concurrent.futures.CancelledError):
continue
logger.exception('feature_key={} {}', future_key, e)
erro_resp = ChatResponse(**base_param)
context = context_dict.get(future_key)
if context.get('status') == 'init':
erro_resp.intermediate_steps = f'LLM 技能执行错误. error={str(e)}'
elif context.get('has_file'):
erro_resp.intermediate_steps = f'文档解析失败,点击输入框上传按钮重新上传\n\n{str(e)}'
else:
erro_resp.intermediate_steps = f'Input data is parsed fail. error={str(e)}'
context['status'] = 'init'
await self.send_json(context.get('flow_id'), context.get('chat_id'),
erro_resp)
erro_resp.type = 'close'
await self.send_json(context.get('flow_id'), context.get('chat_id'),
erro_resp)
except WebSocketDisconnect as e:
logger.info('act=rcv_client_disconnect {}', str(e))
except Exception as e:
# Handle any exceptions that might occur
logger.exception(str(e))
await self.close_connection(flow_id=flow_id,
chat_id=chat_id,
code=status.WS_1011_INTERNAL_ERROR,
reason='后端未知错误类型',
key_list=key_list)
finally:
thread_pool.cancel_task(key_list) # 将进行中的任务进行cancel
try:
await self.close_connection(flow_id=flow_id,
chat_id=chat_id,
code=status.WS_1000_NORMAL_CLOSURE,
reason='Client disconnected',
key_list=key_list)
except Exception as e:
logger.exception(e)
self.disconnect(flow_id, chat_id)
- 功能:处理 WebSocket 连接的主要逻辑,包括消息接收、消息处理、任务管理和异常处理。
- 参数:
flow_id
: 流程的唯一标识符。chat_id
: 聊天会话的唯一标识符。websocket
:WebSocket
对象。user_id
: 当前用户的 ID。gragh_data
: 流程数据,用于初始化聊天逻辑。
- 详细流程:
- 建立连接并初始化:
- 生成缓存键并添加到
key_list
集合。 - 调用
await self.connect(flow_id, chat_id, websocket)
接受 WebSocket 连接并记录。 - 初始化
context_dict
,用于跟踪聊天会话的状态。 - 初始化
payload
和base_param
,这些参数用于构建响应消息。
- 生成缓存键并添加到
- 主处理循环:
- 接收消息:
- 使用
asyncio.wait_for
设置超时(2秒),等待客户端发送 JSON 格式的消息。 - 如果超时,继续等待下一个消息。
- 解析收到的消息载荷
payload
。
- 使用
- 处理多用途连接:
- 检查
payload
中是否包含新的flow_id
,如果是,则可能是复用连接。 - 如果是新的会话,调用
preper_reuse_connection
方法准备复用连接,并更新context_dict
。 - 如果消息中没有
inputs
,则跳过处理。
- 检查
- 消息处理:
- 构建
process_param
,包括autogen_pool
、user_id
、payload
、graph_data
和context_dict
。 - 根据是否有
payload
,调用_process_when_payload
方法处理消息。 - 如果没有
payload
,则循环处理所有非初始化状态的会话。
- 构建
- 任务完成检查:
- 使用
thread_pool.as_completed(key_list)
检查任务是否完成。 - 对于完成的任务,获取结果,如果有异常,构建错误响应并发送给客户端。
- 使用
- 接收消息:
- 异常处理:
- WebSocketDisconnect:记录客户端断开连接的日志。
- IgnoreException:忽略特定的异常,表示客户端主动关闭连接。
- 其他异常:记录日志,关闭连接并发送错误状态码和原因。
- 最终处理:
- 取消所有未完成的任务。
- 尝试关闭 WebSocket 连接,发送正常关闭状态码和原因。
- 清除客户端记录。
- 建立连接并初始化:
9. _process_when_payload
方法
async def _process_when_payload(self, flow_id: str, chat_id: str,
autogen_pool: ThreadPoolManager, **kwargs):
"""
Process the incoming message and send the response.
"""
# set start
user_id = kwargs.get('user_id')
graph_data = kwargs.get('graph_data')
payload = kwargs.get('payload')
key = get_cache_key(flow_id, chat_id)
context = kwargs.get('context_dict').get(key)
status_ = context.get('status')
if payload and status_ != 'init':
logger.error('act=input_before_complete payload={} status={}', payload, status_)
if not payload:
payload = context.get('payload')
context['payload'] = payload
is_begin = bool(status_ == 'init' and 'action' not in payload)
base_param = {'user_id': user_id, 'flow_id': flow_id, 'chat_id': chat_id}
start_resp = ChatResponse(type='begin', category='system', **base_param)
if is_begin:
await self.send_json(flow_id, chat_id, start_resp)
# 判断下是否是首次创建会话
if chat_id:
res = ChatMessageDao.get_messages_by_chat_id(chat_id=chat_id)
if len(res) <= 1: # 说明是新建会话
websocket = self.active_connections[key]
login_user = UserPayload(**{
'user_id': user_id,
'user_name': UserDao.get_user(user_id).user_name,
})
AuditLogService.create_chat_flow(login_user, get_request_ip(websocket),
flow_id)
start_resp.type = 'start'
# should input data
step_resp = ChatResponse(type='end', category='system', **base_param)
langchain_obj_key = get_cache_key(flow_id, chat_id)
if status_ == 'init':
has_file, graph_data = await self.preper_payload(payload, graph_data,
langchain_obj_key, flow_id, chat_id,
start_resp, step_resp)
status_ = 'init_object'
context.update({'status': status_})
context.update({'has_file': has_file})
# build in thread
if not self.in_memory_cache.get(langchain_obj_key) and status_ == 'init_object':
thread_pool.submit(key,
self.init_langchain_object_task,
flow_id,
chat_id,
user_id,
graph_data,
trace_id=chat_id)
status_ = 'waiting_object'
context.update({'status': status_})
# run in thread
if payload and self.in_memory_cache.get(langchain_obj_key):
action, over = await self.preper_action(flow_id, chat_id, langchain_obj_key, payload,
start_resp, step_resp)
logger.debug(
f"processing_message message={payload.get('inputs')} action={action} over={over}")
if not over:
# task_service: 'TaskService' = get_task_service()
# async_task = asyncio.create_task(
# task_service.launch_task(Handler().dispatch_task, self, client_id,
# chat_id, action, payload, user_id))
from bisheng_langchain.chains.autogen.auto_gen import AutoGenChain
from bisheng.chat.handlers import Handler
params = {
'session': self,
'client_id': flow_id,
'chat_id': chat_id,
'action': action,
'payload': payload,
'user_id': user_id,
'trace_id': chat_id
}
if isinstance(self.in_memory_cache.get(langchain_obj_key), AutoGenChain):
# autogen chain
logger.info(f'autogen_submit {langchain_obj_key}')
autogen_pool.submit(key,
Handler(stream_queue=self.stream_queue[key]).dispatch_task,
**params)
else:
thread_pool.submit(key,
Handler(stream_queue=self.stream_queue[key]).dispatch_task,
**params)
status_ = 'init'
context.update({'status': status_})
context.update({'payload': {}}) # clean message
- 功能:处理接收到的有效负载(payload),执行相应的聊天逻辑,并发送响应。
- 参数:
flow_id
: 流程的唯一标识符。chat_id
: 聊天会话的唯一标识符。autogen_pool
: 线程池管理器,用于提交并发任务。**kwargs
: 其他关键字参数,包括user_id
、payload
、graph_data
、context_dict
等。
- 详细流程:
- 获取参数:
- 从
kwargs
中提取user_id
、graph_data
、payload
、context_dict
等信息。 - 获取缓存键
key
。 - 获取当前会话的上下文
context
。
- 从
- 状态检查:
- 检查当前会话的状态
status_
是否为'init'
,以及payload
是否包含'action'
。 - 如果状态不为
'init'
,记录错误日志。
- 检查当前会话的状态
- 处理
payload
:- 如果没有
payload
,从上下文中获取之前的payload
。 - 记录当前的
payload
到上下文中。 - 判断是否为会话的开始 (
is_begin
)。
- 如果没有
- 发送开始响应:
- 构建一个
ChatResponse
对象start_resp
,类型为'begin'
。 - 如果是会话的开始,发送
start_resp
消息。 - 检查是否是首次创建会话,如果是,则记录审计日志。
- 将
start_resp
类型更新为'start'
。
- 构建一个
- 准备处理数据:
- 构建另一个
ChatResponse
对象step_resp
,类型为'end'
。 - 生成缓存键
langchain_obj_key
。 - 如果状态为
'init'
,调用preper_payload
方法处理payload
,更新graph_data
和has_file
。 - 将状态更新为
'init_object'
,并更新上下文。
- 构建另一个
- 初始化 LangChain 对象:
- 如果缓存中没有 LangChain 对象,并且状态为
'init_object'
,则提交任务到线程池以初始化 LangChain 对象。 - 将状态更新为
'waiting_object'
,并更新上下文。
- 如果缓存中没有 LangChain 对象,并且状态为
- 处理动作:
- 如果有
payload
并且缓存中存在 LangChain 对象,调用preper_action
方法处理动作,返回action
和over
。 - 根据
action
和over
的值,决定是否需要提交新的任务到线程池处理。 - 根据
langchain_obj_key
的类型(例如AutoGenChain
),选择合适的处理方式,提交任务到对应的线程池。 - 将状态更新为
'init'
,并清除payload
。
- 如果有
- 获取参数:
10. 辅助方法
a. preper_reuse_connection
方法
def preper_reuse_connection(self, flow_id: str, chat_id: str, websocket: WebSocket):
# 设置复用的映射关系
message = ''
with session_getter() as session:
gragh_data = session.get(Flow, flow_id)
if not gragh_data:
message = '该技能已被删除'
if gragh_data.status != 2:
message = '当前技能未上线,无法直接对话'
gragh_data = gragh_data.data
self.reuse_connect(flow_id, chat_id, websocket)
return gragh_data, message
- 功能:准备复用已有的连接,检查流程的有效性,并返回流程数据和可能的错误消息。
- 参数:
flow_id
: 流程的唯一标识符。chat_id
: 聊天会话的唯一标识符。websocket
:WebSocket
对象。
- 详细流程:
- 初始化
message
为一个空字符串。 - 使用数据库会话获取流程对象
gragh_data
。 - 检查流程是否存在,若不存在,设置错误消息。
- 检查流程的状态是否为 2(假设 2 表示“上线”),若不是,设置错误消息。
- 获取流程的数据
gragh_data.data
。 - 调用
reuse_connect
方法复用 WebSocket 连接。 - 返回
gragh_data
和message
。
- 初始化
b. preper_payload
方法
async def preper_payload(self, payload, graph_data, langchain_obj_key, client_id, chat_id,
start_resp: ChatResponse, step_resp: ChatResponse):
has_file = False
has_variable = False
if 'inputs' in payload and ('data' in payload['inputs']
or 'file_path' in payload['inputs']):
node_data = payload['inputs'].get('data', '') or [payload['inputs']]
graph_data = self.refresh_graph_data(graph_data, node_data)
# 上传文件就重新build,有点粗, 改为只有document loader 需要
node_loader = False
for nod in node_data:
if any('Loader' in x['id'] for x in find_next_node(graph_data, nod['id'])):
node_loader = True
break
if node_loader:
self.set_cache(langchain_obj_key, None) # rebuild object
has_file = any(['InputFile' in nd.get('id', '') for nd in node_data])
has_variable = any(['VariableNode' in nd.get('id', '') for nd in node_data])
if has_file:
step_resp.intermediate_steps = '文件上传完成,开始解析'
await self.send_json(client_id, chat_id, start_resp)
await self.send_json(client_id, chat_id, step_resp, add=False)
await self.send_json(client_id, chat_id, start_resp)
logger.info('input_file start_log')
await asyncio.sleep(-1) # 快速的跳过
elif has_variable:
await self.send_json(client_id, chat_id, start_resp)
logger.info('input_variable start_log')
await asyncio.sleep(-1) # 快速的跳过
return has_file, graph_data
- 功能:处理和准备接收到的消息载荷
payload
,检查是否包含文件或变量,并根据情况更新流程数据和状态。 - 参数:
payload
: 接收到的消息载荷。graph_data
: 当前流程的图数据。langchain_obj_key
: 缓存键,用于标识 LangChain 对象。client_id
: 客户端的唯一标识符。chat_id
: 聊天会话的唯一标识符。start_resp
: 开始响应消息对象。step_resp
: 步骤响应消息对象。
- 详细流程:
- 初始化
has_file
和has_variable
为False
。 - 检查
payload
中是否包含inputs
,并且是否包含data
或file_path
。 - 获取
node_data
,即消息中的数据或文件路径。 - 调用
refresh_graph_data
方法更新graph_data
,处理上传的文件。 - 检查是否有需要重新构建的节点(例如,文件加载器)。
- 如果有文件加载器,调用
set_cache
方法清除缓存,准备重新构建对象。 - 检查
node_data
中是否包含文件或变量节点,并更新has_file
和has_variable
。 - 根据是否包含文件或变量,发送相应的响应消息并记录日志。
- 返回
has_file
和更新后的graph_data
。
- 初始化
c. preper_action
方法
async def preper_action(self, client_id, chat_id, langchain_obj_key, payload,
start_resp: ChatResponse, step_resp: ChatResponse):
langchain_obj = self.in_memory_cache.get(langchain_obj_key)
batch_question = []
action = ''
over = False
if isinstance(langchain_obj, Report):
action = 'report'
step_resp.intermediate_steps = '文件解析完成,开始生成报告'
await self.send_json(client_id, chat_id, step_resp)
elif payload.get('action') == 'stop':
action = 'stop'
elif 'action' in payload:
action = 'autogen'
elif 'clear_history' in payload and payload['clear_history']:
self.chat_history.empty_history(client_id, chat_id)
action = 'clear_history'
over = True
elif 'data' in payload['inputs'] or 'file_path' in payload['inputs']:
action = 'auto_file'
batch_question = self.in_memory_cache.get(langchain_obj_key + '_question')
payload['inputs']['questions'] = batch_question
if not batch_question:
# no question
file_msg = payload['inputs']
file_msg.pop('id', '')
file_msg.pop('data', '')
file = ChatMessage(flow_id=client_id,
chat_id=chat_id,
is_bot=False,
message=file_msg,
type='end',
user_id=step_resp.user_id)
self.chat_history.add_message(client_id, chat_id, file)
step_resp.message = ''
step_resp.intermediate_steps = '文件解析完成'
await self.send_json(client_id, chat_id, step_resp)
start_resp.type = 'close'
await self.send_json(client_id, chat_id, start_resp)
over = True
else:
step_resp.intermediate_steps = '文件解析完成,开始执行'
await self.send_json(client_id, chat_id, step_resp, add=False)
await asyncio.sleep(-1) # 快速的跳过
return action, over
- 功能:根据消息载荷
payload
确定要执行的动作,并更新相应的响应消息。 - 参数:
client_id
: 客户端的唯一标识符。chat_id
: 聊天会话的唯一标识符。langchain_obj_key
: 缓存键,用于标识 LangChain 对象。payload
: 接收到的消息载荷。start_resp
: 开始响应消息对象。step_resp
: 步骤响应消息对象。
- 详细流程:
- 获取缓存中的 LangChain 对象
langchain_obj
。 - 初始化
batch_question
、action
和over
。 - 根据 langchain_obj的类型和 payload中的内容,确定要执行的动作:
- 报告生成 (
Report
):- 设置
action
为'report'
。 - 更新
step_resp
的intermediate_steps
,发送响应消息。
- 设置
- 停止动作 (
'stop'
):- 设置
action
为'stop'
。
- 设置
- 自动生成 (
'autogen'
):- 设置
action
为'autogen'
。
- 设置
- 清空历史 (
'clear_history'
):- 调用
chat_history.empty_history
清空聊天历史。 - 设置
action
为'clear_history'
,并设置over=True
。
- 调用
- 文件自动处理 (
'auto_file'
):- 设置
action
为'auto_file'
。 - 获取
batch_question
,并将其添加到payload['inputs']['questions']
。 - 如果没有问题(
batch_question
为空),构建文件消息,添加到聊天历史,并发送关闭响应。 - 如果有问题,更新
step_resp
,发送响应消息。
- 设置
- 报告生成 (
- 使用
await asyncio.sleep(-1)
快速跳过,实际实现中这可能是为了模拟阻塞或等待某些条件。 - 返回
action
和over
的状态。
- 获取缓存中的 LangChain 对象
11. init_langchain_object_task
方法
async def init_langchain_object_task(self, flow_id, chat_id, user_id, graph_data):
key_node = get_cache_key(flow_id, chat_id)
logger.info(f'init_langchain build_begin key={key_node}')
with session_getter() as session:
db_user = session.get(User, user_id) # 用来支持节点判断用户权限
artifacts = {}
start_time = time.time()
graph = await build_flow_no_yield(graph_data=graph_data,
artifacts=artifacts,
process_file=True,
flow_id=UUID(flow_id).hex,
chat_id=chat_id,
user_name=db_user.user_name)
await graph.abuild()
logger.info(f'init_langchain build_end timecost={time.time() - start_time}')
question = []
for node in graph.vertices:
if node.vertex_type in {'InputNode', 'AudioInputNode', 'FileInputNode'}:
question_parse = await node.get_result()
if isinstance(question_parse, list):
question.extend(question_parse)
else:
question.append(question_parse)
self.set_cache(key_node + '_question', question)
input_nodes = graph.get_input_nodes()
for node in input_nodes:
# 只存储chain
if node.base_type == 'inputOutput' and node.vertex_type != 'Report':
continue
self.set_cache(key_node, await node.get_result())
self.set_cache(key_node + '_artifacts', artifacts)
return flow_id, chat_id
- 功能:异步初始化 LangChain 对象,用于处理复杂的聊天逻辑,如自然语言理解、对话生成等。
- 参数:
flow_id
: 流程的唯一标识符。chat_id
: 聊天会话的唯一标识符。user_id
: 当前用户的 ID。graph_data
: 流程数据,用于初始化 LangChain 对象。
- 详细流程:
- 生成缓存键:使用
get_cache_key(flow_id, chat_id)
生成唯一的缓存键key_node
。 - 获取用户信息:从数据库中获取
User
对象db_user
,用于支持节点判断用户权限。 - 构建流程图:
- 调用
build_flow_no_yield
方法,传入流程数据和其他参数,构建流程图对象graph
。 - 调用
await graph.abuild()
,异步构建流程图。
- 调用
- 日志记录:记录流程图构建的开始和结束时间,计算耗时。
- 提取问题:
- 遍历
graph.vertices
,查找InputNode
、AudioInputNode
、FileInputNode
类型的节点。 - 调用
await node.get_result()
获取解析后的问题,添加到question
列表中。
- 遍历
- 更新缓存:
- 将解析后的问题列表
question
添加到缓存中。 - 遍历
graph.get_input_nodes()
,获取输入节点,并将其结果存入缓存中。
- 将解析后的问题列表
- 返回:返回
flow_id
和chat_id
。
- 生成缓存键:使用
12. refresh_graph_data
方法
def refresh_graph_data(self, graph_data: dict, node_data: List[dict]):
tweak = process_node_data(node_data)
"""upload file to make flow work"""
return process_tweaks(graph_data, tweaks=tweak)
- 功能:根据上传的文件或数据更新流程图数据。
- 参数:
graph_data
: 当前流程的图数据。node_data
: 节点数据,可能包含上传的文件或其他输入数据。
- 详细流程:
- 调用
process_node_data(node_data)
处理节点数据,生成tweak
。 - 调用
process_tweaks(graph_data, tweaks=tweak)
更新graph_data
,返回更新后的图数据。
- 调用
ChatManager 类的设计层次与架构
系统架构中的层次
在一个典型的多层架构系统中,ChatManager
类位于服务层(Service Layer)**或**应用逻辑层(Application Logic Layer)。这一层负责处理具体的业务逻辑,协调不同的组件和服务,确保系统的各个部分协同工作。
- API 层(Controller Layer):
- 负责处理 HTTP 请求和 WebSocket 连接,将请求转发给服务层。
- 示例:
get_app_chat_list
和chat
路由处理函数。
- 服务层(Service Layer):
- 负责具体的业务逻辑处理,如管理 WebSocket 连接、处理消息、管理聊天历史、协调缓存和数据库操作等。
- 示例:
ChatManager
类。
- 数据层(Data Layer):
- 负责与数据库的交互,进行数据的持久化和查询。
- 示例:
ChatMessageDao
、UserDao
等数据访问对象(DAO)。
- 缓存层(Cache Layer):
- 负责管理系统的缓存,提高数据访问的效率。
- 示例:
cache_manager
、InMemoryCache
。
- 工具层(Utility Layer):
- 提供系统所需的各种工具函数和辅助功能。
- 示例:
get_cache_key
、build_flow_no_yield
等。
为什么这样设计
- 单一职责原则:
- 每个类和方法只负责特定的职责,使得代码更易于理解、测试和维护。
- 例如,
ChatHistory
只负责管理聊天历史记录,ChatManager
负责管理 WebSocket 连接和聊天逻辑。
- 模块化与可维护性:
- 将不同的功能模块分离,使得各个部分可以独立开发和维护。
- 例如,
ChatManager
类与ChatHistory
类、缓存管理器、数据库访问对象等分离,降低了模块之间的耦合度。
- 高并发与异步处理:
- 使用
asyncio
和线程池(ThreadPoolManager
)来处理高并发和耗时任务,确保 WebSocket 连接的高效处理。 - 例如,
dispatch_client
方法中使用asyncio.wait_for
监听消息,并通过线程池处理耗时任务,避免阻塞主线程。
- 使用
- 缓存与性能优化:
- 使用缓存机制(如
InMemoryCache
)来加速数据访问,提高系统性能。 - 通过
cache_manager
和in_memory_cache
实现高效的数据存取和共享。
- 使用缓存机制(如
- 扩展性与灵活性:
- 通过设计模式(如观察者模式)和依赖注入(如 FastAPI 的
Depends
),使得系统具有良好的扩展性和灵活性。 - 例如,
ChatHistory
通过继承Subject
类实现观察者模式,可以在消息添加时自动通知其他组件。
- 通过设计模式(如观察者模式)和依赖注入(如 FastAPI 的
- 错误处理与稳定性:
- 通过异常处理机制,确保在 WebSocket 连接过程中出现错误时能够优雅地关闭连接并记录错误日志,提升系统的稳定性。
- 例如,
handle_websocket
方法中捕获并处理各种异常,确保连接不会因为未处理的错误而导致系统崩溃。
- 日志记录与监控:
- 通过使用日志记录(如
loguru
),实现对系统运行状态和关键事件的监控,便于问题追踪和调试。 - 例如,在关键步骤记录日志信息,如连接建立、消息处理、任务完成等。
- 通过使用日志记录(如
ChatManager
类的主要职责
- WebSocket 连接管理:
- 接受、记录、断开 WebSocket 连接。
- 管理每个连接的消息队列和状态。
- 聊天消息处理:
- 接收并解析来自客户端的消息。
- 根据消息内容执行相应的聊天逻辑,如生成回复、处理文件上传等。
- 将处理后的消息发送回客户端,并记录到聊天历史中。
- 并发任务管理:
- 使用异步任务和线程池处理耗时的聊天生成任务,确保系统能够高效地处理大量并发请求。
- 管理任务的创建、监控和取消,确保任务的正确执行和系统的稳定性。
- 聊天历史与缓存管理:
- 使用
ChatHistory
管理聊天记录,确保消息的持久化和历史查询能力。 - 使用缓存管理器优化数据访问,提升系统性能。
- 使用
- 错误处理与异常管理:
- 捕获并处理 WebSocket 和聊天逻辑中的各种异常,确保系统的稳定性和用户体验。
- 日志记录与审计:
- 记录关键事件和错误日志,支持系统监控和问题排查。
- 通过审计日志服务记录聊天流程的操作,满足审计和合规要求。
总结
ChatManager
类在整个聊天系统中充当着**服务管理器(Service Manager)**的角色,负责管理和协调与 WebSocket 连接相关的所有业务逻辑。通过模块化设计、异步处理、高效的缓存机制和严格的错误处理,它确保了系统能够高效、稳定地处理大量的实时聊天请求。
为何这样设计:
- 高效管理 WebSocket 连接:通过记录和管理所有活跃连接,确保消息能够准确发送到对应的客户端。
- 模块化和可维护性:将不同的职责分离到不同的类和方法中,增强了代码的可读性和可维护性。
- 并发与异步处理:利用
asyncio
和线程池,实现高并发情况下的高效消息处理和任务执行。 - 缓存与性能优化:使用缓存机制减少数据库查询次数,提高系统响应速度。
- 错误处理与日志记录:通过全面的异常捕获和详细的日志记录,提升系统的稳定性和可追溯性。
通过这样的设计,ChatManager
能够有效地管理复杂的聊天逻辑,确保系统在高并发和实时通信的场景下依然保持高性能和稳定性。
dispatch_client
当然,下面我将详细讲解你提供的 dispatch_client
函数。我们将从函数的整体目的开始,然后逐步解析每一部分的具体实现和背后的设计思想。
dispatch_client
函数概述
async def dispatch_client(
self,
request: Request, # 原始请求体
client_id: str,
chat_id: str,
login_user: UserPayload,
work_type: WorkType,
websocket: WebSocket,
graph_data: dict = None):
client_key = uuid.uuid4().hex
chat_client = ChatClient(request,
client_key,
client_id,
chat_id,
login_user.user_id,
login_user,
work_type,
websocket,
graph_data=graph_data)
await self.accept_client(client_key, chat_client, websocket)
logger.debug(
f'act=accept_client client_key={client_key} client_id={client_id} chat_id={chat_id}')
try:
while True:
try:
json_payload_receive = await asyncio.wait_for(websocket.receive_json(),
timeout=2.0)
except asyncio.TimeoutError:
continue
try:
payload = json.loads(json_payload_receive) if json_payload_receive else {}
except TypeError:
payload = json_payload_receive
# client内部处理自己的业务逻辑
# TODO zgq:这里可以增加线程池防止阻塞
await chat_client.handle_message(payload)
except WebSocketDisconnect as e:
logger.info('act=rcv_client_disconnect {}', str(e))
except IgnoreException:
# client 内部自己关闭了ws链接,并无异常的情况
pass
except Exception as e:
# Handle any exceptions that might occur
logger.exception(str(e))
await self.close_client(client_key,
code=status.WS_1011_INTERNAL_ERROR,
reason='后端未知错误类型')
finally:
try:
await self.close_client(client_key,
code=status.WS_1000_NORMAL_CLOSURE,
reason='Client disconnected')
except Exception as e:
logger.exception(e)
self.clear_client(client_key)
函数的整体目的
dispatch_client
函数的主要职责是管理和处理与特定客户端(基于 client_id
和 chat_id
)的 WebSocket 连接。这包括:
- 建立和记录连接:生成唯一的客户端键,创建
ChatClient
实例,并接受 WebSocket 连接。 - 消息处理循环:持续监听和接收来自客户端的消息,并通过
ChatClient
处理这些消息。 - 异常处理:捕获和处理各种可能的异常,确保连接的稳定性和资源的正确释放。
- 清理资源:在连接关闭时,确保相关资源和记录被正确清理。
详细解析
1. 函数签名和参数
async def dispatch_client(
self,
request: Request, # 原始请求体
client_id: str,
chat_id: str,
login_user: UserPayload,
work_type: WorkType,
websocket: WebSocket,
graph_data: dict = None):
- self:表示这是一个类的方法,通常是
ChatManager
类的一个实例方法。 - request:原始的 HTTP 请求对象,包含了请求的所有信息。
- client_id:标识客户端的唯一 ID,用于区分不同的客户端。
- chat_id:标识当前聊天会话的唯一 ID,用于区分不同的聊天对话。
- login_user:包含用户身份信息的对象(
UserPayload
),用于识别和授权用户。 - work_type:聊天工作的类型,可能用于区分不同的聊天模式或逻辑(例如 GPT-3、GPT-4 等)。
- websocket:FastAPI 的
WebSocket
对象,表示与客户端的 WebSocket 连接。 - graph_data:可选参数,包含与聊天相关的流程数据或图数据。
2. 生成唯一的客户端键和创建 ChatClient
实例
client_key = uuid.uuid4().hex
chat_client = ChatClient(request,
client_key,
client_id,
chat_id,
login_user.user_id,
login_user,
work_type,
websocket,
graph_data=graph_data)
-
生成唯一的客户端键 (
client_key
):
- 使用
uuid.uuid4().hex
生成一个唯一的十六进制字符串,确保每个客户端连接都有一个唯一的标识符。这对于在后续管理连接和消息时非常重要。
- 使用
-
创建
ChatClient
实例:
ChatClient
是一个管理具体聊天会话的类,负责处理来自客户端的消息、生成回复等逻辑。- 传递的参数包括请求对象、客户端键、客户端和聊天会话的 ID、用户信息、工作类型、WebSocket 连接和流程数据。
3. 接受客户端连接并记录
await self.accept_client(client_key, chat_client, websocket)
logger.debug(
f'act=accept_client client_key={client_key} client_id={client_id} chat_id={chat_id}')
-
await self.accept_client(...)
:
- 调用
ChatManager
类中的accept_client
方法,接受 WebSocket 连接并将ChatClient
实例记录到active_clients
字典中。
- 调用
-
日志记录
:
- 使用
logger.debug
记录连接接受的日志,包括client_key
、client_id
和chat_id
,方便后续调试和监控。
- 使用
4. 消息处理循环
try:
while True:
try:
json_payload_receive = await asyncio.wait_for(websocket.receive_json(),
timeout=2.0)
except asyncio.TimeoutError:
continue
try:
payload = json.loads(json_payload_receive) if json_payload_receive else {}
except TypeError:
payload = json_payload_receive
# client内部处理自己的业务逻辑
# TODO zgq:这里可以增加线程池防止阻塞
await chat_client.handle_message(payload)
-
无限循环:
- 使用
while True
创建一个无限循环,持续监听和接收来自客户端的消息。
- 使用
-
接收消息:
-
await asyncio.wait_for(websocket.receive_json(), timeout=2.0)
:
- 尝试在 2 秒内接收来自客户端的 JSON 消息。
- 如果在 2 秒内没有接收到消息,会抛出
asyncio.TimeoutError
,并跳过当前循环迭代,继续监听下一个消息。
-
-
异常处理:
-
asyncio.TimeoutError
:
- 捕获超时异常,表示在指定时间内未接收到消息。
- 使用
continue
跳过当前循环,继续等待下一个消息。
-
解析消息
:
- 尝试将接收到的 JSON 字符串解析为 Python 字典对象。
- 如果解析失败(例如消息不是有效的 JSON 格式),则直接使用原始消息内容。
-
-
处理消息:
- 调用
chat_client.handle_message(payload)
方法,处理解析后的消息。 - 这个方法通常包含业务逻辑,如生成回复、调用外部服务等。
- 调用
5. 异常处理
except WebSocketDisconnect as e:
logger.info('act=rcv_client_disconnect {}', str(e))
except IgnoreException:
# client 内部自己关闭了ws链接,并无异常的情况
pass
except Exception as e:
# Handle any exceptions that might occur
logger.exception(str(e))
await self.close_client(client_key,
code=status.WS_1011_INTERNAL_ERROR,
reason='后端未知错误类型')
WebSocketDisconnect
:- 捕获
WebSocketDisconnect
异常,表示客户端断开了连接。 - 记录断开连接的日志信息,通常包括断开原因和相关信息。
- 捕获
IgnoreException
:- 捕获
IgnoreException
,这是一种自定义异常,表示客户端主动关闭连接且没有异常情况。 - 不做任何处理,直接忽略。
- 捕获
- 其他异常:
- 捕获所有其他类型的异常,表示在消息处理过程中发生了未知错误。
- 记录异常的详细信息。
- 调用
close_client
方法,使用WS_1011_INTERNAL_ERROR
状态码关闭 WebSocket 连接,并提供关闭原因。
6. 清理资源
finally:
try:
await self.close_client(client_key,
code=status.WS_1000_NORMAL_CLOSURE,
reason='Client disconnected')
except Exception as e:
logger.exception(e)
self.clear_client(client_key)
finally
块:- 无论前面的代码是否抛出异常,都会执行
finally
块中的内容,确保资源的正确释放。
- 无论前面的代码是否抛出异常,都会执行
- 关闭连接:
- 调用
close_client
方法,使用WS_1000_NORMAL_CLOSURE
状态码关闭连接,并提供关闭原因Client disconnected
。 - 捕获并记录关闭连接过程中可能
- 调用
出现的异常,确保即使在关闭连接时出现问题,也不会影响系统的稳定性。
-
清理客户端记录
:
- 调用
clear_client
方法,移除client_key
对应的客户端记录,从而释放资源。
- 调用
相关方法的详细解释
1. accept_client
async def accept_client(self, client_key: str, chat_client: ChatClient, websocket: WebSocket):
await websocket.accept()
self.active_clients[client_key] = chat_client
-
功能:接受 WebSocket 连接并记录
ChatClient
实例。 -
参数
:
client_key
:唯一的客户端标识符。chat_client
:ChatClient
实例,负责管理具体的聊天会话。websocket
:WebSocket 连接对象。
-
详细流程
:
- 调用
await websocket.accept()
,接受 WebSocket 连接。 - 将
chat_client
实例记录到active_clients
字典中,键为client_key
,值为chat_client
实例。
- 调用
2. close_client
async def close_client(self, client_key: str, code: int, reason: str):
if chat_client := self.active_clients.get(client_key):
try:
await chat_client.websocket.close(code=code, reason=reason)
self.clear_client(client_key)
except RuntimeError as exc:
# This is to catch the following error:
# Unexpected ASGI message 'websocket.close', after sending 'websocket.close'
if 'after sending' in str(exc):
logger.error(exc)
-
功能:关闭特定客户端的 WebSocket 连接,并清理记录。
-
参数
:
client_key
:唯一的客户端标识符。code
:WebSocket 关闭连接的状态码。reason
:关闭连接的原因。
-
详细流程
:
- 从
active_clients
字典中获取对应的ChatClient
实例。 - 如果存在,调用
chat_client.websocket.close(code=code, reason=reason)
关闭 WebSocket 连接。 - 调用
clear_client(client_key)
清除客户端记录。 - 捕获并记录
RuntimeError
异常,避免因重复关闭连接导致的错误。
- 从
3. clear_client
def clear_client(self, client_key: str):
if client_key not in self.active_clients:
logger.warning('close_client client_key={} not in active_clients', client_key)
return
logger.info('close_client client_key={}', client_key)
self.active_clients.pop(client_key, None)
-
功能:清除指定客户端的记录。
-
参数
:
client_key
:唯一的客户端标识符。
-
详细流程
:
- 检查
client_key
是否存在于active_clients
字典中。 - 如果不存在,记录警告日志并返回。
- 如果存在,记录信息日志并从
active_clients
字典中移除对应的ChatClient
实例。
- 检查
ChatClient
类的作用
虽然用户提供的代码片段中没有 ChatClient
类的具体实现,但根据上下文可以推测其主要职责:
- 消息处理:接收和解析来自客户端的消息,执行相应的业务逻辑。
- 回复生成:基于接收到的消息,生成回复并通过 WebSocket 发送回客户端。
- 状态管理:维护当前聊天会话的状态,包括上下文、历史记录等。
- 错误处理:处理消息处理过程中的异常,确保聊天会话的稳定性。
设计思想与最佳实践
- 唯一标识符:
- 使用
uuid.uuid4().hex
生成唯一的client_key
,确保每个客户端连接都有唯一的标识,方便后续管理和追踪。
- 使用
- 异步处理:
- 利用
asyncio
实现异步消息接收和处理,确保高并发情况下的高效性能。
- 利用
- 超时机制:
- 使用
asyncio.wait_for
设置超时时间(2秒),防止长时间无消息接收导致资源占用。 - 通过捕获
asyncio.TimeoutError
,实现轮询机制,持续监听消息。
- 使用
- 异常处理:
- 捕获并处理各种类型的异常,确保系统的稳定性和健壮性。
- 特别处理
WebSocketDisconnect
和IgnoreException
,确保正常的断开连接和自定义异常不影响整体流程。
- 资源清理:
- 在连接关闭时,确保所有相关资源(如
ChatClient
实例)被正确清理,防止资源泄漏。
- 在连接关闭时,确保所有相关资源(如
- 日志记录:
- 通过
logger
记录关键事件和异常信息,便于调试和监控系统运行状态。
- 通过
- 模块化设计:
- 将连接管理和消息处理逻辑封装到
ChatManager
类中,确保代码的模块化和可维护性。 - 使用
ChatClient
类进一步封装具体的聊天逻辑,遵循单一职责原则。
- 将连接管理和消息处理逻辑封装到
总结
dispatch_client
函数在 ChatManager
类中起着核心作用,负责管理与客户端的 WebSocket 连接,持续监听和处理消息,并确保连接的稳定性和资源的正确释放。通过使用异步编程、异常处理、日志记录和模块化设计,确保了系统在高并发和实时通信场景下的高效性和可靠性。