使用 Qdrant 和 FiftyOne 进行最近邻嵌入搜索

神经网络嵌入是输入数据的低维表示,可用于各种应用。嵌入具有一些有趣的功能,因为它们可以捕获数据点的语义。这对于图像和视频等非结构化数据特别有用,因此您不仅可以编码像素相似性,还可以编码一些更复杂的关系。

对这些嵌入执行搜索会产生许多用例,例如分类、构建​​推荐系统,甚至是异常检测。对嵌入执行最近邻搜索以完成这些任务的主要好处之一是无需为每个新问题创建自定义网络;您通常可以使用预先训练的模型。
无需任何进一步的微调就可以使用由一些公开可用的模型生成的嵌入。

虽然有很多涉及嵌入的强大用例,但在对嵌入执行搜索的工作流程中存在一些挑战。具体来说,在大型数据集上执行最近邻搜索,然后有效地对搜索结果采取行动,例如,执行自动标记数据等工作流程,既是技术挑战,也是工具挑战。为此,Qdrant 和 FiftyOne 可以帮助简化这些工作流程

Qdrant是一个开源向量数据库,旨在对密集神经嵌入执行近似最近邻搜索 (ANN),这对于任何预期可扩展到大量数据的生产就绪系统都是必需的。

FiftyOne是一个开源数据集管理和模型评估工具,可让您有效地管理和可视化数据集、生成嵌入并改进模型结果。

在本文中,我们将 MNIST 数据集加载到 FiftyOne 中,并基于 ANN 进行分类。数据点将通过从我们的训练数据集中选择 K 个最近点中最常见的地面实况标签进行分类。换句话说,对于每个测试示例,我们将使用选定的距离函数选择其K个最近邻,然后通过投票选择最佳标签。向量空间中的所有搜索都将使用 Qdrant 完成以加快速度。然后,我们将在 FiftyOne 中评估此分类的结果。

安装

如果您想开始使用 Qdrant 的语义搜索,您需要运行它的一个实例,因为此工具以客户端-服务器方式工作。最简单的方法是使用官方的 Docker 镜像并使用一个命令启动 Qdrant:

docker run -p “6333:6333” -p “6334:6334” -d qdrant/qdrant

运行命令后,我们将运行 Qdrant 服务器,HTTP API 暴露在 6333 端口,gRPC 接口暴露在 6334。

我们还需要安装一些 Python 包。我们将使用 FiftyOne 来可视化数据、它们的真实标签以及我们的嵌入相似性模型预测的标签。嵌入将由 MobileNet v2 创建,可在 torchvision 中使用。当然,我们还需要以某种方式与 Qdrant 服务器进行通信,并且由于我们将使用 Python,qdrant_client因此这是一种首选方式。

pip install fiftyone
pip install torchvision
pip install qdrant_client

加工管道

  • 加载数据集
  • 生成嵌入
  • 将嵌入加载到 Qdrant
  • 最近邻分类
  • 五十一的评估

加载数据集

我们需要采取几个步骤才能使事情顺利进行。首先,我们需要加载MNIST数据集中提取训练示例,因为我们将在搜索操作中使用它们。为了让一切变得更快,我们不会使用所有示例,而是仅使用 2500 个样本。我们可以使用FiftyOne数据集Zoo在一行代码中加载我们想要的 MNIST 子集。

import fiftyone as fo
import fiftyone.zoo as foz

# Load the data
dataset = foz.load_zoo_dataset("mnist", max_samples=2500)

# Get all training samples
train_view = dataset.match_tags(tags=["train"])

让我们从查看FiftyOne app中的数据集开始。

# Visualize the dataset in FiftyOne
session = fo.launch_app(train_view)

生成嵌入

下一步是在数据集中的样本上生成嵌入。这始终可以在 FiftyOne 之外使用您的自定义模型完成。但是,FiftyOne 还在 FiftyOne 模型动物园中提供了各种模型,可以直接使用这些模型来生成嵌入。

在此示例中,我们使用在ImageNet上训练的 MobileNetv2 来计算每个图像的嵌入。

# Compute embeddings
model = foz.load_zoo_model("mobilenet-v2-imagenet-torch")

train_embeddings = train_view.compute_embeddings(model)

将嵌入加载到 Qdrant

Qdrant 不仅允许存储向量,还允许存储一些相应的属性——每个数据点都有一个相关的向量,并且可以选择附加一个 JSON 有效负载。我们想用它来传递真实标签,以确保我们以后可以做出预测。

ground_truth_labels = train_view.values("ground_truth.label")
train_payload = [
    {"ground_truth": gt} for gt in ground_truth_labels
]

创建嵌入后,我们可以开始与 Qdrant 服务器通信。的实例QdrantClient很有帮助,因为它包含了所有必需的方法。让我们连接并创建一个名为向量大小的点集合,“mnist.”该向量大小取决于模型输出,因此如果我们想改天尝试不同的模型,我们将需要导入不同的模型,但其余部分保持不变。最终,在确保集合存在之后,我们可以发送所有向量及其包含真实标签的有效负载。

import qdrant_client as qc
from qdrant_client.http.models import Distance

# Load the train embeddings into Qdrant
def create_and_upload_collection(
    embeddings, payload, collection_name="mnist"
):
    client = qc.QdrantClient(host="localhost")
    client.recreate_collection(
        collection_name=collection_name,
        vector_size=embeddings.shape[1],
        distance=Distance.COSINE,
    )
    client.upload_collection(
        collection_name=collection_name,
        vectors=embeddings,
        payload=payload,
    )
    return client
    
client = create_and_upload_collection(train_embeddings, train_payload)

最近邻分类

现在对数据集执行推理。我们可以为我们的测试数据集创建嵌入,但忽略基本事实并尝试使用 ANN 找到它,然后比较两者是否匹配。让我们一步一步,从创建嵌入开始

# Assign the labels to test embeddings by selecting
# the most common label among the neighbours of each sample
test_view = dataset.match_tags(tags=["test"])
test_embeddings = test_view.compute_embeddings(model)

是时候来点魔法了。让我们遍历测试数据集的样本和相应的嵌入,并使用搜索操作从训练集中找到 15 个最接近的嵌入。我们还需要选择有效载荷,因为它们包含在特定点附近找到最常见标签所需的地面实况标签。Python 的Counter类将有助于避免任何样板代码。最常见的标签将作为一个存储“ann_prediction”在 FiftyOne 中的每个测试样本上。

这包含在下面的函数中,它以嵌入向量作为输入,使用 Qdrant 搜索功能找到测试嵌入的最近邻居,生成类预测,并返回可以存储在FiftyOne数据集中的FiftyOne分类对象.

import collections
from tqdm import tqdm

def generate_fiftyone_classification(
    embedding, collection_name="mnist"
):
    search_results = client.search(
        collection_name=collection_name,
        query_vector=embedding,
        with_payload=True,
        top=15,
    )
    # Count the occurrences of each class and select the most common label
    # with the confidence estimated as the number of occurrences of 
    # the most common label divided by a total number of results.
    counter = collections.Counter(
        [point.payload["ground_truth"] for point in search_results]
    )
    predicted_class, occurences_num = counter.most_common(1)[0]
    confidence = occurences_num / sum(counter.values())
    prediction = fo.Classification(
        label=predicted_class, confidence=confidence
    )
    return prediction
    
predictions = []

# Call Qdrant to find the closest data points
for embedding in tqdm(test_embeddings):
    prediction = generate_fiftyone_classification(embedding)
    predictions.append(prediction)
    
test_view.set_values("ann_prediction", predictions)
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

wouderw

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值