OpenAI 的 embedding 是计算文本与维度的相关性,默认的 ada-002 模型会将文本解析为 1536 个维度。用户可以通过文本之间的 embedding 计算相似度。
embedding 的使用场景是可以根据用户提供的语料片段与 prompt 内容计算相关度,然后将最相关的语料片段作为上下文放到 prompt 中,以提高 completion 的准确率。
生成embedding文件
# imports
import openai # for generating embeddings
import pandas as pd # for DataFrames to store article sections and embeddings
import os
EMBEDDING_MODEL = "text-embedding-ada-002" # OpenAI's best embeddings as of Apr 2023
BATCH_SIZE = 1000 # you can submit up to 2048 embedding inputs per request
a_strings = []
with open('a.txt', 'r', encoding='utf-8') as file:
content = file.read()
a_strings.append(content)
embeddings = []
def em():
openai.api_key =
openai.api_base =
for batch_start in range(0, len(a_strings), BATCH_SIZE):
batch_end = batch_start + BATCH_SIZE
batch = a_strings[batch_start:batch_end]
print(f"Batch {batch_start} to {batch_end-1}")
response = openai.Embedding.create(model=EMBEDDING_MODEL, input=batch)
for i, be in enumerate(response["data"]):
assert i == be["index"]
# double check embeddings are in same order as input
batch_embeddings = [e["embedding"] for e in response["data"]]
embeddings.extend(batch_embeddings)
df = pd.DataFrame({"text": a_strings, "embedding": embeddings})
SAVE_PATH = "/Users/Projects/a_answer.csv"
df.to_csv(SAVE_PATH,encoding='utf-8-sig', index=False)
if __name__ == '__main__':
print(em())
import ast # for converting embeddings saved as strings back to arrays
import openai # for calling the OpenAI API
import pandas as pd # for storing text and embeddings data
import tiktoken # for counting tokens
from scipy import spatial # for calculating vector similarities for search
import os
EMBEDDING_MODEL = "text-embedding-ada-002"
GPT_MODEL = "gpt-3.5-turbo"
df = pd.read_csv('/Users/Projects/a_answer.csv',encoding='utf-8-sig')
# convert embeddings from CSV str type back to list type
df['embedding'] = df['embedding'].apply(ast.literal_eval)
def strings_ranked_by_relatedness(
query: str,
df: pd.DataFrame,
relatedness_fn=lambda x, y: 1 - spatial.distance.cosine(x, y),
top_n: int = 100
):
openai.api_key =
openai.api_base =
"""Returns a list of strings and relatednesses, sorted from most related to least."""
query_embedding_response = openai.Embedding.create(
model=EMBEDDING_MODEL,
input=query,
)
query_embedding = query_embedding_response["data"][0]["embedding"]
strings_and_relatednesses = [
(row["text"], relatedness_fn(query_embedding, row["embedding"]))
for i, row in df.iterrows()
]
strings_and_relatednesses.sort(key=lambda x: x[1], reverse=True)
strings, relatednesses = zip(*strings_and_relatednesses)
s = strings[:top_n]
r = relatednesses[:top_n]
return s,r
def num_tokens(text: str, model: str = GPT_MODEL) -> int:
"""Return the number of tokens in a string."""
encoding = tiktoken.encoding_for_model(model)
return len(encoding.encode(text))
def query_message(
query: str,
df: pd.DataFrame,
model: str,
token_budget: int
) -> str:
"""Return a message for GPT, with relevant source texts pulled from a dataframe."""
strings, relatednesses = strings_ranked_by_relatedness(query, df)
introduction = 'Use the below articles to answer the subsequent question. use chiness.If the answer cannot be found in the articles, write "I could not find an answer."'
question = f"\n\nQuestion: {query}"
message = introduction
for string in strings:
next_article = f'\n\nmy article section:\n"""\n{string}\n"""'
if (
num_tokens(message + next_article + question, model=model)
> token_budget
):
break
else:
message += next_article
return message + question
def ask(
query: str,
df: pd.DataFrame = df,
model: str = GPT_MODEL,
token_budget: int = 4096 - 500,
print_message: bool = False,
) -> str:
"""Answers a query using GPT and a dataframe of relevant texts and embeddings."""
openai.api_key =
openai.api_base =
message = query_message(query, df, model=model, token_budget=token_budget)
if print_message:
print(message)
messages = [
{"role": "system", "content": "you are chatgpt3.5"},
{"role": "user", "content": message},
]
response = openai.ChatCompletion.create(
model=model,
messages=messages,
temperature=0
)
response_message = response["choices"][0]["message"]["content"]
return response_message
if __name__ == '__main__':
while True:
user_input = input("请输入提问(输入【quit】退出): ")
if user_input.lower() == 'quit':
break
else:
print(ask(user_input))