python异步调用大模型,实现语音电话功能

python异步调用大模型,实现语音电话功能

本文使用fastapi框架,异步传入用户语音输入,并流式返回大模型输出,实现语音电话的功能

接口定义

首先定义一个websocket接口

@router.websocket("/chat/voice_call")
async def voice_chat(ws: WebSocket,db=Depends(get_db), redis=Depends(get_redis)):
    await ws.accept()
    await voice_call_handler(ws,db,redis)

再定义接口数据帧交互格式

  • 数据帧格式
data = {
    "audio": [str], #base64编码后的音频数据切片
    "meta_info": {
        "session_id": [str], #会话_id
        "encoding": [str] #压缩类型,暂时只有raw
    },
    "is_close":[bool] #当准备结束连接时发送True,正常连接时为False
}
  • 数据帧示例
ws_data = {
    "audio" : "/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA...",
    "meta_info":{
        "session_id":"28445e6d-e8c1-46a6-b980-fbf39b918def",
        "encoding": 'raw'
    },
    "is_close" : False
}
  • 返回说明

返回有两种形式,一种是返回文本信息,一种是返回二进制流音频信息

  • 文本信息
参数名称参数类型参数说明
typestring说明返回帧类型,仅有类型,“error”,表示出现error
codeint200为正常返回,500为异常返回
msgstring返回帧的信息
  • 错误帧示例
{"type": "error", "code": 500, "msg": "wrong frame"}

整体思路

在这里插入图片描述

工具函数

#获取session内容
def get_session_content(session_id,redis,db):
    session_content_str = ""
    if redis.exists(session_id):
        session_content_str = redis.get(session_id)
    else:
        session_db = db.query(Session).filter(Session.id == session_id).first()
        if not session_db:
            raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
        session_content_str = session_db.content
    return json.loads(session_content_str)

#解析大模型流式返回内容
def parseChunkDelta(chunk):
    decoded_data = chunk.decode('utf-8')
    parsed_data = json.loads(decoded_data[6:])
    if 'delta' in parsed_data['choices'][0]:
        delta_content = parsed_data['choices'][0]['delta']
        return delta_content['content']
    else:
        return ""

#断句函数
def split_string_with_punctuation(current_sentence,text,is_first):
    result = []
    for char in text:
        current_sentence += char
        if is_first and char in ',.?!,。?!':
            result.append(current_sentence)
            current_sentence = ''
            is_first = False
        elif char in '。?!':
            result.append(current_sentence)
            current_sentence = ''
    return result, current_sentence, is_first

#vad预处理,语音活性检测数据必须为1280长度的字符串
def vad_preprocess(audio):
    if len(audio)<1280:
        return ('A'*1280)
    return audio[:1280],audio[1280:]

VAD类

import webrtcvad
import base64

class VAD():
    def __init__(self, vad_sensitivity=1, frame_duration=30, vad_buffer_size=7, min_act_time=1, RATE=16000,**kwargs):
        self.RATE = RATE
        self.vad = webrtcvad.Vad(vad_sensitivity)
        self.vad_buffer_size = vad_buffer_size
        self.vad_chunk_size = int(self.RATE * frame_duration / 1000)
        self.min_act_time = min_act_time    # 最小活动时间,单位秒

    def is_speech(self,data):
        byte_data = base64.b64decode(data)
        return self.vad.is_speech(byte_data, self.RATE)

创建异步队列、异步事件以及future

audio_q = asyncio.Queue()	#音频队列
asr_result_q = asyncio.Queue()	#语音识别结果队列
llm_response_q = asyncio.Queue()	#大模型返回队列
split_result_q = asyncio.Queue()	#断句结果队列

input_finished_event = asyncio.Event()	#用户输入结束事件
asr_finished_event = asyncio.Event()	#语音识别结束事件
llm_finished_event = asyncio.Event()	#大模型结束事件
split_finished_event = asyncio.Event()	#断句结束事件
voice_call_end_event = asyncio.Event()	#语音电话终止事件

future = asyncio.Future()	#用于获取传输的session_id

用户输入处理函数

async def voice_call_audio_producer(ws,audio_queue,future,input_finished_event):
    logger.debug("音频数据生产函数启动")
    is_future_done = False
    audio_data = ""
    try:
        while not input_finished_event.is_set():
            voice_call_data_json = json.loads(await ws.receive_text())
            if not is_future_done: #在第一次循环中读取session_id
                future.set_result(voice_call_data_json['meta_info']['session_id'])
                is_future_done = True
            if voice_call_data_json["is_close"]:
                input_finished_event.set()
                break
            else:
                audio_data += voice_call_data_json["audio"]
                while len(audio_data) > 1280:
                    vad_frame,audio_data = vad_preprocess(audio_data)
                    await audio_queue.put(vad_frame) #将音频数据存入audio_q
    except KeyError as ke:
        logger.info(f"收到心跳包")

语音识别函数

async def voice_call_audio_consumer(audio_q,asr_result_q,input_finished_event,asr_finished_event):
    logger.debug("音频数据消费者函数启动")
    vad = VAD()
    current_message = ""
    vad_count = 0
    while not (input_finished_event.is_set() and audio_q.empty()):
        audio_data = await audio_q.get()
        if vad.is_speech(audio_data):
            if vad_count > 0:
                vad_count -= 1
            asr_result = asr.streaming_recognize(audio_data)
            current_message += ''.join(asr_result['text'])
        else:
            vad_count += 1
            if vad_count >= 25: #连续25帧没有语音,则认为说完了
                asr_result = asr.streaming_recognize(audio_data, is_end=True)
                if current_message:
                    logger.debug(f"检测到静默,用户输入为:{current_message}")
                    await asr_result_q.put(current_message)
                current_message = ""
                vad_count = 0
    asr_finished_event.set()

大模型调用函数

async def voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,asr_finished_event,llm_finished_event):
    logger.debug("asr结果消费以及llm返回生产函数启动")
    while not (asr_finished_event.is_set() and asr_result_q.empty()):
        session_content = get_session_content(session_id,redis,db)
        messages = json.loads(session_content["messages"])
        current_message = await asr_result_q.get()
        messages.append({'role': 'user', "content": current_message})
        payload = json.dumps({
            "model": llm_info["model"],
            "stream": True,
            "messages": messages,
            "max_tokens":10000,
            "temperature": llm_info["temperature"],
            "top_p": llm_info["top_p"]
        })
        
        headers = {
            'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
            'Content-Type': 'application/json'
        }
        response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True)
        if response.status_code == 200:
            for chunk in response.iter_lines():
                if chunk:
                    chunk_data =parseChunkDelta(chunk)
                    llm_frame = {'message':chunk_data,'is_end':False}
                    await llm_response_q.put(llm_frame)
            llm_frame = {'message':"",'is_end':True}
            await llm_response_q.put(llm_frame)
    llm_finished_event.set()

断句函数

async def voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event):
    logger.debug("语音合成及返回函数启动")
    while not (split_finished_event.is_set() and split_result_q.empty()):
        sentence = await split_result_q.get()
        sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
        text_response = {"type": "text", "code": 200, "msg": sentence}
        await ws.send_bytes(audio) #返回音频二进制流数据
        await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
        logger.debug(f"websocket返回:{sentence}")
    asyncio.sleep(0.5)
    await ws.close()
    voice_call_end_event.set()

语音合成函数

async def voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event):
    logger.debug("语音合成及返回函数启动")
    while not (split_finished_event.is_set() and split_result_q.empty()):
        sentence = await split_result_q.get()
        sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
        text_response = {"type": "text", "code": 200, "msg": sentence}
        await ws.send_bytes(audio) #返回音频二进制流数据
        await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
        logger.debug(f"websocket返回:{sentence}")
    asyncio.sleep(0.5)
    await ws.close()
    voice_call_end_event.set()

语音电话处理函数

async def voice_call_handler(ws, db, redis):
    logger.debug("voice_call websocket 连接建立")
    audio_q = asyncio.Queue()	#音频队列
    asr_result_q = asyncio.Queue()	#语音识别结果队列
    llm_response_q = asyncio.Queue()	#大模型返回队列
    split_result_q = asyncio.Queue()	#断句结果队列

    input_finished_event = asyncio.Event()	#用户输入结束事件
    asr_finished_event = asyncio.Event()	#语音识别结束事件
    llm_finished_event = asyncio.Event()	#大模型结束事件
    split_finished_event = asyncio.Event()	#断句结束事件
    voice_call_end_event = asyncio.Event()	#语音电话终止事件

    future = asyncio.Future()	#用于获取传输的session_id
    asyncio.create_task(voice_call_audio_producer(ws,audio_q,future,input_finished_event)) #创建音频数据生产者
    asyncio.create_task(voice_call_audio_consumer(audio_q,asr_result_q,input_finished_event,asr_finished_event)) #创建音频数据消费者

    #获取session内容
    session_id = await future #获取session_id
    tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
    llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])

    asyncio.create_task(voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,asr_finished_event,llm_finished_event)) #创建llm处理者
    asyncio.create_task(voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,llm_finished_event,split_finished_event)) #创建llm断句结果
    asyncio.create_task(voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event)) #返回tts音频结果

    while not voice_call_end_event.is_set():
        await asyncio.sleep(3)
    await ws.close()
    logger.debug("voice_call websocket 连接断开")

完整代码

最后贴一版完整代码

注:无法直接使用,仅提供思路

import webrtcvad
import base64

class VAD():
    def __init__(self, vad_sensitivity=1, frame_duration=30, vad_buffer_size=7, min_act_time=1, RATE=16000,**kwargs):
        self.RATE = RATE
        self.vad = webrtcvad.Vad(vad_sensitivity)
        self.vad_buffer_size = vad_buffer_size
        self.vad_chunk_size = int(self.RATE * frame_duration / 1000)
        self.min_act_time = min_act_time    # 最小活动时间,单位秒

    def is_speech(self,data):
        byte_data = base64.b64decode(data)
        return self.vad.is_speech(byte_data, self.RATE)
    
def get_session_content(session_id,redis,db):
    session_content_str = ""
    if redis.exists(session_id):
        session_content_str = redis.get(session_id)
    else:
        session_db = db.query(Session).filter(Session.id == session_id).first()
        if not session_db:
            raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
        session_content_str = session_db.content
    return json.loads(session_content_str)

#解析大模型流式返回内容
def parseChunkDelta(chunk):
    decoded_data = chunk.decode('utf-8')
    parsed_data = json.loads(decoded_data[6:])
    if 'delta' in parsed_data['choices'][0]:
        delta_content = parsed_data['choices'][0]['delta']
        return delta_content['content']
    else:
        return ""

#断句函数
def split_string_with_punctuation(current_sentence,text,is_first):
    result = []
    for char in text:
        current_sentence += char
        if is_first and char in ',.?!,。?!':
            result.append(current_sentence)
            current_sentence = ''
            is_first = False
        elif char in '。?!':
            result.append(current_sentence)
            current_sentence = ''
    return result, current_sentence, is_first

#vad预处理
def vad_preprocess(audio):
    if len(audio)<1280:
        return ('A'*1280)
    return audio[:1280],audio[1280:]


#音频数据生产函数
async def voice_call_audio_producer(ws,audio_q,future,input_finished_event):
    logger.debug("音频数据生产函数启动")
    is_future_done = False
    audio_data = ""
    try:
        while not input_finished_event.is_set():
            voice_call_data_json = json.loads(await ws.receive_text())
            if not is_future_done: #在第一次循环中读取session_id
                future.set_result(voice_call_data_json['meta_info']['session_id'])
                is_future_done = True
            if voice_call_data_json["is_close"]:
                input_finished_event.set()
                break
            else:
                audio_data += voice_call_data_json["audio"]
                while len(audio_data) > 1280:
                    vad_frame,audio_data = vad_preprocess(audio_data)
                    await audio_q.put(vad_frame) #将音频数据存入audio_q
    except KeyError as ke:
        logger.info(f"收到心跳包")

#音频数据消费函数
async def voice_call_audio_consumer(audio_q,asr_result_q,input_finished_event,asr_finished_event):
    logger.debug("音频数据消费者函数启动")
    vad = VAD()
    current_message = ""
    vad_count = 0
    while not (input_finished_event.is_set() and audio_q.empty()):
        audio_data = await audio_q.get()
        if vad.is_speech(audio_data):
            if vad_count > 0:
                vad_count -= 1
            asr_result = asr.streaming_recognize(audio_data)
            current_message += ''.join(asr_result['text'])
        else:
            vad_count += 1
            if vad_count >= 25: #连续25帧没有语音,则认为说完了
                asr_result = asr.streaming_recognize(audio_data, is_end=True)
                if current_message:
                    logger.debug(f"检测到静默,用户输入为:{current_message}")
                    await asr_result_q.put(current_message)
                current_message = ""
                vad_count = 0
    asr_finished_event.set()

#asr结果消费以及llm返回生产函数
async def voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,asr_finished_event,llm_finished_event):
    logger.debug("asr结果消费以及llm返回生产函数启动")
    while not (asr_finished_event.is_set() and asr_result_q.empty()):
        session_content = get_session_content(session_id,redis,db)
        messages = json.loads(session_content["messages"])
        current_message = await asr_result_q.get()
        messages.append({'role': 'user', "content": current_message})
        payload = json.dumps({
            "model": llm_info["model"],
            "stream": True,
            "messages": messages,
            "max_tokens":10000,
            "temperature": llm_info["temperature"],
            "top_p": llm_info["top_p"]
        })
        
        headers = {
            'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
            'Content-Type': 'application/json'
        }
        response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True)
        if response.status_code == 200:
            for chunk in response.iter_lines():
                if chunk:
                    chunk_data =parseChunkDelta(chunk)
                    llm_frame = {'message':chunk_data,'is_end':False}
                    await llm_response_q.put(llm_frame)
            llm_frame = {'message':"",'is_end':True}
            await llm_response_q.put(llm_frame)
    llm_finished_event.set()

#llm结果返回函数
async def voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,llm_finished_event,split_finished_event):
    logger.debug("llm结果返回函数启动")
    llm_response = ""
    current_sentence = ""
    is_first = True
    while not (llm_finished_event.is_set() and llm_response_q.empty()):
        llm_frame = await llm_response_q.get()
        llm_response += llm_frame['message']
        sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,llm_frame['message'],is_first)
        for sentence in sentences:
            await split_result_q.put(sentence)
        if llm_frame['is_end']:
            is_first = True
            session_content = get_session_content(session_id,redis,db)
            messages = json.loads(session_content["messages"])
            messages.append({'role': 'assistant', "content": llm_response})
            session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
            redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
            logger.debug(f"llm返回结果: {llm_response}")
            llm_response = ""
            current_sentence = ""
    split_finished_event.set()
        
#语音合成及返回函数
async def voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event):
    logger.debug("语音合成及返回函数启动")
    while not (split_finished_event.is_set() and split_result_q.empty()):
        sentence = await split_result_q.get()
        sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
        text_response = {"type": "text", "code": 200, "msg": sentence}
        await ws.send_bytes(audio) #返回音频二进制流数据
        await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
        logger.debug(f"websocket返回:{sentence}")
    asyncio.sleep(0.5)
    await ws.close()
    voice_call_end_event.set()


async def voice_call_handler(ws, db, redis):
    logger.debug("voice_call websocket 连接建立")
    audio_q = asyncio.Queue()	#音频队列
    asr_result_q = asyncio.Queue()	#语音识别结果队列
    llm_response_q = asyncio.Queue()	#大模型返回队列
    split_result_q = asyncio.Queue()	#断句结果队列

    input_finished_event = asyncio.Event()	#用户输入结束事件
    asr_finished_event = asyncio.Event()	#语音识别结束事件
    llm_finished_event = asyncio.Event()	#大模型结束事件
    split_finished_event = asyncio.Event()	#断句结束事件
    voice_call_end_event = asyncio.Event()	#语音电话终止事件

    future = asyncio.Future()	#用于获取传输的session_id
    asyncio.create_task(voice_call_audio_producer(ws,audio_q,future,input_finished_event)) #创建音频数据生产者
    asyncio.create_task(voice_call_audio_consumer(audio_q,asr_result_q,input_finished_event,asr_finished_event)) #创建音频数据消费者

    #获取session内容
    session_id = await future #获取session_id
    tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
    llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])

    asyncio.create_task(voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,asr_finished_event,llm_finished_event)) #创建llm处理者
    asyncio.create_task(voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,llm_finished_event,split_finished_event)) #创建llm断句结果
    asyncio.create_task(voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event)) #返回tts音频结果

    while not voice_call_end_event.is_set():
        await asyncio.sleep(3)
    await ws.close()
    logger.debug("voice_call websocket 连接断开")
  • 6
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值