一起抓宝吧!- 构建一个宝可梦go问答机器人

NVIDIA AI-AGENT夏季训练营

项目名称:一起抓宝吧!- 构建一个宝可梦go问答机器人
报告日期:2024年8月18日
项目负责人:starlette

项目概述

  • 抽取游戏的官方文档,并且根据文档来回答一些用户的问题。
  • 添加了tts功能。

破壳梦·走是一个非常受欢迎的游戏。老少皆宜。但是最近几年添加了太多新功能,导致玩家可能会对玩法感到困惑。官网的介绍是基于关键词来进行检索的。不但要靠人眼来检索,还需要一篇篇去check,到底哪一条知识和我们的需要的知识相关。“一起抓宝吧!”项目旨在构建一个宝可梦GO问答机器人,帮助玩家解答关于游戏的常见问题。该机器人通过抽取官方文档的内容来回答用户的问题,并集成了文本转语音(TTS)功能,提升了用户体验。

技术方案与实施步骤

模型选择(必写)

在pogo问答机器人中,我们选择了由英伟达NIM提供的llama-3-instruct。它作为一个8b量级的轻量级小模型,比较容易在商用级的设别上进行部署。并且它由英文文档训练而成,在处理英文问题上表现出了较高的效率和准确性。

RAG模型采用了text embedding+向量数据库的dense retrieve的方式。由于我的小霸王学习机显卡比较低端,所以文本embedding方面,采用了BGE1.5-en-small系列的小模型。在CPU上也可以方便推理。

向量数据库方面使用了 Chroma,其高效的检索能力支持了RAG模型的快速响应。

数据的构建(必写)

  1. 数据来源:宝可梦GO的官方网站的帮助文档。

  2. 数据处理:我们选用的bge-small的embedding的推荐文本chunk长度是512。对于宝可梦帮助文档,我们没有做更多的处理,因为大部分文章的长度都短于512个字符。

  • 优势:简化了数据预处理,提升了数据检索效率。

功能整合

  • TTS的引入

玩家可能会觉得文本太长了不想读,所以我们引入了英伟达的TTS,将回答的文本转成了语音。

实施步骤

环境搭建

  1. 构建conda环境
conda create -n chatnpc python=3.10
conda activate chatnpc
  1. 必要的库/环境安装
# for send request and webdemo
pip install openai chromadb gradio

# for text embedding
pip install sentence_transformer
  1. TTS功能(可选)

使用了英伟达全家桶中的fastpitch-hifigan-tts来进行推理。软件包安装可以参见这里

说是这么说,但是按照官网的方法安装的不是很成功。debug部分参见下文的问题和解决方案。

代码实现

数据添加

  1. 文档处理

将pogo的文档们变成json文件的格式。

    {
        "file_name": "2473-the-go-battle-league-leaderboard-1688983775.txt",
        "content": "The GO Battle League Leaderboard\n\nThe GO Battle League Leaderboard is a page on the Pok\u00e9mon GO website that displays the top 500 Trainers in the world by their GO Battle League ranking. To access the GO Battle League Leaderboard, go to \npokemongolive.com/leaderboard\nThe Leaderboard lists each top-500 Trainer\u2019s nickname, team, rating, rank, and number of battles played.\n\nThe Leaderboard reflects the previous day\u2019s ratings and is scheduled to refresh every day between approximately 8:00 p.m. GMT and 10:00 p.m. GMT. However, the page may sometimes not update due to maintenance or system issues.\nThe following Trainers will be excluded from the Leaderboard without prior notice:\nTrainers below rank 7\nTrainers with active disciplinary actions\nTrainers with inappropriate nicknames\nTrainers with child accounts\n \nIf you are listed on the Leaderboard and would like to be removed, please contact support through the Pok\u00e9mon GO app:\nFrom the Map View, open the Main Menu \nAt the top right, tap Settings \nTap on Help in the top right corner to access the help center.\nOn iOS, tap Contact Us in the upper right corner. On Android, tap this button: \nOnce you\u2019ve been removed from the Leaderboard, you will no longer be listed unless you request to be added again at a later date.",
        "title": "The GO Battle League Leaderboard"
    },
  1. 将pogo玩法知识添加进数据库。

因为宝可梦游戏玩法知识更新的并不频繁,所以可以一次性将它们全部添加进知识库,而不需要过度考虑修改等操作。这里是相关代码。

def add_documents(client: ClientAPI,
                  data: Dict[str, str],
                  embedding_function: EmbeddingFunction,
                  collection_name: str, 
                  fields_to_add: Union[str, List[str]],
                  required_document_keys: Optional[List[str]]=None,
                  use_uuid_from_file: bool=False,
                  skip_non_string_fields: bool=True) -> chromadb.Collection:
    """
    用预设的embedding_function将玩法数据转换成embedding,并存入collection。
    Args: 
        :param Dict[str, Any] data: 玩法数据
        :param EmbeddingFunction embedding_function: embedding function
        :param str collection_name: collection的名字
        :param str text_to_add_key: data里的key,用来生成embedding。
        :param Optional[List[str]] required_document_keys: 加到metadata里的key。
        :param bool use_uuid_from_file: 是否使用data里的uuid。
        :param bool skip_non_string_fields: 是否跳过非字符串的field。
    Returns:
        :return chromadb.Collection: The collection.
    """
    collection = client.get_or_create_collection(
        name=collection_name,
        metadata={"hnsw:space": "cosine"},
        embedding_function=embedding_function
    )
    # 处理待添加的数据:
    if isinstance(fields_to_add, str):
        fields_to_add = [fields_to_add]
    
    for item in data:
        # 检查item[field]是否是List,如果是,就text_to_add.extend(item[field])。
        text_to_add = []
        for field in fields_to_add:
            if isinstance(item[field], list):
                text_to_add.extend(item[field])
            else:
                text_to_add.append(item[field])
        
        # 构建待添加的metadata。
        if required_document_keys is not None:
            metadata = {key: item[key] for key in required_document_keys}
        else:
            if skip_non_string_fields:
                metadata = {key: item[key] for key in item if isinstance(item[key], str)}
            else:
                metadata = item
        
        # 根据fields to add生成符合对应长度的metadatas:
        # id_batch = [str(uuid.uuid4()) for _ in range(len(text_to_add))]
        metadata_batch = [metadata for _ in range(len(text_to_add))]

        # 查看数据中是否有`uuid`字段,如果有,则添加数据里对应的uuid。
        if "uuid" in item and use_uuid_from_file:
            id_batch = [item["uuid"] for _ in range(len(text_to_add))]
        else:
            id_batch = [str(uuid.uuid4()) for _ in range(len(text_to_add))]

        # add to collection:
        collection.add(
            documents=text_to_add,
            metadatas=metadata_batch,
            ids=id_batch
        )
    return collection


def main():
    database_path = "./pogo_database/"
    client = chromadb.PersistentClient(path=database_path)
    model_path = "path_to_embedding_model"
    with open("pogo_data.json", "r") as fin:
        qa_data = json.load(fin)

    embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model_path,
                                                                                      device="cpu", 
                                                                                      normalize_embeddings=True)
    collection_name = "pogo-text"
    fields_to_add = "content"
    required_document_keys = ["title", "content"]

    collection = add_documents(
        client=client,
        data=qa_data,
        embedding_function=embedding_function,
        collection_name=collection_name,
        fields_to_add=fields_to_add,
        required_document_keys=required_document_keys, 
        use_uuid_from_file=False
    )

TTS请求的发送

  1. 安装TTS相关的环境

  2. 写对应的脚本,来发送请求。

def init_riva_client():
    metadata = [
        ["function-id", "0149dedb-2be8-4195-b9a0-e57e0e14f972"],
        ["authorization", "Bearer nvapi-your_nv_api_key"]
    ]
    auth = riva.client.Auth(use_ssl=True, 
                            uri="grpc.nvcf.nvidia.com:443", 
                            metadata_args=metadata)
    service = riva.client.SpeechSynthesisService(auth)
    return service

def synthesize_text(text, voice=None, language_code='en-US', sample_rate_hz=44100, quality=None):
    service = init_riva_client()

    # save params:
    filename = f"{int(time.time())}.wav"
    output = wave.open(filename, "wb")
    output.setnchannels(1)
    output.setsampwidth(2)
    output.setframerate(sample_rate_hz)

    response = service.synthesize(
        text, voice, language_code, sample_rate_hz=sample_rate_hz,
        quality=20 if quality is None else quality,
        custom_dictionary={}
    )

    output.writeframes(response.audio)
    output.close()

    return filename

gradio webUI的搭建

我们用gradio.Chatbot()来构建聊天框,并且在收到了请求之后,将文本的最后一句话用英伟达TTS来处理。

def create_ui():
    with gr.Blocks(title="Chatbot") as demo:
        gr.HTML("""<h1 align="center">复制来的</h1>""")

        with gr.Row():
            with gr.Column(scale=4):
                system_message = gr.Textbox("You are helpful assistant to help user to answer pokemon go related questions.", label="System Message", placeholder="Input the system message")
                chatbot = gr.Chatbot()
                with gr.Column(scale=12):
                    user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10)
                with gr.Column(min_width=32, scale=1):
                    submitBtn = gr.Button("Submit", variant="primary")

            with gr.Column(scale=1):
                emptyBtn = gr.Button("Clear History")

        history = gr.State([])
        audio_player = gr.Audio()

        submitBtn.click(
            handle_submit, 
            [user_input, chatbot],
            [chatbot, audio_player]
        )
        submitBtn.click(lambda: "", inputs=[], outputs=[user_input])
        emptyBtn.click(lambda: ([], ""), outputs=[chatbot, history, audio_player], show_progress=True)
    
    return demo

集成与部署

该服务可以在本地部署。部署时配置必要的服务和端口,确保系统稳定运行。

python core.py

项目成果与展示

应用场景展示

在用户遇到了不理解的宝可梦go的相关问题的时候,可以询问bot,并获得对应的解答。

功能演示

实现功能:

  • 问答功能:用户可以输入问题,机器人根据官方文档给出答案。
    破壳梦·爬相关的RAG功能 - 破壳梦打架板的leaderboard上哪里看

  • 语音回答:通过TTS功能,机器人能够将回答转换为语音,方便用户收听。

TTS功能的实现

问题与解决方案

问题分析

  1. Text embedding跑不起来

由于我的电脑是小霸王学习机,曾经尝试过huggingface榜单上的stella_en_400m模型。虽然在sentence transformer里指定使用了cpu进行推理,但是仍然报错:

NotImplementedError: No operator found for `memory_efficient_attention_forward` with inputs:
     query       : shape=(1, 296, 16, 64) (torch.float32)
     key         : shape=(1, 296, 16, 64) (torch.float32)
     value       : shape=(1, 296, 16, 64) (torch.float32)
     attn_bias   : <class 'xformers.ops.fmha.attn_bias.BlockDiagonalMask'>
     p           : 0.0
`decoderF` is not supported because:
    device=cpu (supported: {'cuda'})
    attn_bias type is <class 'xformers.ops.fmha.attn_bias.BlockDiagonalMask'>
`flshattF@2.5.6-pt` is not supported because:
    device=cpu (supported: {'cuda'})
    dtype=torch.float32 (supported: {torch.float16, torch.bfloat16})
`cutlassF-pt` is not supported because:
    device=cpu (supported: {'cuda'})
`smallkF` is not supported because:
    max(query.shape[-1] != value.shape[-1]) > 32
    device=cpu (supported: {'cuda'})
    attn_bias type is <class 'xformers.ops.fmha.attn_bias.BlockDiagonalMask'>
    unsupported embed per head: 64

从报错的消息来看,必须使用cuda+xFormer才可以顺利推理。于是作罢,采用了一个对小霸王学习机也友好的小模型来推理。

  1. 英伟达Riva-python-client装不上。
(chatnpc) string@Workstation:~/work/nv_homework$ pip install -r https://raw.githubusercontent.com/nvidia-riva/python-clients/main/requirements.txt
WARNING: Retrying (Retry(total=4, connect=None, read=None, redirect=None, status=None)) after connection broken by 'NewConnectionError('<pip._vendor.urllib3.connection.HTTPSConnection object at 0x7f9c29198af0>: Failed to establish a new connection: [Errno 111] Connection refused')': /nvidia-riva/python-clients/main/requirements.txt

总之就是会报错。

按照python-clients的文档来安装。其中git submodule initgit submodule update --remote --recursive这两步很重要,会拉一些传输协议相关的proto文件。

git clone https://github.com/nvidia-riva/python-clients.git
cd python-clients
git submodule init
git submodule update --remote --recursive
pip install -r requirements.txt
python3 setup.py bdist_wheel
pip install --force-reinstall dist/*.whl

项目总结与展望

项目评估

  1. 成功点
  • 高效回答用户问题,减少玩家的查询时间。

  • 成功集成TTS功能,提升用户互动体验。

  1. 不足
  • retrieve精度(best@1不一定retrieve到命中的知识)

  • 用户的问题可能和我们的pogo无关,比如用户可能问“原神怎么你了”,或者“巫师联盟怎么关服了”这样的问题。可能在回答之前需要再增加一层embedding过滤,过滤掉这种和pogo无关的问题。

  • 如果这个bot在宝可梦go community使用,用户大多会问一些和官方文档无关的问题,而更加倾向于问一些和当地游戏玩法相关的问题。比如:“我在哪里转poke stop可以获得rare candy?”或者“现在哪里有气球皮卡丘的团体战?”,可能需要引入function call的功能,接入当地的pogo地图,来进行检索。

未来方向

  1. 增加多路召回的方法

现在的retrieve的用的是text embedding,可能效果不好。。希望增加一些传统的NLP的方法,比如BM25之类的,

  1. 增强用户个性化

引入用户历史数据分析功能,根据用户的历史查询记录和偏好提供个性化的回答和建议。这可以提升用户体验和互动性。

附件与参考资料

  • 宝可梦go官方文档:https://niantic.helpshift.com/hc/en/6-pokemon-go/

  • ChromaDB文档:https://www.trychroma.com/

  • 英伟达TTS API:https://build.nvidia.com/nvidia/fastpitch-hifigan-tts/api

  • 英伟达提供的llama 8b instruct API: https://build.nvidia.com/meta/llama-3_1-8b-instruct

完整代码

(太懒了懒得传github了……)

  • core.py
import json 
from typing import List, Dict, Any

import chromadb 
from chromadb.utils import embedding_functions
from chromadb.api import ClientAPI

from openai import OpenAI
import gradio as gr

from create_ui import *

client = chromadb.PersistentClient(path="./pogo-db")
model_path = "model_path"
embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model_path,
                                                                                      device="cpu", 
                                                                                      normalize_embeddings=True)

chat_client = OpenAI(
    base_url = "https://integrate.api.nvidia.com/v1",
    api_key = "nvapi-your-nv-api-key"
)

def filter_retrieved_items(retrieve_result: Dict[str, Any],
                           max_distance: float=0.4) -> List[str]:
    """
    根据max_distance过滤掉distance太大的。
    Args: 
        :param Dict[str, Any] retrieve_result: The retrieved result.
        :param float max_distance: The maximum distance allowed.
    Returns:
        :return List[str]: The filtered retrieved text.
    """
    passages = retrieve_result["metadatas"][0]
    retrieved_text = []
    for index, passage in enumerate(passages):
        if retrieve_result["distances"][0][index] < max_distance:
            # 过滤掉distance太大的
            retrieved_text.append(f'## {passage["title"]}\n{passage["content"]}\n')
    # print(retrieved_text)
    return retrieved_text


def rag_from_database(user_query: str,
                      db_client: chromadb.PersistentClient,
                      embedding_function,
                      collection_name: str=None,
                      top_n: int=1) -> List[Dict[str, Any]]:
    """
    Args: 
        :param str user_query: The user query.
        :param int top_n: The number of results to return.
    Returns:
        :return List[Dict[str, Any]]: The retrieved items. [{"content": "xxx", "distance": 0.1}, ...]
    """
    # 从数据库里检索相关的知识
    if collection_name is None:
        collection_name = "pogo-helper"
    collection = db_client.get_collection(collection_name)
    if (isinstance(user_query, str)):
        # obviously str here LOL
        user_query = [user_query]

    query_embedding = embedding_function(user_query)

    retrieve_result = collection.query(
        query_embeddings=query_embedding,
        n_results=top_n
    )
    retrieved_text = filter_retrieved_items(retrieve_result, max_distance=0.5)
    return retrieved_text


def predict(user_input: str,
            chatbot: gr.Chatbot=[]):
    retrieved_text = rag_from_database(user_input, client, embedding_function)
    system_message = "You are helpful assistant to help user to answer pokemon go related questions, please make sure your answer is shorter than 50 words."
    if len(retrieved_text) == 0:
        # 如果检索不到,就直接用chatbot回答
        completion = chat_client.chat.completions.create(
            model="meta/llama-3.1-8b-instruct",
            messages=[{"role": "system", "content": system_message}, {"role":"user","content":user_input}],
            temperature=0.2,
            top_p=0.7,
            max_tokens=1024,
            stream=False
        )
    else:
        # 构建system message,把检索到的知识加到system message里,然后用chatbot回答
        system_message += f"Retrieved Text:\n\n{''.join(retrieved_text)}"
        completion = chat_client.chat.completions.create(
            model="meta/llama-3.1-8b-instruct",
            messages=[{"role": "system", "content": system_message}, {"role":"user","content":user_input}],
            temperature=0.2,
            top_p=0.7,
            max_tokens=1024,
            stream=False
        )
    response = completion.choices[0].message.content
    chatbot.append([user_input, response])
    return chatbot


if __name__ == "__main__":
    # create gradio ui
    demo = create_ui()
    demo.launch(server_port=8080, share=False, server_name="0.0.0.0")

create_ui.py

import gradio as gr 

from core import predict
from tts import process_text


def reset_user_input():
    return gr.update(value='')


def reset_state():
    return [], []


# Function to update chatbot and play audio
def handle_submit(user_input, chatbot):
    response_with_history = predict(user_input, chatbot)
    audio_path = process_text(response_with_history[-1][1])
    return chatbot, audio_path


def create_ui():
    with gr.Blocks(title="Chatbot") as demo:
        gr.HTML("""<h1 align="center">Pokemon Go Helper</h1>""")

        with gr.Row():
            with gr.Column(scale=4):
                system_message = gr.Textbox("You are helpful assistant to help user to answer pokemon go related questions.", label="System Message", placeholder="Input the system message")
                chatbot = gr.Chatbot()
                with gr.Column(scale=12):
                    user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10)
                with gr.Column(min_width=32, scale=1):
                    submitBtn = gr.Button("Submit", variant="primary")

            with gr.Column(scale=1):
                emptyBtn = gr.Button("Clear History")

        history = gr.State([])
        audio_player = gr.Audio()

        submitBtn.click(
            handle_submit, 
            [user_input, chatbot],
            [chatbot, audio_player]
        )
        submitBtn.click(lambda: "", inputs=[], outputs=[user_input])
        emptyBtn.click(lambda: ([], ""), outputs=[chatbot, history, audio_player], show_progress=True)
    
    return demo
  • tts.py
import time
import wave
import json
from pathlib import Path
import io

import gradio as gr
import riva.client
from riva.client.argparse_utils import add_connection_argparse_parameters


def read_file_to_dict(file_path):
    result_dict = {}
    with open(file_path, 'r') as file:
        for line_number, line in enumerate(file, start=1):
            line = line.strip()
            try:
                key, value = line.split('  ', 1)  # Split by double space
                result_dict[str(key.strip())] = str(value.strip())
            except ValueError:
                print(f"Warning: Malformed line {line}")
                continue
    if not result_dict:
        raise ValueError("Error: No valid entries found in the file.")
    return result_dict


def init_riva_client():
    metadata = [
        ["function-id", "0149dedb-2be8-4195-b9a0-e57e0e14f972"],
        ["authorization", "Bearer nvapi-your-nv-api-key"]
    ]
    auth = riva.client.Auth(use_ssl=True, 
                            uri="grpc.nvcf.nvidia.com:443", 
                            metadata_args=metadata)
    service = riva.client.SpeechSynthesisService(auth)
    return service


def synthesize_text(text, voice=None, language_code='en-US', sample_rate_hz=44100, quality=None):
    service = init_riva_client()

    # save params:
    filename = f"{int(time.time())}.wav"
    output = wave.open(filename, "wb")
    output.setnchannels(1)
    output.setsampwidth(2)
    output.setframerate(sample_rate_hz)

    response = service.synthesize(
        text, voice, language_code, sample_rate_hz=sample_rate_hz,
        quality=20 if quality is None else quality,
        custom_dictionary={}
    )

    output.writeframes(response.audio)
    output.close()

    return filename

def process_text(text):
    audio_stream = synthesize_text(text)
    return audio_stream


# Launch the app
if __name__ == "__main__":
    # check if it's possible to generate and save audio lol
    text = "Hello, how are you doing today?"
    audio_stream = process_text(text)
  • 12
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值