继承实现自己的 langchain.llms.base.LLM

替换langchain chatGPT实现自己的大模型接口

继承 langchain.llms.base.LLM

要实现自己的langchain.llms.base.LLM,可以按以下步骤进行:

  1. 选择继承: 继承langchain.llms.base.LLM类,这是LangChain框架中用于定义大语言模型(LLM)的基本类。
  2. 重写方法: 重写_call方法,这是主要的处理方法,用于对传递的提示问题进行分析,并返回相应的输出。你可以在这个方法中定义你的自定义逻辑。
  3. 实现其他必要方法: 根据需要,实现其他必要的方法。例如,根据具体情况,可能需要实现__init__方法来初始化模型。
  4. 编写自定义逻辑:_call方法中编写你的自定义逻辑,根据输入进行处理并返回相应的输出。这是你可以发挥创造力的地方,根据任务需求定义模型的行为。

下面是一个封装讯飞API简单的例子

#!/usr/bin/env python 
# -*- coding: utf-8 -*-
# @Time    : 2023/9/25 18:20
# @Author  : *****
import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
import logging
import ssl
from datetime import datetime
from time import mktime
from typing import Optional, List, Mapping, Any
from urllib.parse import urlencode
from urllib.parse import urlparse
from wsgiref.handlers import format_date_time

import langchain
import websocket
from langchain.cache import InMemoryCache
from langchain.llms.base import LLM

logging.basicConfig(level=logging.INFO)
# 启动llm的缓存
langchain.llm_cache = InMemoryCache()
result_list = []
SPARK_APPID = "0267a98f"
SPARK_API_SECRET = "xxxxxxxxxxxxxxx"
SPARK_API_KEY = "xxxxxxxxxxxxxxxxxxx"
gpt_url = "ws://spark-api.xf-yun.com/v1.1/chat"  # spark官方模型提供api接口
domain = "general"  # v1.5版本
gpt_url = "ws://spark-api.xf-yun.com/v2.1/chat"  # spark官方模型提供api接口
domain = "generalv2"  # v2.0版本

answer = ""


class Ws_Param(object):
    # 初始化
    def __init__(self, APPID, APIKey, APISecret, Spark_url):
        self.APPID = APPID
        self.APIKey = APIKey
        self.APISecret = APISecret
        self.host = urlparse(Spark_url).netloc
        self.path = urlparse(Spark_url).path
        self.Spark_url = Spark_url
    
    # 生成url
    def create_url(self):
        # 生成RFC1123格式的时间戳
        now = datetime.now()
        date = format_date_time(mktime(now.timetuple()))
        
        # 拼接字符串
        signature_origin = "host: " + self.host + "\n"
        signature_origin += "date: " + date + "\n"
        signature_origin += "GET " + self.path + " HTTP/1.1"
        
        # 进行hmac-sha256进行加密
        signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
                                 digestmod=hashlib.sha256).digest()
        
        signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
        
        authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
        
        authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
        
        # 将请求的鉴权参数组合为字典
        v = {
            "authorization": authorization,
            "date": date,
            "host": self.host
            }
        # 拼接鉴权参数,生成url
        url = self.Spark_url + '?' + urlencode(v)
        # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
        return url


# 收到websocket错误的处理
def on_error(ws, error):
    print("### error:", error)


# 收到websocket关闭的处理
def on_close(ws, one, two):
    print(" ")


# 收到websocket连接建立的处理
def on_open(ws):
    thread.start_new_thread(run, (ws,))


def run(ws, *args):
    data = json.dumps(gen_params(appid=ws.appid, domain=ws.domain, question=ws.question))
    ws.send(data)


# 收到websocket消息的处理
def on_message(ws, message):
    # print(message)
    data = json.loads(message)
    code = data['header']['code']
    if code != 0:
        print(f'请求错误: {code}, {data}')
        ws.close()
    else:
        choices = data["payload"]["choices"]
        status = choices["status"]
        content = choices["text"][0]["content"]
        # print(content, end="")
        global answer
        answer += content
        # print(1)
        if status == 2:
            ws.close()


def gen_params(appid, domain, question):
    """
    通过appid和用户的提问来生成请参数
    """
    data = {
        "header": {
            "app_id": appid,
            "uid": "1234"
            },
        "parameter": {
            "chat": {
                "domain": domain,
                "random_threshold": 0.5,
                "max_tokens": 4096,
                "auditing": "default"
                }
            },
        "payload": {
            "message": {
                "text": question
                }
            }
        }
    return data




class Spark(LLM):
    '''
    根据源码解析在通过LLMS包装的时候主要重构两个部分的代码
    _call 模型调用主要逻辑,输入问题,输出模型相应结果
    _identifying_params 返回模型描述信息,通常返回一个字典,字典中包括模型的主要参数
    '''
    host = urlparse(gpt_url).netloc  # host目标机器解析
    path = urlparse(gpt_url).path  # 路径目标解析
    max_tokens = 4096
    temperature = 0.5
    
    # ws = websocket.WebSocketApp(url='')
    
    @property
    def _llm_type(self) -> str:
        # 模型简介
        return "Spark"
    
    def _post(self, prompt):
        # 模型请求响应
        wsParam = Ws_Param(SPARK_APPID, SPARK_API_KEY, SPARK_API_SECRET, gpt_url)
        websocket.enableTrace(False)
        wsUrl = wsParam.create_url()
        ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
        ws.appid = SPARK_APPID
        content = [{'content': prompt, 'role': 'user'}]
        ws.question = content
        ws.domain = domain
        # setattr(ws, "temperature", self.temperature)
        # setattr(ws, "max_tokens", self.max_tokens)
        ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
        return answer
    
    def _call(self, prompt: str,
              stop: Optional[List[str]] = None) -> str:
        # 启动关键的函数
        content = self._post(prompt)
        # content = "这是一个测试"
        return content
    
    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """
        Get the identifying parameters.
        """
        _param_dict = {
            "url": gpt_url
            }
        return _param_dict


if __name__ == "__main__":
    llm = Spark(temperature=0.9)
    result = llm("你好啊", stop=["you"])
    print(result)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值