import torch
def batch_cosine_similarity(sentence_embedding_a, sentence_embedding_b):
"""
:param sentence_embedding_a: [a_number, hidden_dim]
:param sentence_embedding_b: [b_number, hidden_dim]
:return:
"""
a_number = sentence_embedding_a.size(0)
b_number = sentence_embedding_b.size(0)
# a_embedding: [a_number, hidden_dim] --> [a_number, b_number, hidden_dim] --> [a_number * b_number, hidden_dim]
# b_embedding: [b_number, hidden_dim] --> [a_number, b_number, hidden_dim] --> [a_number * b_number, hidden_dim]
a_embedding = sentence_embedding_a.unsqueeze(1).repeat(1, b_number, 1).view(-1, 768)
b_embedding = sentence_embedding_b.unsqueeze(0).repeat(a_number, 1, 1).view(-1, 768)
# similarity: [b_number * a_number, 1]
similarity = torch.cosine_similarity(a_embedding, b_embedding)
similarity = similarity.view(a_number, b_number)
return similarity
torch 自带的 cosine_similarity 函数似乎仅适用于相同尺寸的向量进行相似度的计算,为了避免for循环带来的速度慢的情况,采用repeat方式进行向量的复制,并且实现多对多的的余弦相似度计算。