Unity接入大模型(小羊驼Vicuna,vLLM,ChatGPT等)

实现在Unity内部的大模型访问,我也是第一次接触Unity中通过大模型url访问。此博客面向新手,旨在给大家简单理解大模型POST和GET过程,还有实现简单的大模型访问。

参考博客:什么是chatGPT?Unity结合OpenAI官方api实现类似chatGPT的AI聊天机器人

附带源码地址:OpenAIChatRobotMaster: 使用unity实现的基于OpenAI官方api的AI聊天机器人示例

参考的博客主要用于访问CHATGPT,但是我目前需求是访问自己的大模型URL,其中碰到的问题以及代码的一些详细解读。

————————————

半年后补上:新的接入方式,从0开始。

目录

一、效果展示

二、源码POST解读

三、源码修改

四、代码评价

五、整个代码


一、效果展示

 UI改了一下,具体效果如上,这是使用了我们实验室自己部署的vicuna-13b大模型

二、源码POST解读

首先原博客的模型是chatgpt的:text-davinci-003模型,模型的请求体和响应体如下:

 所以原博客的unity代码中有这样的封装:

[System.Serializable]public class PostData{
    public string model;
    public string prompt;
    public int max_tokens; 
    public float temperature;
    public int top_p;
    public float frequency_penalty;
    public float presence_penalty;
    public string stop;
}
[System.Serializable]public class TextCallback{
    public string id;
    public string created;
    public string model;
    public List<TextSample> choices;
    [System.Serializable]public class TextSample{
        public string text;
        public string index;
        public string finish_reason;
    }
}

代码逻辑我做了一个图,大家可以看看,可以方便理解源代码:
 

逻辑概述:输入框的文本信息_msg一方面渲染到了聊天框m_PostChatPrefab,另一方面被封装到了_postData类里的prompt中。将json信息传递到LLM中,最后返回的_msg转json格式得到_textback格式。我们需要得到的是其中的choices[0]。将得到的choices[0]渲染到聊天框m_RobotChatPrefab中。

三、源码修改

具体修改部分主要是请求体和响应体的格式。源码中post过程没有错误答应部分,我自己中途打印了一些中间变量,方便查错:

本次以vicuna-13b大模型为例,官方文档没有请求体和响应体的格式,所以通过postman来查看的格式,如下:

 参照上面的内容修改代码如下:

[System.Serializable]public class PostData{
    public string model;
    public List<Messages> messages;
    [System.Serializable]
    public class Messages
    {
        public string role;
        public string content;
    }
    public int web;
    public int account_id;
    public int conversation_id;
    public float temperature;
    public int top_p;
    public int n;
    public int max_tokens;
    public string stop;
    public bool stream;
    public float frequency_penalty;
    public float presence_penalty;
    public string user;
}
[System.Serializable]public class TextCallback{
    public string id;
    public string created;
    public string model;
    public List<TextSample> choices;
    [System.Serializable]public class TextSample
    {
        public string index;
        public Messages message;
        [System.Serializable]
        public class Messages
        {
            public string role;
            public string content;
        }
        public string finish_reason;
    }
}

其中的_postData传参代码修改如下:

PostData _postData = new PostData{
    model = m_PostDataSetting.model,
    messages = new List<PostData.Messages>
    {
        new PostData.Messages
        {
            role = "user",
            content = _postWord
        }
    },
    web = m_PostDataSetting.web,
    account_id = m_PostDataSetting.account_id,
    conversation_id = m_PostDataSetting.conversation_id,
    temperature = m_PostDataSetting.temperature,
    top_p = m_PostDataSetting.top_p,
    n = m_PostDataSetting.n,
    max_tokens = m_PostDataSetting.max_tokens,
    stop = "string",
    stream= false,
    frequency_penalty = m_PostDataSetting.frequency_penalty,
    presence_penalty = m_PostDataSetting.presence_penalty,
    user = "string"
};

心得:

其中报错内容有1.打印出的_jsonText没有message

2.TextCallback.TextSample.Messages定义不正确

原因:

string _jsonText = JsonUtility.ToJson (_postData);中:

JsonUtility.ToJson()方法具有一些限制,它只能序列化Unity支持的类型,并且不能序列化嵌套的自定义类型(如Messages类)。

所以在新加的public class Messages前面添加[System.Serializable],就可以序列化了。

四、代码评价

代码方面通俗易懂,但是因为后续工作需求,代码还有许多功能需要增加:

1.此代码将每一次的对话都直接渲染到聊天框中,没有在内部进行存储,导致的结果就是不能进行多轮对话,后续我将朝这个方向进行改进。

2.此对话响应方式是响应全部结束后才渲染出来,不能一个字一个字的流式响应。

针对这些需求,后续会对代码进行修改。

五、整个代码

    private string m_ApiUrl = "http://*********/completion";


    //配置参数,用于存储聊天界面的一些配置信息。
    [SerializeField] private PostData m_PostDataSetting;
    //输入的信息,用于获取用户输入的聊天内容。
    [SerializeField] private InputField m_InputWord;
    //聊天文本放置的层,用于存储聊天气泡的位置信息。
    [SerializeField] private RectTransform m_rootTrans;
    //发送聊天气泡,用于存储用户发送的聊天气泡的预制体。
    [SerializeField] private ChatPrefab m_PostChatPrefab;
    //回复的聊天气泡,用于存储机器人回复的聊天气泡的预制体。
    [SerializeField] private ChatPrefab m_RobotChatPrefab;
    //滚动条,用于控制聊天界面的滚动。
    [SerializeField] private ScrollRect m_ScroTectObject;
    /// <summary>
    /// 发送信息UI
    /// </summary>
    public void SendData()
    {
        if (m_InputWord.text.Equals(""))
            return;
        //将输入框中的文本作为消息进行处理,
        string _msg = m_InputWord.text;
        // 以m_PostChatPrefab预制体为模板,生成聊天记录
        ChatPrefab _chat = Instantiate(m_PostChatPrefab, m_rootTrans.transform);
        _chat.SetText(_msg);
        //重新计算容器尺寸
        LayoutRebuilder.ForceRebuildLayoutImmediate(m_rootTrans);
        //使用协程(TurnToLastLine())确保聊天框始终显示最新的聊天记录。
        StartCoroutine(TurnToLastLine());
        //获取发送的数据并将其传递给回调函数(CallBack)。
        StartCoroutine(GetPostData(_msg, CallBack));
        //清空输入框文本
        m_InputWord.text = "";
    }
    /// <summary>
    /// AI回复的信息UI
    /// </summary>
    /// <param name="_callback"></param>
    private void CallBack(string _callback)
    {
        //去除字符串两侧的空格
        _callback = _callback.Trim();
        //将该字符串传递给 ChatPrefab 类的实例变量 _chat。
        ChatPrefab _chat = Instantiate(m_RobotChatPrefab, m_rootTrans.transform);
        _chat.SetText(_callback);
        //重新计算容器尺寸
        LayoutRebuilder.ForceRebuildLayoutImmediate(m_rootTrans);
        //将页面滚动到最后一行,
        StartCoroutine(TurnToLastLine());
    }

    /// <summary>
    ///UI协程函数, 将文本框滚动到最后一行消息的位置。
    /// </summary>
    /// <returns></returns>
    private IEnumerator TurnToLastLine()
    {
        yield return new WaitForEndOfFrame();
        //滚动到最近的消息
        m_ScroTectObject.verticalNormalizedPosition = 0;
    }

    /// <summary>
    /// 设置AI模型类型model
    /// </summary>
    /// <param name="_modelType"></param>
    public void SetAIModel(Toggle _modelType)
    {
        if (_modelType.isOn)
        {
            m_PostDataSetting.model = _modelType.name;
        }
    }
    //---------------------------------------------------------------------------------------------------------------
    /// <summary>
    /// 用于存储向AI模型发送的参数数据。
    /// </summary>
    [System.Serializable]
    public class PostData
    {
        public string model;
        public List<Messages> messages;

        [System.Serializable]
        public class Messages
        {
            public string role;
            public string content;
        }
        public int web;
        public int account_id;
        public int conversation_id;
        public float temperature;
        public int top_p;
        public int n;
        public int max_tokens;
        public string stop;
        public bool stream;
        public float frequency_penalty;
        public float presence_penalty;
        //public string user;
    }

    /// <summary>
    /// 向AI模型发送数据
    /// </summary>
    /// <param name="_postWord"></param>
    /// <param name="_callback"></param>
    /// <returns></returns>
	private IEnumerator GetPostData(string _postWord, System.Action<string> _callback)
    {
        //UnityWebRequest发送POST请求,接口:m_ApiUrl
        var request = new UnityWebRequest(m_ApiUrl, "POST");
        PostData _postData = new PostData
        {
            model = m_PostDataSetting.model,
            messages = new List<PostData.Messages>{
                new PostData.Messages
                {
                    role = "user",
                    content = _postWord
                }
            },
            web = m_PostDataSetting.web,
            account_id = m_PostDataSetting.account_id,
            conversation_id = m_PostDataSetting.conversation_id,
            temperature = m_PostDataSetting.temperature,
            top_p = m_PostDataSetting.top_p,
            n = m_PostDataSetting.n,
            max_tokens = m_PostDataSetting.max_tokens,
            stop = m_PostDataSetting.stop,
            stream = false,
            frequency_penalty = m_PostDataSetting.frequency_penalty,
            presence_penalty = m_PostDataSetting.presence_penalty,
            //user = m_PostDataSetting.user
        };
        //将 _postData 转换成 JSON 格式的字符串
        string _jsonText = JsonUtility.ToJson(_postData);
        byte[] data = System.Text.Encoding.UTF8.GetBytes(_jsonText);
        //请求数据data上传
        request.uploadHandler = (UploadHandler)new UploadHandlerRaw(data);
        //设置请求的下载处理器DownloadHandlerBuffer,,,返回的数据存储在缓存区
        request.downloadHandler = (DownloadHandler)new DownloadHandlerBuffer();
        //设置请求头。。告诉服务器上传的数据为 JSON 格式。
        request.SetRequestHeader("Content-Type", "application/json");
        //request.SetRequestHeader("Authorization",string.Format("Bearer {0}",m_OpenAI_Key));
        //异步发送请求并等待响应
        yield return request.SendWebRequest();
        Debug.Log("Response Code: " + request.responseCode);

        if (request.responseCode == 200)
        {
            string _msg = request.downloadHandler.text;
            Debug.Log(" _msg: " + _msg);
            //将_msg转化为TextCallback数据结构
            if (!string.IsNullOrEmpty(_msg))
            {
                TextCallback _textback = JsonUtility.FromJson<TextCallback>(_msg);
                Debug.Log("_textback: " + _textback);
                Debug.Log("_textback.choices[0]: " + _textback.choices[0]);
                Debug.Log("_textback.choices[0].message: " + _textback.choices[0].message);
                Debug.Log("_textback.choices[0].message.content: " + _textback.choices[0].message.content);
                if (_textback != null && _textback.choices != null && _textback.choices.Count > 0)
                {
                    _callback(_textback.choices[0].message.content);
                }
                else { Debug.LogError("Request Error: Invalid response data."); }
            }
            else
            { Debug.LogError("Request Error: Empty response data."); }
        }
        else
        {
            Debug.LogError("Request Error: " + request.responseCode);
        }
    }
    /// <summary>
    /// 用于退出应用程序,
    /// </summary>
    public void Quit()
    { Application.Quit(); }
    void Update()
    {
        if (Input.GetKeyDown(KeyCode.Escape))
        { Application.Quit(); }
        if (Input.GetKeyDown(KeyCode.Return))
        { SendData(); }
    }
    [System.Serializable]
    public class TextCallback
    {
        public string id;
        public string created;
        public string model;
        public List<TextSample> choices;
        [System.Serializable]
        public class TextSample
        {
            public string index;
            public Messages message;
            [System.Serializable]
            public class Messages
            {
                public string role;
                public string content;
            }
            public string finish_reason;
        }
    }

  • 19
    点赞
  • 35
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值