Mini-Omni的贡献是让任意大模型都可以具有“听”和“说”的能力。
✅实时语音到语音对话能力。不需要额外的ASR或TTS模型。
✅边说边思考,能够同时生成文本和音频。
✅流媒体音频输出功能。
✅通过“音频到文本”和“音频到音频”批量推理来进一步提高性能。
工作原理:
一、准备虚拟环境omni,下载代码,安装依赖:
conda create -n omni python=3.10
conda activate omni
git clone https://github.com/gpt-omni/mini-omni.git
cd mini-omni
pip install -r requirements.txt
二、设置下载镜像,下载模型到./checkpoint目录:
export HF_ENDPOINT=https://hf-mirror.com
huggingface-cli download --resume-download gpt-omni/mini-omni --local-dir ./checkpoint
三、默认使用的是GPU(cuda:0),Mac M1没有GPU,Pytorch下Mac M1默认的是mps,但Whisper模型不支持mps运行,咱们使用cpu运行。需要将server.py,inference.py,webui/omni-gradio.py中'cuda:0'的值修改为'cpu'。
四、安装ffmpeg
# 安装 ffmpeg
brew install ffmpeg
# 查看版本
ffmpeg -version
五、运行服务端:
python3 server.py --ip '0.0.0.0' --port 60808
六、运行gradio客户端:
API_URL=http://0.0.0.0:60808/chat python3 webui/omni_gradio.py
七、修改的代码
1、server.py
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from inference import OmniInference
import flask
import base64
import tempfile
import traceback
from flask import Flask, Response, stream_with_context
class OmniChatServer(object):
def __init__(self, ip='0.0.0.0', port=60808, run_app=True,
ckpt_dir='./checkpoint', device='cpu') -> None:
server = Flask(__name__)
# CORS(server, resources=r"/*")
# server.config["JSON_AS_ASCII"] = False
self.client = OmniInference(ckpt_dir, device)
self.client.warm_up()
server.route("/chat", methods=["POST"])(self.chat)
if run_app:
server.run(host=ip, port=port, threaded=False)
else:
self.server = server
def chat(self) -> Response:
req_data = flask.request.get_json()
try:
data_buf = req_data["audio"].encode("utf-8")
data_buf = base64.b64decode(data_buf)
stream_stride = req_data.get("stream_stride", 4)
max_tokens = req_data.get("max_tokens", 2048)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
f.write(data_buf)
audio_generator = self.client.run_AT_batch_stream(f.name, stream_stride, max_tokens)
return Response(stream_with_context(audio_generator), mimetype="audio/wav")
except Exception as e:
print(traceback.format_exc())
# CUDA_VISIBLE_DEVICES=1 gunicorn -w 2 -b 0.0.0.0:60808 'server:create_app()'
def create_app():
server = OmniChatServer(run_app=False)
return server.server
def serve(ip='0.0.0.0', port=60808, device='cpu'):
OmniChatServer(ip, port=port,run_app=True, device=device)
if __name__ == "__main__":
import fire
fire.Fire(serve)
2、inference.py
import os
import lightning as L
import torch
import time
from snac import SNAC
from litgpt import Tokenizer
from litgpt.utils import (
num_parameters,
)
from litgpt.generate.base import (
generate_AA,
generate_ASR,
generate_TA,
generate_TT,
generate_AT,
generate_TA_BATCH,
next_token_batch
)
import soundfile as sf
from litgpt.model import GPT, Config
from lightning.fabric.utilities.load import _lazy_load as lazy_load
from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
from utils.snac_utils import get_snac, generate_audio_data
import whisper
from tqdm import tqdm
from huggingface_hub import snapshot_download
torch.set_printoptions(sci_mode=False)
# TODO
text_vocabsize = 151936
text_specialtokens = 64
audio_vocabsize = 4096
audio_specialtokens = 64
padded_text_vocabsize = text_vocabsize + text_specialtokens
padded_audio_vocabsize = audio_vocabsize + audio_specialtokens
_eot = text_vocabsize
_pad_t = text_vocabsize + 1
_input_t = text_vocabsize + 2
_answer_t = text_vocabsize + 3
_asr = text_vocabsize + 4
_eoa = audio_vocabsize
_pad_a = audio_vocabsize + 1
_input_a = audio_vocabsize + 2
_answer_a = audio_vocabsize + 3
_split = audio_vocabsize + 4
def get_input_ids_TA(text, text_tokenizer):
input_ids_item = [[] for _ in range(8)]
text_tokens = text_tokenizer.encode(text)
for i in range(7):
input_ids_item[i] = [layershift(_pad_a, i)] * (len(text_tokens) + 2) + [
layershift(_answer_a, i)
]
input_ids_item[i] = torch.tensor(input_ids_item[i]).unsqueeze(0)
input_ids_item[-1] = [_input_t] + text_tokens.tolist() + [_eot] + [_answer_t]
input_ids_item[-1] = torch.tensor(input_ids_item[-1]).unsqueeze(0)
return input_ids_item
def get_input_ids_TT(text, text_tokenizer):
input_ids_item = [[] for i in range(8)]
text_tokens = text_tokenizer.encode(text).tolist()
for i in range(7):
input_ids_item[i] = torch.tensor(
[layershift(_pad_a, i)] * (len(text_tokens) + 3)
).unsqueeze(0)
input_ids_item[-1] = [_input_t] + text_tokens + [_eot] + [_answer_t]
input_ids_item[-1] = torch.tensor(input_ids_item[-1]).unsqueeze(0)
return input_ids_item
def get_input_ids_whisper(
mel, leng, whispermodel, device,
special_token_a=_answer_a, special_token_t=_answer_t,
):
with torch.no_grad():
mel = mel.unsqueeze(0).to(device)
# audio_feature = whisper.decode(whispermodel,mel, options).audio_features
audio_feature = whispermodel.embed_audio(mel)[0][:leng]
T = audio_feature.size(0)
input_ids = []
for i in range(7):
input_ids_item = []
input_ids_item.append(layershift(_input_a, i))
input_ids_item += [layershift(_pad_a, i)] * T
input_ids_item += [(layershift(_eoa, i)), layershift(special_token_a, i)]
input_ids.append(torch.tensor(input_ids_item).unsqueeze(0))
input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, special_token_t])
input_ids.append(input_id_T.unsqueeze(0))
return audio_feature.unsqueeze(0), input_ids
def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
with torch.no_grad():
mel = mel.unsqueeze(0).to(device)
# audio_feature = whisper.decode(whispermodel,mel, options).audio_features
audio_feature = whispermodel.embed_audio(mel)[0][:leng]
T = audio_feature.size(0)
input_ids_AA = []
for i in range(7):
input_ids_item = []
input_ids_item.append(layershift(_input_a, i))
input