使用嵌入和最近邻搜索的推荐
推荐在网络上很普遍:
- 买了那个东西? 尝试这些类似的项目。
- 喜欢那本书吗? 试试这些类似的标题。
- 不是您正在寻找的帮助页面? 试试这些类似的页面。
此笔记本演示了如何使用嵌入来查找要推荐的相似项目。 特别是,我们使用 AG 的新闻文章语料库作为我们的数据集。
我们的模型将回答这个问题:给定一篇文章,还有哪些其他文章与它最相似?
1.导入
首先,让我们导入稍后需要的包和函数。 如果您没有这些,则需要安装它们。 您可以通过控制台运行pip install {package_name}
来安装,例如 pip install pandas
# imports
import pandas as pd
import pickle
from openai.embeddings_utils import (
get_embedding,
distances_from_embeddings,
tsne_components_from_embeddings,
chart_from_components,
indices_of_nearest_neighbors_from_distances,
)
# constants
EMBEDDING_MODEL = "text-embedding-ada-002"
2.加载数据
接下来我们加载AG新闻数据,看看长什么样子。
# load data (full dataset available at http://groups.di.unipi.it/~gulli/AG_corpus_of_news_articles.html)
dataset_path = "data/AG_news_samples.csv"
df = pd.read_csv(dataset_path)
# print dataframe
n_examples = 5
df.head(n_examples)
title | description | label_int | label |
---|---|---|---|
World Briefings | BRITAIN: BLAIR WARNS OF CLIMATE THREAT Prime M… | 1 | World |
Nvidia Puts a Firewall on a Motherboard (PC Wo… | PC World - Upcoming chip set will include buil… | 4 | Sci/Tech |
Olympic joy in Greek, Chinese press | Newspapers in Greece reflect a mixture of exhi… | 2 | Sports |
U2 Can iPod with Pictures | SAN JOSE, Calif. – Apple Computer (Quote, Cha… | 4 | Sci/Tech |
The Dream Factory | Any product, any shape, any size – manufactur… | 4 | Sci/Tech |
让我们看一下那些相同的例子,但没有被省略号截断。
# print the title, description, and label of each example
for idx, row in df.head(n_examples).iterrows():
print("")
print(f"Title: {row['title']}")
print(f"Description: {row['description']}")
print(f"Label: {row['label']}")
Title: World Briefings
Description: BRITAIN: BLAIR WARNS OF CLIMATE THREAT Prime Minister Tony Blair urged the international community to consider global warming a dire threat and agree on a plan of action to curb the quot;alarming quot; growth of greenhouse gases.
Label: World
Title: Nvidia Puts a Firewall on a Motherboard (PC World)
Description: PC World - Upcoming chip set will include built-in security features for your PC.
Label: Sci/Tech
Title: Olympic joy in Greek, Chinese press
Description: Newspapers in Greece reflect a mixture of exhilaration that the Athens Olympics proved successful, and relief that they passed off without any major setback.
Label: Sports
Title: U2 Can iPod with Pictures
Description: SAN JOSE, Calif. -- Apple Computer (Quote, Chart) unveiled a batch of new iPods, iTunes software and promos designed to keep it atop the heap of digital music players.
Label: Sci/Tech
Title: The Dream Factory
Description: Any product, any shape, any size -- manufactured on your desktop! The future is the fabricator. By Bruce Sterling from Wired magazine.
Label: Sci/Tech
3.构建缓存以保存嵌入
在获取这些文章的嵌入之前,让我们设置一个缓存来保存我们生成的嵌入。 通常,保存嵌入是个好主意,以便以后可以重新使用它们。 如果您不保存它们,则每次重新计算它们时都会再次付费。
缓存是一个字典,将 (text, model) 的元组映射到嵌入,这是一个浮点数列表。 缓存保存为 Python pickle 文件。
# establish a cache of embeddings to avoid recomputing
# cache is a dict of tuples (text, model) -> embedding, saved as a pickle file
# set path to embedding cache
embedding_cache_path = "data/recommendations_embeddings_cache.pkl"
# load the cache if it exists, and save a copy to disk
try:
embedding_cache = pd.read_pickle(embedding_cache_path)
except FileNotFoundError:
embedding_cache = {}
with open(embedding_cache_path, "wb") as embedding_cache_file:
pickle.dump(embedding_cache, embedding_cache_file)
# define a function to retrieve embeddings from the cache if present, and otherwise request via the API
def embedding_from_string(
string: str,
model: str = EMBEDDING_MODEL,
embedding_cache=embedding_cache
) -> list:
"""Return embedding of given string, using a cache to avoid recomputing."""
if (string, model) not in embedding_cache.keys():
embedding_cache[(string, model)] = get_embedding(string, model)
with open(embedding_cache_path, "wb") as embedding_cache_file:
pickle.dump(embedding_cache, embedding_cache_file)
return embedding_cache[(string, model)]
让我们通过嵌入来检查它是否有效。
# as an example, take the first description from the dataset
example_string = df["description"].values[0]
print(f"\nExample string: {example_string}")
# print the first 10 dimensions of the embedding
example_embedding = embedding_from_string(example_string)
print(f"\nExample embedding: {example_embedding[:10]}...")
Example string: BRITAIN: BLAIR WARNS OF CLIMATE THREAT Prime Minister Tony Blair urged the international community to consider global warming a dire threat and agree on a plan of action to curb the quot;alarming quot; growth of greenhouse gases.
Example embedding: [-0.01071077398955822, -0.022362446412444115, -0.00883542187511921, -0.0254171434789896, 0.031423427164554596, 0.010723662562668324, -0.016717055812478065, 0.004195375367999077, -0.008074969984591007, -0.02142154797911644]...
4.基于嵌入推荐相似文章
要查找类似的文章,让我们遵循一个三步计划:
- 获取所有文章描述的相似度嵌入
- 计算来源标题与所有其他文章之间的距离
- 打印出最接近源标题的其他文章
def print_recommendations_from_strings(
strings: list[str],
index_of_source_string: int,
k_nearest_neighbors: int = 1,
model=EMBEDDING_MODEL,
) -> list[int]:
"""Print out the k nearest neighbors of a given string."""
# get embeddings for all strings
embeddings = [embedding_from_string(string, model=model) for string in strings]
# get the embedding of the source string
query_embedding = embeddings[index_of_source_string]
# get distances between the source embedding and other embeddings (function from embeddings_utils.py)
distances = distances_from_embeddings(query_embedding, embeddings, distance_metric="cosine")
# get indices of nearest neighbors (function from embeddings_utils.py)
indices_of_nearest_neighbors = indices_of_nearest_neighbors_from_distances(distances)
# print out source string
query_string = strings[index_of_source_string]
print(f"Source string: {query_string}")
# print out its k nearest neighbors
k_counter = 0
for i in indices_of_nearest_neighbors:
# skip any strings that are identical matches to the starting string
if query_string == strings[i]:
continue
# stop after printing out k articles
if k_counter >= k_nearest_neighbors:
break
k_counter += 1
# print out the similar strings and their distances
print(
f"""
--- Recommendation #{k_counter} (nearest neighbor {k_counter} of {k_nearest_neighbors}) ---
String: {strings[i]}
Distance: {distances[i]:0.3f}"""
)
return indices_of_nearest_neighbors
5.示例建议
让我们寻找与第一篇相似的文章,那篇文章是关于托尼·布莱尔的。
article_descriptions = df["description"].tolist()
tony_blair_articles = print_recommendations_from_strings(
strings=article_descriptions, # let's base similarity off of the article description
index_of_source_string=0, # let's look at articles similar to the first one about Tony Blair
k_nearest_neighbors=5, # let's look at the 5 most similar articles
)
Source string: BRITAIN: BLAIR WARNS OF CLIMATE THREAT Prime Minister Tony Blair urged the international community to consider global warming a dire threat and agree on a plan of action to curb the quot;alarming quot; growth of greenhouse gases.
--- Recommendation #1 (nearest neighbor 1 of 5) ---
String: THE re-election of British Prime Minister Tony Blair would be seen as an endorsement of the military action in Iraq, Prime Minister John Howard said today.
Distance: 0.153
--- Recommendation #2 (nearest neighbor 2 of 5) ---
String: LONDON, England -- A US scientist is reported to have observed a surprising jump in the amount of carbon dioxide, the main greenhouse gas.
Distance: 0.160
--- Recommendation #3 (nearest neighbor 3 of 5) ---
String: The anguish of hostage Kenneth Bigley in Iraq hangs over Prime Minister Tony Blair today as he faces the twin test of a local election and a debate by his Labour Party about the divisive war.
Distance: 0.160
--- Recommendation #4 (nearest neighbor 4 of 5) ---
String: Israel is prepared to back a Middle East conference convened by Tony Blair early next year despite having expressed fears that the British plans were over-ambitious and designed
Distance: 0.171
--- Recommendation #5 (nearest neighbor 5 of 5) ---
String: AFP - A battle group of British troops rolled out of southern Iraq on a US-requested mission to deadlier areas near Baghdad, in a major political gamble for British Prime Minister Tony Blair.
Distance: 0.173
让我们看看我们的推荐器在第二篇关于 NVIDIA 的新芯片组的安全性更高的示例文章中是如何做的。
chipset_security_articles = print_recommendations_from_strings(
strings=article_descriptions, # let's base similarity off of the article description
index_of_source_string=1, # let's look at articles similar to the second one about a more secure chipset
k_nearest_neighbors=5, # let's look at the 5 most similar articles
)
Source string: PC World - Upcoming chip set will include built-in security features for your PC.
--- Recommendation #1 (nearest neighbor 1 of 5) ---
String: PC World - Updated antivirus software for businesses adds intrusion prevention features.
Distance: 0.112
--- Recommendation #2 (nearest neighbor 2 of 5) ---
String: PC World - The one-time World Class Product of the Year PDA gets a much-needed upgrade.
Distance: 0.145
--- Recommendation #3 (nearest neighbor 3 of 5) ---
String: PC World - Send your video throughout your house--wirelessly--with new gateways and media adapters.
Distance: 0.153
--- Recommendation #4 (nearest neighbor 4 of 5) ---
String: PC World - Symantec, McAfee hope raising virus-definition fees will move users to\ suites.
Distance: 0.157
--- Recommendation #5 (nearest neighbor 5 of 5) ---
String: Gateway computers will be more widely available at Office Depot, in the PC maker #39;s latest move to broaden distribution at retail stores since acquiring rival eMachines this year.
Distance: 0.168
从打印的距离,您可以看到 #1 推荐比所有其他推荐更接近(0.11 对 0.14+)。 #1 推荐看起来与起始文章非常相似 - 这是 PC World 的另一篇关于提高计算机安全性的文章。 不错!
附录:在更复杂的推荐系统中使用嵌入
构建推荐系统的一种更复杂的方法是训练一个机器学习模型,该模型接收数十或数百个信号,例如项目受欢迎程度或用户点击数据。 即使在这个系统中,嵌入对于推荐系统来说也是一个非常有用的信号,特别是对于那些还没有用户数据的“冷启动”项目(例如,一个全新的产品添加到目录中还没有任何点击)。
附录:使用嵌入可视化相似文章
为了了解我们最近的邻居推荐器在做什么,让我们可视化文章嵌入。 虽然我们无法绘制每个嵌入向量的 2048 维,但我们可以使用 t-SNE 或 PCA 等技术将嵌入压缩为 2 或 3 维,我们可以绘制图表。
在可视化最近邻之前,让我们使用 t-SNE 可视化所有文章描述。 请注意,t-SNE 不是确定性的,这意味着结果可能因运行而异。
# get embeddings for all article descriptions
embeddings = [embedding_from_string(string) for string in article_descriptions]
# compress the 2048-dimensional embeddings into 2 dimensions using t-SNE
tsne_components = tsne_components_from_embeddings(embeddings)
# get the article labels for coloring the chart
labels = df["label"].tolist()
chart_from_components(
components=tsne_components,
labels=labels,
strings=article_descriptions,
width=600,
height=500,
title="t-SNE components of article descriptions",
)
正如您在上图中所见,即使是高度压缩的嵌入也能很好地按类别对文章描述进行聚类。 值得强调的是:这种聚类是在不知道标签本身的情况下完成的!
此外,如果您仔细观察最令人震惊的异常值,它们通常是由于错误标记而不是嵌入不良造成的。 例如,绿色体育集群中的大部分蓝色世界点似乎都是体育故事。
接下来,让我们根据它们是源文章、最近的邻居还是其他来重新着色这些点。
# create labels for the recommended articles
def nearest_neighbor_labels(
list_of_indices: list[int],
k_nearest_neighbors: int = 5
) -> list[str]:
"""Return a list of labels to color the k nearest neighbors."""
labels = ["Other" for _ in list_of_indices]
source_index = list_of_indices[0]
labels[source_index] = "Source"
for i in range(k_nearest_neighbors):
nearest_neighbor_index = list_of_indices[i + 1]
labels[nearest_neighbor_index] = f"Nearest neighbor (top {k_nearest_neighbors})"
return labels
tony_blair_labels = nearest_neighbor_labels(tony_blair_articles, k_nearest_neighbors=5)
chipset_security_labels = nearest_neighbor_labels(chipset_security_articles, k_nearest_neighbors=5
)
# a 2D chart of nearest neighbors of the Tony Blair article
chart_from_components(
components=tsne_components,
labels=tony_blair_labels,
strings=article_descriptions,
width=600,
height=500,
title="Nearest neighbors of the Tony Blair article",
category_orders={"label": ["Other", "Nearest neighbor (top 5)", "Source"]},
)
查看上面的二维图表,我们可以看到有关托尼·布莱尔的文章在世界新闻集群中有些靠得很近。 有趣的是,虽然 5 个最近的邻居(红色)在高维空间中最近,但它们并不是这个压缩二维空间中的最近点。 将嵌入压缩到 2 维会丢弃它们的大部分信息,并且 2D 空间中的最近邻似乎不如完整嵌入空间中的那些相关。
# a 2D chart of nearest neighbors of the chipset security article
chart_from_components(
components=tsne_components,
labels=chipset_security_labels,
strings=article_descriptions,
width=600,
height=500,
title="Nearest neighbors of the chipset security article",
category_orders={"label": ["Other", "Nearest neighbor (top 5)", "Source"]},
)
对于芯片组安全性示例,完整嵌入空间中的 4 个最近邻在这个压缩的 2D 可视化中仍然是最近邻。 第五个显示为更远,尽管在完整嵌入空间中更近。
如果需要,您还可以使用函数 chart_from_components_3D 制作嵌入的交互式 3D 图。 (这样做将需要使用 n_components=3 重新计算 t-SNE 组件。)