Langchain是一个集成多个模块的LLM辅助开发工具,由于传统的开发例程基于OpenAI公司的 ChatGPT,其使用较为不便,因此本文基于国内LLM模型 ChatGLM,更加方便国内的用户学习和开发使用,促进LLM落地应用侧。
环境:
- 开发语言: Python3.9
- 开发平台: Pycharm
- GPU: RTX3090
服务的调用
服务调用可以查看代码段
import json
from flask import Flask
from flask import request
from transformers import AutoTokenizer, AutoModel
from pyreadline import Readline
from wsgiref.simple_server import make_server
# 这个程序是服务器开启
# 角色输入,利用历史语句输入角色
character = []
# 启动模型
readline = Readline()
# 这里填写模型的地址,请填写模型的绝对地址,模型参数文件可以从官网https://huggingface.co/中的ChatGLM2-6B下载
tokenizer = AutoTokenizer.from_pretrained("./model/ChatGLM2-6B",
trust_remote_code=True)
model = AutoModel.from_pretrained("./model/ChatGLM2-6B",
trust_remote_code=True).cuda()
model.eval()
# 服务器启动
app = Flask(__name__)
@app.route("/", methods=["POST", "GET"])
def root():
"""root
"""
return "Welcome to ChatGLM2-6B model."
@app.route("/chat", methods=["POST"])
def chat():
"""chat
"""
prompt = ""
data_seq = request.get_data()
data_dict = json.loads(data_seq)
human_input = data_dict["human_input"]
response, history = model.chat(tokenizer, human_input, history=character)
result_dict = {
"response": response
}
for data_seq, response in history:
prompt += f"Question: {data_seq}\n\n"
prompt += f"Answer:{response}"
print(history)
# 输出结果
result_seq = json.dumps(result_dict, ensure_ascii=False)
return result_seq
# 主函数
if __name__ == "__main__":
server = make_server('127.0.0.1', 8595, app)
server.serve_forever()
基本使用
# import time
import os
import logging
import requests
from typing import Optional, List, Dict, Mapping, Any
import torch.backends
import langchain
from langchain.llms.base import LLM
from langchain.cache import InMemoryCache
# 下载对应包可以采用:pip install + 包 + -i https://pypi.tuna.tsinghua.edu.cn/simple
class ChatGLM(LLM):
# 模型服务url
url = "http://127.0.0.1:8595/chat"
@property
def _llm_type(self) -> str:
return "ChatGLM2-6B"
def _construct_query(self, prompt: str) -> Dict:
"""构造请求体
"""
query = {
"human_input": prompt
}
return query
@classmethod
def _post(cls, url: str,
query: Dict) -> Any:
"""POST请求
"""
_headers = {"Content_Type": "application/json"}
with requests.session() as sess:
resp = sess.post(url,
json=query,
headers=_headers,
timeout=60)
return resp
def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs) -> str:
"""_call
:param **kwargs:
"""
# construct query
query = self._construct_query(prompt=prompt)
# post
resp = self._post(url=self.url,
query=query)
if resp.status_code == 200:
resp_json = resp.json()
predictions = resp_json["response"]
return predictions
else:
return "请求模型"
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters.
"""
_param_dict = {
"url": self.url
}
return _param_dict
if __name__ == "__main__": # 验证模型运行条件
question_list = []
response_list = []
logging.basicConfig(level=logging.INFO)
# 启动llm的缓存
langchain.llm_cache = InMemoryCache()
llm = ChatGLM()
while True:
human_input = input("用户: ")
# 停止回答
if human_input.strip() == "stop":
break
question_list.append({'question': human_input})
response = llm(human_input)
print(f"小U: {response}")
response_list.append({'Answer': response})