import csv
from datetime import datetime
import numpy as np
import pandas as pd
import os
from pathlib import Path
from sentence_transformers import SentenceTransformer, CrossEncoder, util
from concurrent.futures import ThreadPoolExecutor, wait
import multiprocessing
# Query data from QE platform
def concurrency_encode_text(text_arr):
record_count = text_arr.size
task_count = int(record_count / thread_num) + 1
future_tasks = []
for index in range(thread_num):
first_index = task_count * index
second_index = task_count * (index + 1)
if second_index > record_count:
second_index = record_count
task_text_arr = text_arr[first_index:second_index].copy()
future_tasks.append(thread_pool_executor.submit(encode_text, task_text_arr))
print("sub task starr, batch {}, start {}, end {}".format(index, first_index, second_index))
wait(future_tasks)
encoded_embedding = np.concatenate(([task.result() for task in future_tasks]))
return encoded_embedding
def encode_text(text_arr):
print("encode text array size {} records".format(text_arr.size))
text_embedding = np.empty((0,512), dtype="float32")
batch_size = 1000
batch_count = int(text_arr.size / batch_size) + 1
for batch_index in range(batch_count):
first_index = batch_size * batch_index
second_index = batch_size * (batch_index + 1)
if second_index > text_arr.size:
second_index = text_arr.size
bat
基于Transformer对文本进行向量搜索
最新推荐文章于 2023-04-17 15:47:35 发布