faiss与torch的Topk

Topk检索

暴力搜索中,faiss.IndexFlatIP和torch.topk+torch.matmul都是常用的方法。然而,实际使用时,无论是运行效率还是计算精度,两者是较大存在差异的。
为了兼顾更多使用场景,faiss使用cpu版本(为了兼容HNSW等方法),torch使用gpu版本,以此对比cpu版faiss创建索引与gpu版torch计算出topk的的性能与效果差异。

使用示例

faiss.IndexFlatIP

import faiss


def search_in_faiss_cpu(corpus_vectors, query_vectors, topk):
    start_time = datetime.datetime.now()
    faiss_indexer.add(corpus_vectors.detach().cpu().numpy())
    dis, index = faiss_indexer.search(query_vectors.detach().cpu().numpy(), k=topk)
    index = index.tolist()
    dis = dis.tolist()
    faiss_indexer.reset()
    cost_time = datetime.datetime.now() - start_time
    return dis, index, str(cost_time)

torch.topk+torch.matmul

import torch


def search_in_torch_topk(corpus_vectors, query_vectors, topk):
    start_time = datetime.datetime.now()
    calc = torch.matmul(query_vectors, torch.transpose(corpus_vectors, 0, 1))
    dis, index = torch.topk(calc, k=topk, dim=-1)
    index = index.detach().cpu().tolist()
    dis = dis.detach().cpu().tolist()
    cost_time = datetime.datetime.now() - start_time
    return dis, index, str(cost_time)

差异测试

corpus_size = 10000
query_size = 10000
dim = 768  # same with the pretrained model output embedding dim
topk = 10
device = 'cuda:0'
faiss_indexer = faiss.IndexFlatIP(dim)

for epoch in range(10):
    corpus_vectors = torch.randn(corpus_size, dim, device=device)
    query_vectors = torch.randn(query_size, dim, device=device)

    faiss_dis, faiss_result, faiss_t = search_in_faiss_cpu(corpus_vectors, query_vectors, topk)
    torch_dis, torch_result, torch_t = search_in_torch_topk(corpus_vectors, query_vectors, topk)
    # print(faiss_result, torch_result)
    assert len(faiss_result) == len(torch_result)
    for i in range(len(faiss_result)):
        for j in range(len(faiss_result[i])):
            if faiss_result[i][j] != torch_result[i][j]:
                print(f"i={i}, j={j}(index/distance): faiss: {faiss_result[i][j]}/{faiss_dis[i][j]}, torch: {torch_result[i][j]}/{torch_dis[i][j]}")
    print(f"Epoch {epoch}: Faiss:", faiss_t, ", Torch:", torch_t)
Epoch 0: Faiss: 0:00:03.098337 , Torch: 0:00:00.530352
i=2477, j=4(index/distance): faiss: 1303/90.24628448486328, torch: 5946/90.24634552001953
i=2477, j=5(index/distance): faiss: 5946/90.24626159667969, torch: 1303/90.24632263183594
i=3909, j=6(index/distance): faiss: 9539/83.25800323486328, torch: 112/83.2580795288086
i=3909, j=7(index/distance): faiss: 112/83.25799560546875, torch: 9539/83.25801086425781
i=6506, j=0(index/distance): faiss: 346/105.05818176269531, torch: 2258/105.05816650390625
i=6506, j=1(index/distance): faiss: 2258/105.05810546875, torch: 346/105.05811309814453
Epoch 1: Faiss: 0:00:03.195171 , Torch: 0:00:00.078587
i=707, j=7(index/distance): faiss: 8988/88.88170623779297, torch: 9912/88.88166809082031
...

由此可见,faiss(cpu)的索引创建所耗费的资源远不如直接在gpu上使用torch计算内积与topk(显存充足的情况下),且在实际计算中,faiss的距离计算存在精度差异问题,与torch直接内积得出结果并不完全相同。

结论

在显存充足且候选集合数量级仅为万时,直接使用torch暴力求解更高效,效率差异巨大,不必使用faiss创建索引。

待验证

  1. 当面对大corpus时,faiss提前创建好索引,再针对query批量搜索会不会效率更高(torch需要每次把query与corpus两两计算,消耗确实巨大,需验证gpu本身带来的性能提升能否抵消两两计算带来的消耗)。
  2. 设置场景:输入向量流,使用torch实现算法中的topk问题,从而最终获取topk,可能效率也会特别高(毕竟faiss的方法,现在看来存在计算精度问题,我相信torch自身计算一定是准的)(挖坑)。
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值